• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18493491435

14 Oct 2025 10:25AM UTC coverage: 94.424%. First build
18493491435

Pull #643

github

web-flow
Merge 090fb00ed into 6f6ee4c58
Pull Request #643: gh-408: porting straightforward functions in `fields`

200 of 202 branches covered (99.01%)

Branch coverage included in aggregate %.

187 of 212 new or added lines in 5 files covered. (88.21%)

1375 of 1466 relevant lines covered (93.79%)

7.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.74
/glass/_array_api_utils.py
1
from __future__ import annotations
8✔
2

3
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
4

5
import array_api_strict
8✔
6
import numpy as np
8✔
7
import numpy.random
8✔
8

9
import glass.jax
8✔
10

11
if TYPE_CHECKING:
12
    from types import FunctionType, ModuleType
13

14
    from array_api_strict._array_object import Array as AArray
15
    from jaxtyping import Array as JAXArray
16
    from numpy.typing import DTypeLike, NDArray
17

18
    Size: TypeAlias = int | tuple[int, ...] | None
19

20
    GlassAnyArray: TypeAlias = NDArray[Any] | JAXArray
21
    GlassFloatArray: TypeAlias = NDArray[np.float64] | JAXArray
22

23

24
def get_namespace(*arrays: GlassAnyArray) -> ModuleType:
8✔
25
    """
26
    Return the array library (array namespace) of input arrays
27
    if they belong to the same library or raise a :class:`ValueError`
28
    if they do not.
29
    """
30
    namespace = arrays[0].__array_namespace__()
8✔
31
    if any(
8✔
32
        array.__array_namespace__() != namespace
33
        for array in arrays
34
        if array is not None
35
    ):
36
        msg = "input arrays should belong to the same array library"
8✔
37
        raise ValueError(msg)
8✔
38

39
    return namespace
8✔
40

41

42
def rng_dispatcher(array: NDArray[Any] | JAXArray) -> UnifiedGenerator:
8✔
43
    """Dispatch RNG on the basis of the provided array."""
44
    backend = array.__array_namespace__().__name__
8✔
45
    if backend == "jax.numpy":
8✔
46
        return glass.jax.Generator(seed=42)
8✔
47
    if backend == "numpy":
8✔
48
        return np.random.default_rng()
8✔
49
    if backend == "array_api_strict":
8✔
50
        return Generator(seed=42)
8✔
51
    msg = "the array backend in not supported"
×
52
    raise NotImplementedError(msg)
×
53

54

55
class Generator:
8✔
56
    """NumPy random number generator returning array_api_strict Array."""
57

58
    __slots__ = ("rng",)
8✔
59

60
    def __init__(
8✔
61
        self, seed: int | bool | NDArray[np.int_ | np.bool] | None = None
62
    ) -> None:
63
        self.rng = numpy.random.default_rng(seed=seed)  # type: ignore[arg-type]
8✔
64

65
    def random(
8✔
66
        self,
67
        size: Size = None,
68
        dtype: DTypeLike | None = np.float64,
69
        out: NDArray[Any] | None = None,
70
    ) -> AArray:
71
        """Return random floats in the half-open interval [0.0, 1.0)."""
72
        return array_api_strict.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
73

74
    def normal(
8✔
75
        self,
76
        loc: float | NDArray[np.floating] = 0.0,
77
        scale: float | NDArray[np.floating] = 1.0,
78
        size: Size = None,
79
    ) -> AArray:
80
        """Draw samples from a Normal distribution (mean=loc, stdev=scale)."""
81
        return array_api_strict.asarray(self.rng.normal(loc, scale, size))
8✔
82

83
    def poisson(self, lam: float | NDArray[np.floating], size: Size = None) -> AArray:
8✔
84
        """Draw samples from a Poisson distribution."""
85
        return array_api_strict.asarray(self.rng.poisson(lam, size))
8✔
86

87
    def standard_normal(
8✔
88
        self,
89
        size: Size = None,
90
        dtype: DTypeLike | None = np.float64,
91
        out: NDArray[Any] | None = None,
92
    ) -> AArray:
93
        """Draw samples from a standard Normal distribution (mean=0, stdev=1)."""
94
        return array_api_strict.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
95

96
    def uniform(
8✔
97
        self,
98
        low: float | NDArray[np.floating] = 0.0,
99
        high: float | NDArray[np.floating] = 1.0,
100
        size: Size = None,
101
    ) -> AArray:
102
        """Draw samples from a Uniform distribution."""
103
        return array_api_strict.asarray(self.rng.uniform(low, high, size))
8✔
104

105

106
UnifiedGenerator: TypeAlias = np.random.Generator | glass.jax.Generator | Generator
8✔
107

108

109
class GlassXPAdditions:
8✔
110
    """
111
    Additional functions missing from both array-api-strict and array-api-extra.
112

113
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
114
    for details.
115
    """
116

117
    xp: ModuleType
8✔
118
    backend: str
8✔
119

120
    def __init__(self, xp: ModuleType) -> None:
8✔
121
        self.xp = xp
8✔
122
        self.backend = xp.__name__
8✔
123

124
    def trapezoid(
8✔
125
        self, y: GlassAnyArray, x: GlassAnyArray = None, dx: float = 1.0, axis: int = -1
126
    ) -> GlassAnyArray:
127
        """
128
        Integrate along the given axis using the composite trapezoidal rule.
129

130
        See https://github.com/glass-dev/glass/issues/646
131
        """
132
        self.backend = self.xp.__name__
8✔
133
        if self.backend == "jax.numpy":
8✔
134
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
135
        if self.backend == "numpy":
8✔
136
            return np.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
137
        if self.backend == "array_api_strict":
8✔
138
            # Using design principle of scipy (i.e. copy, use np, copy back)
139
            y_np = np.asarray(y, copy=True)
8✔
140
            x_np = np.asarray(x, copy=True)
8✔
141
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
142
            return self.xp.asarray(result_np, copy=True)
8✔
143

NEW
144
        msg = "the array backend in not supported"
×
NEW
145
        raise NotImplementedError(msg)
×
146

147
    def union1d(self, ar1: GlassAnyArray, ar2: GlassAnyArray) -> GlassAnyArray:
8✔
148
        """
149
        Compute the set union of two 1D arrays.
150

151
        See https://github.com/glass-dev/glass/issues/647
152
        """
153
        if self.backend == "jax.numpy":
8✔
154
            return glass.jax.union1d(ar1, ar2)
8✔
155
        if self.backend == "numpy":
8✔
156
            return np.union1d(ar1, ar2)
8✔
157
        if self.backend == "array_api_strict":
8✔
158
            # Using design principle of scipy (i.e. copy, use np, copy back)
159
            ar1_np = np.asarray(ar1, copy=True)
8✔
160
            ar2_np = np.asarray(ar2, copy=True)
8✔
161
            result_np = np.union1d(ar1_np, ar2_np)
8✔
162
            return self.xp.asarray(result_np, copy=True)
8✔
163

NEW
164
        msg = "the array backend in not supported"
×
NEW
165
        raise NotImplementedError(msg)
×
166

167
    def interp(  # noqa: PLR0913
8✔
168
        self,
169
        x: GlassAnyArray,
170
        x_points: GlassAnyArray,
171
        y_points: GlassAnyArray,
172
        left: float | None = None,
173
        right: float | None = None,
174
        period: float | None = None,
175
    ) -> GlassAnyArray:
176
        """
177
        One-dimensional linear interpolation for monotonically increasing
178
        sample points.
179

180
        See https://github.com/glass-dev/glass/issues/650
181
        """
182
        if self.backend == "jax.numpy":
8✔
183
            return glass.jax.interp(
8✔
184
                x, x_points, y_points, left=left, right=right, period=period
185
            )
186
        if self.backend == "numpy":
8✔
187
            return np.interp(
8✔
188
                x, x_points, y_points, left=left, right=right, period=period
189
            )
190
        if self.backend == "array_api_strict":
8✔
191
            # Using design principle of scipy (i.e. copy, use np, copy back)
192
            x_np = np.asarray(x, copy=True)
8✔
193
            x_points_np = np.asarray(x_points, copy=True)
8✔
194
            y_points_np = np.asarray(y_points, copy=True)
8✔
195
            result_np = np.interp(
8✔
196
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
197
            )
198
            return self.xp.asarray(result_np, copy=True)
8✔
199

NEW
200
        msg = "the array backend in not supported"
×
NEW
201
        raise NotImplementedError(msg)
×
202

203
    def gradient(self, f: GlassAnyArray) -> GlassAnyArray:
8✔
204
        """
205
        Return the gradient of an N-dimensional array.
206

207
        See https://github.com/glass-dev/glass/issues/648
208
        """
209
        if self.backend == "jax.numpy":
8✔
210
            return glass.jax.gradient(f)
8✔
211
        if self.backend == "numpy":
8✔
212
            return np.gradient(f)
8✔
213
        if self.backend == "array_api_strict":
8✔
214
            # Using design principle of scipy (i.e. copy, use np, copy back)
215
            f_np = np.asarray(f, copy=True)
8✔
216
            result_np = np.gradient(f_np)
8✔
217
            return self.xp.asarray(result_np, copy=True)
8✔
218

NEW
219
        msg = "the array backend in not supported"
×
NEW
220
        raise NotImplementedError(msg)
×
221

222
    def linalg_lstsq(
8✔
223
        self, a: GlassAnyArray, b: GlassAnyArray, rcond: float | None = None
224
    ) -> tuple[GlassAnyArray, GlassAnyArray, GlassAnyArray, GlassAnyArray]:
225
        """
226
        Return the gradient of an N-dimensional array.
227

228
        See https://github.com/glass-dev/glass/issues/649
229
        """
230
        if self.backend == "jax.numpy":
8✔
231
            return glass.jax.linalg_lstsq(a, b, rcond=rcond)
8✔
232
        if self.backend == "numpy":
8✔
233
            return np.linalg.lstsq(a, b, rcond=rcond)
8✔
234
        if self.backend == "array_api_strict":
8✔
235
            # Using design principle of scipy (i.e. copy, use np, copy back)
236
            a_np = np.asarray(a, copy=True)
8✔
237
            b_np = np.asarray(b, copy=True)
8✔
238
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
239
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
240

NEW
241
        msg = "the array backend in not supported"
×
NEW
242
        raise NotImplementedError(msg)
×
243

244
    def einsum(self, subscripts: str, *operands: GlassAnyArray) -> GlassAnyArray:
8✔
245
        """
246
        Evaluates the Einstein summation convention on the operands.
247

248
        See https://github.com/glass-dev/glass/issues/657
249
        """
250
        if self.backend == "jax.numpy":
8✔
NEW
251
            return glass.jax.einsum(subscripts, *operands)
×
252
        if self.backend == "numpy":
8✔
253
            return np.einsum(subscripts, *operands)
8✔
254
        if self.backend == "array_api_strict":
8✔
255
            # Using design principle of scipy (i.e. copy, use np, copy back)
256
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
257
            result_np = np.einsum(subscripts, *operands_np)
8✔
258
            return self.xp.asarray(result_np, copy=True)
8✔
259

NEW
260
        msg = "the array backend in not supported"
×
NEW
261
        raise NotImplementedError(msg)
×
262

263
    def apply_along_axis(
8✔
264
        self,
265
        func1d: FunctionType,
266
        axis: int,
267
        arr: GlassAnyArray,
268
        *args: object,
269
        **kwargs: object,
270
    ) -> GlassAnyArray:
271
        """
272
        Apply a function to 1-D slices along the given axis.
273

274
        See https://github.com/glass-dev/glass/issues/651
275
        """
276
        if self.backend == "jax.numpy":
8✔
277
            return glass.jax.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
278
        if self.backend == "numpy":
8✔
279
            return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
280
        if self.backend == "array_api_strict":
8✔
281
            # Using design principle of scipy (i.e. copy, use np, copy back)
282
            arr_np = np.asarray(arr, copy=True)
8✔
283
            result_np = np.apply_along_axis(func1d, axis, arr_np, *args, **kwargs)
8✔
284
            return self.xp.asarray(result_np, copy=True)
8✔
285

NEW
286
        msg = "the array backend in not supported"
×
NEW
287
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc