• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18650097072

20 Oct 2025 11:08AM UTC coverage: 94.467%. Remained the same
18650097072

Pull #684

github

web-flow
Merge 1771cf649 into c5b0599c8
Pull Request #684: gh-682: create an `isort` section for Array API

200 of 202 branches covered (99.01%)

Branch coverage included in aggregate %.

2 of 2 new or added lines in 2 files covered. (100.0%)

3 existing lines in 1 file now uncovered.

1405 of 1497 relevant lines covered (93.85%)

7.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.54
/glass/_array_api_utils.py
1
from __future__ import annotations
8✔
2

3
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
4

5
import numpy as np
8✔
6
import numpy.random
8✔
7

8
import array_api_strict
8✔
9

10
if TYPE_CHECKING:
11
    from types import FunctionType, ModuleType
12

13
    from jaxtyping import Array as JAXArray
14
    from numpy.typing import DTypeLike, NDArray
15

16
    from array_api_strict._array_object import Array as AArray
17

18
    import glass.jax
19

20
    Size: TypeAlias = int | tuple[int, ...] | None
21

22
    AnyArray: TypeAlias = NDArray[Any] | JAXArray
23
    ComplexArray: TypeAlias = JAXArray | NDArray[np.complex128]
24
    DoubleArray: TypeAlias = NDArray[np.double] | JAXArray
25
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray
26

27

28
def get_namespace(*arrays: AnyArray) -> ModuleType:
8✔
29
    """
30
    Return the array library (array namespace) of input arrays
31
    if they belong to the same library or raise a :class:`ValueError`
32
    if they do not.
33
    """
34
    namespace = arrays[0].__array_namespace__()
8✔
35
    if any(
8✔
36
        array.__array_namespace__() != namespace
37
        for array in arrays
38
        if array is not None
39
    ):
40
        msg = "input arrays should belong to the same array library"
8✔
41
        raise ValueError(msg)
8✔
42

43
    return namespace
8✔
44

45

46
def rng_dispatcher(
8✔
47
    array: NDArray[Any] | JAXArray,
48
) -> np.random.Generator | glass.jax.Generator | Generator:
49
    """Dispatch RNG on the basis of the provided array."""
50
    backend = array.__array_namespace__().__name__
8✔
51
    if backend == "jax.numpy":
8✔
52
        import glass.jax  # noqa: PLC0415
8✔
53

54
        return glass.jax.Generator(seed=42)
8✔
55
    if backend == "numpy":
8✔
56
        return np.random.default_rng()
8✔
57
    if backend == "array_api_strict":
8✔
58
        return Generator(seed=42)
8✔
59
    msg = "the array backend in not supported"
×
UNCOV
60
    raise NotImplementedError(msg)
×
61

62

63
class Generator:
8✔
64
    """NumPy random number generator returning array_api_strict Array."""
65

66
    __slots__ = ("rng",)
8✔
67

68
    def __init__(
8✔
69
        self,
70
        seed: int | bool | NDArray[np.int_ | np.bool] | None = None,  # noqa: FBT001
71
    ) -> None:
72
        self.rng = numpy.random.default_rng(seed=seed)  # type: ignore[arg-type]
8✔
73

74
    def random(
8✔
75
        self,
76
        size: Size = None,
77
        dtype: DTypeLike | None = np.float64,
78
        out: NDArray[Any] | None = None,
79
    ) -> AArray:
80
        """Return random floats in the half-open interval [0.0, 1.0)."""
81
        return array_api_strict.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
82

83
    def normal(
8✔
84
        self,
85
        loc: float | NDArray[np.floating] = 0.0,
86
        scale: float | NDArray[np.floating] = 1.0,
87
        size: Size = None,
88
    ) -> AArray:
89
        """Draw samples from a Normal distribution (mean=loc, stdev=scale)."""
90
        return array_api_strict.asarray(self.rng.normal(loc, scale, size))
8✔
91

92
    def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArray:
8✔
93
        """Draw samples from a Poisson distribution."""
94
        return array_api_strict.asarray(self.rng.poisson(lam, size))
8✔
95

96
    def standard_normal(
8✔
97
        self,
98
        size: Size = None,
99
        dtype: DTypeLike | None = np.float64,
100
        out: NDArray[Any] | None = None,
101
    ) -> AArray:
102
        """Draw samples from a standard Normal distribution (mean=0, stdev=1)."""
103
        return array_api_strict.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
104

105
    def uniform(
8✔
106
        self,
107
        low: float | NDArray[np.floating] = 0.0,
108
        high: float | NDArray[np.floating] = 1.0,
109
        size: Size = None,
110
    ) -> AArray:
111
        """Draw samples from a Uniform distribution."""
112
        return array_api_strict.asarray(self.rng.uniform(low, high, size))
8✔
113

114

115
class XPAdditions:
8✔
116
    """
117
    Additional functions missing from both array-api-strict and array-api-extra.
118

119
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
120
    for details.
121
    """
122

123
    xp: ModuleType
8✔
124
    backend: str
8✔
125

126
    def __init__(self, xp: ModuleType) -> None:
8✔
127
        self.xp = xp
8✔
128
        self.backend = xp.__name__
8✔
129

130
    def trapezoid(
8✔
131
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
132
    ) -> AnyArray:
133
        """
134
        Integrate along the given axis using the composite trapezoidal rule.
135

136
        See https://github.com/glass-dev/glass/issues/646
137
        """
138
        self.backend = self.xp.__name__
8✔
139
        if self.backend == "jax.numpy":
8✔
140
            import glass.jax  # noqa: PLC0415
8✔
141

142
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
143
        if self.backend == "numpy":
8✔
144
            return np.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
145
        if self.backend == "array_api_strict":
8✔
146
            # Using design principle of scipy (i.e. copy, use np, copy back)
147
            y_np = np.asarray(y, copy=True)
8✔
148
            x_np = np.asarray(x, copy=True)
8✔
149
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
150
            return self.xp.asarray(result_np, copy=True)
8✔
151

152
        msg = "the array backend in not supported"
×
UNCOV
153
        raise NotImplementedError(msg)
×
154

155
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
156
        """
157
        Compute the set union of two 1D arrays.
158

159
        See https://github.com/glass-dev/glass/issues/647
160
        """
161
        if self.backend == "jax.numpy":
8✔
162
            import glass.jax  # noqa: PLC0415
8✔
163

164
            return glass.jax.union1d(ar1, ar2)
8✔
165
        if self.backend == "numpy":
8✔
166
            return np.union1d(ar1, ar2)
8✔
167
        if self.backend == "array_api_strict":
8✔
168
            # Using design principle of scipy (i.e. copy, use np, copy back)
169
            ar1_np = np.asarray(ar1, copy=True)
8✔
170
            ar2_np = np.asarray(ar2, copy=True)
8✔
171
            result_np = np.union1d(ar1_np, ar2_np)
8✔
172
            return self.xp.asarray(result_np, copy=True)
8✔
173

174
        msg = "the array backend in not supported"
×
UNCOV
175
        raise NotImplementedError(msg)
×
176

177
    def interp(  # noqa: PLR0913
8✔
178
        self,
179
        x: AnyArray,
180
        x_points: AnyArray,
181
        y_points: AnyArray,
182
        left: float | None = None,
183
        right: float | None = None,
184
        period: float | None = None,
185
    ) -> AnyArray:
186
        """
187
        One-dimensional linear interpolation for monotonically increasing
188
        sample points.
189

190
        See https://github.com/glass-dev/glass/issues/650
191
        """
192
        if self.backend == "jax.numpy":
8✔
193
            import glass.jax  # noqa: PLC0415
8✔
194

195
            return glass.jax.interp(
8✔
196
                x, x_points, y_points, left=left, right=right, period=period
197
            )
198
        if self.backend == "numpy":
8✔
199
            return np.interp(
8✔
200
                x, x_points, y_points, left=left, right=right, period=period
201
            )
202
        if self.backend == "array_api_strict":
8✔
203
            # Using design principle of scipy (i.e. copy, use np, copy back)
204
            x_np = np.asarray(x, copy=True)
8✔
205
            x_points_np = np.asarray(x_points, copy=True)
8✔
206
            y_points_np = np.asarray(y_points, copy=True)
8✔
207
            result_np = np.interp(
8✔
208
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
209
            )
210
            return self.xp.asarray(result_np, copy=True)
8✔
211

212
        msg = "the array backend in not supported"
×
213
        raise NotImplementedError(msg)
×
214

215
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
216
        """
217
        Return the gradient of an N-dimensional array.
218

219
        See https://github.com/glass-dev/glass/issues/648
220
        """
221
        if self.backend == "jax.numpy":
8✔
222
            import glass.jax  # noqa: PLC0415
8✔
223

224
            return glass.jax.gradient(f)
8✔
225
        if self.backend == "numpy":
8✔
226
            return np.gradient(f)
8✔
227
        if self.backend == "array_api_strict":
8✔
228
            # Using design principle of scipy (i.e. copy, use np, copy back)
229
            f_np = np.asarray(f, copy=True)
8✔
230
            result_np = np.gradient(f_np)
8✔
231
            return self.xp.asarray(result_np, copy=True)
8✔
232

233
        msg = "the array backend in not supported"
×
234
        raise NotImplementedError(msg)
×
235

236
    def linalg_lstsq(
8✔
237
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
238
    ) -> tuple[AnyArray, AnyArray, AnyArray, AnyArray]:
239
        """
240
        Return the gradient of an N-dimensional array.
241

242
        See https://github.com/glass-dev/glass/issues/649
243
        """
244
        if self.backend == "jax.numpy":
8✔
245
            import glass.jax  # noqa: PLC0415
8✔
246

247
            return glass.jax.linalg_lstsq(a, b, rcond=rcond)
8✔
248
        if self.backend == "numpy":
8✔
249
            return np.linalg.lstsq(a, b, rcond=rcond)
8✔
250
        if self.backend == "array_api_strict":
8✔
251
            # Using design principle of scipy (i.e. copy, use np, copy back)
252
            a_np = np.asarray(a, copy=True)
8✔
253
            b_np = np.asarray(b, copy=True)
8✔
254
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
255
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
256

257
        msg = "the array backend in not supported"
×
258
        raise NotImplementedError(msg)
×
259

260
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
261
        """
262
        Evaluates the Einstein summation convention on the operands.
263

264
        See https://github.com/glass-dev/glass/issues/657
265
        """
266
        if self.backend == "jax.numpy":
8✔
267
            import glass.jax  # noqa: PLC0415
×
268

269
            return glass.jax.einsum(subscripts, *operands)
×
270
        if self.backend == "numpy":
8✔
271
            return np.einsum(subscripts, *operands)
8✔
272
        if self.backend == "array_api_strict":
8✔
273
            # Using design principle of scipy (i.e. copy, use np, copy back)
274
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
275
            result_np = np.einsum(subscripts, *operands_np)
8✔
276
            return self.xp.asarray(result_np, copy=True)
8✔
277

278
        msg = "the array backend in not supported"
×
279
        raise NotImplementedError(msg)
×
280

281
    def apply_along_axis(
8✔
282
        self,
283
        func1d: FunctionType,
284
        axis: int,
285
        arr: AnyArray,
286
        *args: object,
287
        **kwargs: object,
288
    ) -> AnyArray:
289
        """
290
        Apply a function to 1-D slices along the given axis.
291

292
        See https://github.com/glass-dev/glass/issues/651
293
        """
294
        if self.backend == "jax.numpy":
8✔
295
            import glass.jax  # noqa: PLC0415
8✔
296

297
            return glass.jax.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
298
        if self.backend == "numpy":
8✔
299
            return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
300
        if self.backend == "array_api_strict":
8✔
301
            # Using design principle of scipy (i.e. copy, use np, copy back)
302
            arr_np = np.asarray(arr, copy=True)
8✔
303
            result_np = np.apply_along_axis(func1d, axis, arr_np, *args, **kwargs)
8✔
304
            return self.xp.asarray(result_np, copy=True)
8✔
305

306
        msg = "the array backend in not supported"
×
307
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc