• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18588513829

17 Oct 2025 09:26AM UTC coverage: 94.431% (-0.04%) from 94.471%
18588513829

push

github

web-flow
gh-667: Make JAX an optional dependency as intended (#672)

200 of 202 branches covered (99.01%)

Branch coverage included in aggregate %.

21 of 22 new or added lines in 3 files covered. (95.45%)

1 existing line in 1 file now uncovered.

1394 of 1486 relevant lines covered (93.81%)

7.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.54
/glass/_array_api_utils.py
1
from __future__ import annotations
8✔
2

3
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
4

5
import array_api_strict
8✔
6
import numpy as np
8✔
7
import numpy.random
8✔
8

9
if TYPE_CHECKING:
10
    from types import FunctionType, ModuleType
11

12
    from array_api_strict._array_object import Array as AArray
13
    from jaxtyping import Array as JAXArray
14
    from numpy.typing import DTypeLike, NDArray
15

16
    import glass.jax
17

18
    Size: TypeAlias = int | tuple[int, ...] | None
19
    GLASSAnyArray: TypeAlias = JAXArray | NDArray[Any]
20
    GLASSFloatArray: TypeAlias = JAXArray | NDArray[np.float64]
21
    GLASSComplexArray: TypeAlias = JAXArray | NDArray[np.complex128]
22

23
    AnyArray: TypeAlias = NDArray[Any] | JAXArray
24
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray
25

26

27
def get_namespace(*arrays: AnyArray) -> ModuleType:
8✔
28
    """
29
    Return the array library (array namespace) of input arrays
30
    if they belong to the same library or raise a :class:`ValueError`
31
    if they do not.
32
    """
33
    namespace = arrays[0].__array_namespace__()
8✔
34
    if any(
8✔
35
        array.__array_namespace__() != namespace
36
        for array in arrays
37
        if array is not None
38
    ):
39
        msg = "input arrays should belong to the same array library"
8✔
40
        raise ValueError(msg)
8✔
41

42
    return namespace
8✔
43

44

45
def rng_dispatcher(
8✔
46
    array: NDArray[Any] | JAXArray,
47
) -> np.random.Generator | glass.jax.Generator | Generator:
48
    """Dispatch RNG on the basis of the provided array."""
49
    backend = array.__array_namespace__().__name__
8✔
50
    if backend == "jax.numpy":
8✔
51
        import glass.jax  # noqa: PLC0415
8✔
52

53
        return glass.jax.Generator(seed=42)
8✔
54
    if backend == "numpy":
8✔
55
        return np.random.default_rng()
8✔
56
    if backend == "array_api_strict":
8✔
57
        return Generator(seed=42)
8✔
58
    msg = "the array backend in not supported"
×
59
    raise NotImplementedError(msg)
×
60

61

62
class Generator:
8✔
63
    """NumPy random number generator returning array_api_strict Array."""
64

65
    __slots__ = ("rng",)
8✔
66

67
    def __init__(
8✔
68
        self,
69
        seed: int | bool | NDArray[np.int_ | np.bool] | None = None,  # noqa: FBT001
70
    ) -> None:
71
        self.rng = numpy.random.default_rng(seed=seed)  # type: ignore[arg-type]
8✔
72

73
    def random(
8✔
74
        self,
75
        size: Size = None,
76
        dtype: DTypeLike | None = np.float64,
77
        out: NDArray[Any] | None = None,
78
    ) -> AArray:
79
        """Return random floats in the half-open interval [0.0, 1.0)."""
80
        return array_api_strict.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
81

82
    def normal(
8✔
83
        self,
84
        loc: float | NDArray[np.floating] = 0.0,
85
        scale: float | NDArray[np.floating] = 1.0,
86
        size: Size = None,
87
    ) -> AArray:
88
        """Draw samples from a Normal distribution (mean=loc, stdev=scale)."""
89
        return array_api_strict.asarray(self.rng.normal(loc, scale, size))
8✔
90

91
    def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArray:
8✔
92
        """Draw samples from a Poisson distribution."""
93
        return array_api_strict.asarray(self.rng.poisson(lam, size))
8✔
94

95
    def standard_normal(
8✔
96
        self,
97
        size: Size = None,
98
        dtype: DTypeLike | None = np.float64,
99
        out: NDArray[Any] | None = None,
100
    ) -> AArray:
101
        """Draw samples from a standard Normal distribution (mean=0, stdev=1)."""
102
        return array_api_strict.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
103

104
    def uniform(
8✔
105
        self,
106
        low: float | NDArray[np.floating] = 0.0,
107
        high: float | NDArray[np.floating] = 1.0,
108
        size: Size = None,
109
    ) -> AArray:
110
        """Draw samples from a Uniform distribution."""
111
        return array_api_strict.asarray(self.rng.uniform(low, high, size))
8✔
112

113

114
class XPAdditions:
8✔
115
    """
116
    Additional functions missing from both array-api-strict and array-api-extra.
117

118
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
119
    for details.
120
    """
121

122
    xp: ModuleType
8✔
123
    backend: str
8✔
124

125
    def __init__(self, xp: ModuleType) -> None:
8✔
126
        self.xp = xp
8✔
127
        self.backend = xp.__name__
8✔
128

129
    def trapezoid(
8✔
130
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
131
    ) -> AnyArray:
132
        """
133
        Integrate along the given axis using the composite trapezoidal rule.
134

135
        See https://github.com/glass-dev/glass/issues/646
136
        """
137
        self.backend = self.xp.__name__
8✔
138
        if self.backend == "jax.numpy":
8✔
139
            import glass.jax  # noqa: PLC0415
8✔
140

141
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
142
        if self.backend == "numpy":
8✔
143
            return np.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
144
        if self.backend == "array_api_strict":
8✔
145
            # Using design principle of scipy (i.e. copy, use np, copy back)
146
            y_np = np.asarray(y, copy=True)
8✔
147
            x_np = np.asarray(x, copy=True)
8✔
148
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
149
            return self.xp.asarray(result_np, copy=True)
8✔
150

151
        msg = "the array backend in not supported"
×
152
        raise NotImplementedError(msg)
×
153

154
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
155
        """
156
        Compute the set union of two 1D arrays.
157

158
        See https://github.com/glass-dev/glass/issues/647
159
        """
160
        if self.backend == "jax.numpy":
8✔
161
            import glass.jax  # noqa: PLC0415
8✔
162

163
            return glass.jax.union1d(ar1, ar2)
8✔
164
        if self.backend == "numpy":
8✔
165
            return np.union1d(ar1, ar2)
8✔
166
        if self.backend == "array_api_strict":
8✔
167
            # Using design principle of scipy (i.e. copy, use np, copy back)
168
            ar1_np = np.asarray(ar1, copy=True)
8✔
169
            ar2_np = np.asarray(ar2, copy=True)
8✔
170
            result_np = np.union1d(ar1_np, ar2_np)
8✔
171
            return self.xp.asarray(result_np, copy=True)
8✔
172

173
        msg = "the array backend in not supported"
×
174
        raise NotImplementedError(msg)
×
175

176
    def interp(  # noqa: PLR0913
8✔
177
        self,
178
        x: AnyArray,
179
        x_points: AnyArray,
180
        y_points: AnyArray,
181
        left: float | None = None,
182
        right: float | None = None,
183
        period: float | None = None,
184
    ) -> AnyArray:
185
        """
186
        One-dimensional linear interpolation for monotonically increasing
187
        sample points.
188

189
        See https://github.com/glass-dev/glass/issues/650
190
        """
191
        if self.backend == "jax.numpy":
8✔
192
            import glass.jax  # noqa: PLC0415
8✔
193

194
            return glass.jax.interp(
8✔
195
                x, x_points, y_points, left=left, right=right, period=period
196
            )
197
        if self.backend == "numpy":
8✔
198
            return np.interp(
8✔
199
                x, x_points, y_points, left=left, right=right, period=period
200
            )
201
        if self.backend == "array_api_strict":
8✔
202
            # Using design principle of scipy (i.e. copy, use np, copy back)
203
            x_np = np.asarray(x, copy=True)
8✔
204
            x_points_np = np.asarray(x_points, copy=True)
8✔
205
            y_points_np = np.asarray(y_points, copy=True)
8✔
206
            result_np = np.interp(
8✔
207
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
208
            )
209
            return self.xp.asarray(result_np, copy=True)
8✔
210

211
        msg = "the array backend in not supported"
×
212
        raise NotImplementedError(msg)
×
213

214
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
215
        """
216
        Return the gradient of an N-dimensional array.
217

218
        See https://github.com/glass-dev/glass/issues/648
219
        """
220
        if self.backend == "jax.numpy":
8✔
221
            import glass.jax  # noqa: PLC0415
8✔
222

223
            return glass.jax.gradient(f)
8✔
224
        if self.backend == "numpy":
8✔
225
            return np.gradient(f)
8✔
226
        if self.backend == "array_api_strict":
8✔
227
            # Using design principle of scipy (i.e. copy, use np, copy back)
228
            f_np = np.asarray(f, copy=True)
8✔
229
            result_np = np.gradient(f_np)
8✔
230
            return self.xp.asarray(result_np, copy=True)
8✔
231

232
        msg = "the array backend in not supported"
×
233
        raise NotImplementedError(msg)
×
234

235
    def linalg_lstsq(
8✔
236
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
237
    ) -> tuple[AnyArray, AnyArray, AnyArray, AnyArray]:
238
        """
239
        Return the gradient of an N-dimensional array.
240

241
        See https://github.com/glass-dev/glass/issues/649
242
        """
243
        if self.backend == "jax.numpy":
8✔
244
            import glass.jax  # noqa: PLC0415
8✔
245

246
            return glass.jax.linalg_lstsq(a, b, rcond=rcond)
8✔
247
        if self.backend == "numpy":
8✔
248
            return np.linalg.lstsq(a, b, rcond=rcond)
8✔
249
        if self.backend == "array_api_strict":
8✔
250
            # Using design principle of scipy (i.e. copy, use np, copy back)
251
            a_np = np.asarray(a, copy=True)
8✔
252
            b_np = np.asarray(b, copy=True)
8✔
253
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
254
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
255

256
        msg = "the array backend in not supported"
×
257
        raise NotImplementedError(msg)
×
258

259
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
260
        """
261
        Evaluates the Einstein summation convention on the operands.
262

263
        See https://github.com/glass-dev/glass/issues/657
264
        """
265
        if self.backend == "jax.numpy":
8✔
NEW
266
            import glass.jax  # noqa: PLC0415
×
267

UNCOV
268
            return glass.jax.einsum(subscripts, *operands)
×
269
        if self.backend == "numpy":
8✔
270
            return np.einsum(subscripts, *operands)
8✔
271
        if self.backend == "array_api_strict":
8✔
272
            # Using design principle of scipy (i.e. copy, use np, copy back)
273
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
274
            result_np = np.einsum(subscripts, *operands_np)
8✔
275
            return self.xp.asarray(result_np, copy=True)
8✔
276

277
        msg = "the array backend in not supported"
×
278
        raise NotImplementedError(msg)
×
279

280
    def apply_along_axis(
8✔
281
        self,
282
        func1d: FunctionType,
283
        axis: int,
284
        arr: AnyArray,
285
        *args: object,
286
        **kwargs: object,
287
    ) -> AnyArray:
288
        """
289
        Apply a function to 1-D slices along the given axis.
290

291
        See https://github.com/glass-dev/glass/issues/651
292
        """
293
        if self.backend == "jax.numpy":
8✔
294
            import glass.jax  # noqa: PLC0415
8✔
295

296
            return glass.jax.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
297
        if self.backend == "numpy":
8✔
298
            return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
299
        if self.backend == "array_api_strict":
8✔
300
            # Using design principle of scipy (i.e. copy, use np, copy back)
301
            arr_np = np.asarray(arr, copy=True)
8✔
302
            result_np = np.apply_along_axis(func1d, axis, arr_np, *args, **kwargs)
8✔
303
            return self.xp.asarray(result_np, copy=True)
8✔
304

305
        msg = "the array backend in not supported"
×
306
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc