• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18496983159

14 Oct 2025 12:46PM UTC coverage: 94.471% (-0.9%) from 95.38%
18496983159

Pull #643

github

web-flow
Merge 772d55885 into 1d0b73949
Pull Request #643: gh-408: porting straightforward functions in `fields`

200 of 202 branches covered (99.01%)

Branch coverage included in aggregate %.

187 of 212 new or added lines in 5 files covered. (88.21%)

1 existing line in 1 file now uncovered.

1389 of 1480 relevant lines covered (93.85%)

7.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.74
/glass/_array_api_utils.py
1
from __future__ import annotations
8✔
2

3
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
4

5
import array_api_strict
8✔
6
import numpy as np
8✔
7
import numpy.random
8✔
8

9
import glass.jax
8✔
10

11
if TYPE_CHECKING:
12
    from types import FunctionType, ModuleType
13

14
    from array_api_strict._array_object import Array as AArray
15
    from jaxtyping import Array as JAXArray
16
    from numpy.typing import DTypeLike, NDArray
17

18
    Size: TypeAlias = int | tuple[int, ...] | None
19
    GLASSAnyArray: TypeAlias = JAXArray | NDArray[Any]
20
    GLASSFloatArray: TypeAlias = JAXArray | NDArray[np.float64]
21
    GLASSComplexArray: TypeAlias = JAXArray | NDArray[np.complex128]
22

23
    GlassAnyArray: TypeAlias = NDArray[Any] | JAXArray
24
    GlassFloatArray: TypeAlias = NDArray[np.float64] | JAXArray
25

26

27
def get_namespace(*arrays: GlassAnyArray) -> ModuleType:
8✔
28
    """
29
    Return the array library (array namespace) of input arrays
30
    if they belong to the same library or raise a :class:`ValueError`
31
    if they do not.
32
    """
33
    namespace = arrays[0].__array_namespace__()
8✔
34
    if any(
8✔
35
        array.__array_namespace__() != namespace
36
        for array in arrays
37
        if array is not None
38
    ):
39
        msg = "input arrays should belong to the same array library"
8✔
40
        raise ValueError(msg)
8✔
41

42
    return namespace
8✔
43

44

45
def rng_dispatcher(array: NDArray[Any] | JAXArray) -> UnifiedGenerator:
8✔
46
    """Dispatch RNG on the basis of the provided array."""
47
    backend = array.__array_namespace__().__name__
8✔
48
    if backend == "jax.numpy":
8✔
49
        return glass.jax.Generator(seed=42)
8✔
50
    if backend == "numpy":
8✔
51
        return np.random.default_rng()
8✔
52
    if backend == "array_api_strict":
8✔
53
        return Generator(seed=42)
8✔
54
    msg = "the array backend in not supported"
×
55
    raise NotImplementedError(msg)
×
56

57

58
class Generator:
8✔
59
    """NumPy random number generator returning array_api_strict Array."""
60

61
    __slots__ = ("rng",)
8✔
62

63
    def __init__(
8✔
64
        self, seed: int | bool | NDArray[np.int_ | np.bool] | None = None
65
    ) -> None:
66
        self.rng = numpy.random.default_rng(seed=seed)  # type: ignore[arg-type]
8✔
67

68
    def random(
8✔
69
        self,
70
        size: Size = None,
71
        dtype: DTypeLike | None = np.float64,
72
        out: NDArray[Any] | None = None,
73
    ) -> AArray:
74
        """Return random floats in the half-open interval [0.0, 1.0)."""
75
        return array_api_strict.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
76

77
    def normal(
8✔
78
        self,
79
        loc: float | NDArray[np.floating] = 0.0,
80
        scale: float | NDArray[np.floating] = 1.0,
81
        size: Size = None,
82
    ) -> AArray:
83
        """Draw samples from a Normal distribution (mean=loc, stdev=scale)."""
84
        return array_api_strict.asarray(self.rng.normal(loc, scale, size))
8✔
85

86
    def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArray:
8✔
87
        """Draw samples from a Poisson distribution."""
88
        return array_api_strict.asarray(self.rng.poisson(lam, size))
8✔
89

90
    def standard_normal(
8✔
91
        self,
92
        size: Size = None,
93
        dtype: DTypeLike | None = np.float64,
94
        out: NDArray[Any] | None = None,
95
    ) -> AArray:
96
        """Draw samples from a standard Normal distribution (mean=0, stdev=1)."""
97
        return array_api_strict.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
98

99
    def uniform(
8✔
100
        self,
101
        low: float | NDArray[np.floating] = 0.0,
102
        high: float | NDArray[np.floating] = 1.0,
103
        size: Size = None,
104
    ) -> AArray:
105
        """Draw samples from a Uniform distribution."""
106
        return array_api_strict.asarray(self.rng.uniform(low, high, size))
8✔
107

108

109
UnifiedGenerator: TypeAlias = np.random.Generator | glass.jax.Generator | Generator
8✔
110

111

112
class GlassXPAdditions:
8✔
113
    """
114
    Additional functions missing from both array-api-strict and array-api-extra.
115

116
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
117
    for details.
118
    """
119

120
    xp: ModuleType
8✔
121
    backend: str
8✔
122

123
    def __init__(self, xp: ModuleType) -> None:
8✔
124
        self.xp = xp
8✔
125
        self.backend = xp.__name__
8✔
126

127
    def trapezoid(
8✔
128
        self, y: GlassAnyArray, x: GlassAnyArray = None, dx: float = 1.0, axis: int = -1
129
    ) -> GlassAnyArray:
130
        """
131
        Integrate along the given axis using the composite trapezoidal rule.
132

133
        See https://github.com/glass-dev/glass/issues/646
134
        """
135
        self.backend = self.xp.__name__
8✔
136
        if self.backend == "jax.numpy":
8✔
137
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
138
        if self.backend == "numpy":
8✔
139
            return np.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
140
        if self.backend == "array_api_strict":
8✔
141
            # Using design principle of scipy (i.e. copy, use np, copy back)
142
            y_np = np.asarray(y, copy=True)
8✔
143
            x_np = np.asarray(x, copy=True)
8✔
144
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
145
            return self.xp.asarray(result_np, copy=True)
8✔
146

NEW
147
        msg = "the array backend in not supported"
×
NEW
148
        raise NotImplementedError(msg)
×
149

150
    def union1d(self, ar1: GlassAnyArray, ar2: GlassAnyArray) -> GlassAnyArray:
8✔
151
        """
152
        Compute the set union of two 1D arrays.
153

154
        See https://github.com/glass-dev/glass/issues/647
155
        """
156
        if self.backend == "jax.numpy":
8✔
157
            return glass.jax.union1d(ar1, ar2)
8✔
158
        if self.backend == "numpy":
8✔
159
            return np.union1d(ar1, ar2)
8✔
160
        if self.backend == "array_api_strict":
8✔
161
            # Using design principle of scipy (i.e. copy, use np, copy back)
162
            ar1_np = np.asarray(ar1, copy=True)
8✔
163
            ar2_np = np.asarray(ar2, copy=True)
8✔
164
            result_np = np.union1d(ar1_np, ar2_np)
8✔
165
            return self.xp.asarray(result_np, copy=True)
8✔
166

NEW
167
        msg = "the array backend in not supported"
×
NEW
168
        raise NotImplementedError(msg)
×
169

170
    def interp(  # noqa: PLR0913
8✔
171
        self,
172
        x: GlassAnyArray,
173
        x_points: GlassAnyArray,
174
        y_points: GlassAnyArray,
175
        left: float | None = None,
176
        right: float | None = None,
177
        period: float | None = None,
178
    ) -> GlassAnyArray:
179
        """
180
        One-dimensional linear interpolation for monotonically increasing
181
        sample points.
182

183
        See https://github.com/glass-dev/glass/issues/650
184
        """
185
        if self.backend == "jax.numpy":
8✔
186
            return glass.jax.interp(
8✔
187
                x, x_points, y_points, left=left, right=right, period=period
188
            )
189
        if self.backend == "numpy":
8✔
190
            return np.interp(
8✔
191
                x, x_points, y_points, left=left, right=right, period=period
192
            )
193
        if self.backend == "array_api_strict":
8✔
194
            # Using design principle of scipy (i.e. copy, use np, copy back)
195
            x_np = np.asarray(x, copy=True)
8✔
196
            x_points_np = np.asarray(x_points, copy=True)
8✔
197
            y_points_np = np.asarray(y_points, copy=True)
8✔
198
            result_np = np.interp(
8✔
199
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
200
            )
201
            return self.xp.asarray(result_np, copy=True)
8✔
202

NEW
203
        msg = "the array backend in not supported"
×
NEW
204
        raise NotImplementedError(msg)
×
205

206
    def gradient(self, f: GlassAnyArray) -> GlassAnyArray:
8✔
207
        """
208
        Return the gradient of an N-dimensional array.
209

210
        See https://github.com/glass-dev/glass/issues/648
211
        """
212
        if self.backend == "jax.numpy":
8✔
213
            return glass.jax.gradient(f)
8✔
214
        if self.backend == "numpy":
8✔
215
            return np.gradient(f)
8✔
216
        if self.backend == "array_api_strict":
8✔
217
            # Using design principle of scipy (i.e. copy, use np, copy back)
218
            f_np = np.asarray(f, copy=True)
8✔
219
            result_np = np.gradient(f_np)
8✔
220
            return self.xp.asarray(result_np, copy=True)
8✔
221

NEW
222
        msg = "the array backend in not supported"
×
NEW
223
        raise NotImplementedError(msg)
×
224

225
    def linalg_lstsq(
8✔
226
        self, a: GlassAnyArray, b: GlassAnyArray, rcond: float | None = None
227
    ) -> tuple[GlassAnyArray, GlassAnyArray, GlassAnyArray, GlassAnyArray]:
228
        """
229
        Return the gradient of an N-dimensional array.
230

231
        See https://github.com/glass-dev/glass/issues/649
232
        """
233
        if self.backend == "jax.numpy":
8✔
234
            return glass.jax.linalg_lstsq(a, b, rcond=rcond)
8✔
235
        if self.backend == "numpy":
8✔
236
            return np.linalg.lstsq(a, b, rcond=rcond)
8✔
237
        if self.backend == "array_api_strict":
8✔
238
            # Using design principle of scipy (i.e. copy, use np, copy back)
239
            a_np = np.asarray(a, copy=True)
8✔
240
            b_np = np.asarray(b, copy=True)
8✔
241
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
242
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
243

NEW
244
        msg = "the array backend in not supported"
×
NEW
245
        raise NotImplementedError(msg)
×
246

247
    def einsum(self, subscripts: str, *operands: GlassAnyArray) -> GlassAnyArray:
8✔
248
        """
249
        Evaluates the Einstein summation convention on the operands.
250

251
        See https://github.com/glass-dev/glass/issues/657
252
        """
253
        if self.backend == "jax.numpy":
8✔
NEW
254
            return glass.jax.einsum(subscripts, *operands)
×
255
        if self.backend == "numpy":
8✔
256
            return np.einsum(subscripts, *operands)
8✔
257
        if self.backend == "array_api_strict":
8✔
258
            # Using design principle of scipy (i.e. copy, use np, copy back)
259
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
260
            result_np = np.einsum(subscripts, *operands_np)
8✔
261
            return self.xp.asarray(result_np, copy=True)
8✔
262

NEW
263
        msg = "the array backend in not supported"
×
NEW
264
        raise NotImplementedError(msg)
×
265

266
    def apply_along_axis(
8✔
267
        self,
268
        func1d: FunctionType,
269
        axis: int,
270
        arr: GlassAnyArray,
271
        *args: object,
272
        **kwargs: object,
273
    ) -> GlassAnyArray:
274
        """
275
        Apply a function to 1-D slices along the given axis.
276

277
        See https://github.com/glass-dev/glass/issues/651
278
        """
279
        if self.backend == "jax.numpy":
8✔
280
            return glass.jax.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
281
        if self.backend == "numpy":
8✔
282
            return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
283
        if self.backend == "array_api_strict":
8✔
284
            # Using design principle of scipy (i.e. copy, use np, copy back)
285
            arr_np = np.asarray(arr, copy=True)
8✔
286
            result_np = np.apply_along_axis(func1d, axis, arr_np, *args, **kwargs)
8✔
287
            return self.xp.asarray(result_np, copy=True)
8✔
288

NEW
289
        msg = "the array backend in not supported"
×
NEW
290
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc