• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 21030569361

15 Jan 2026 12:01PM UTC coverage: 93.785%. Remained the same
21030569361

Pull #971

github

web-flow
Merge 5fdc16647 into bee3e55d2
Pull Request #971: gh-970: Move rng code to its own module

211 of 213 branches covered (99.06%)

Branch coverage included in aggregate %.

35 of 37 new or added lines in 5 files covered. (94.59%)

24 existing lines in 1 file now uncovered.

1313 of 1412 relevant lines covered (92.99%)

5.06 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.82
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
==============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
18

19
from typing import TYPE_CHECKING, Any
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26

27
    from glass._types import AnyArray
28

29

30
class CompatibleBackendNotFoundError(Exception):
6✔
31
    """
32
    Exception raised when an array library backend that
33
    implements a requested function, is not found.
34
    """
35

36
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
6✔
UNCOV
37
        self.message = (
×
38
            f"{missing_backend} is required here as "
39
            "no alternative has been provided by the user."
40
            if users_backend is None
41
            else f"GLASS depends on functions not supported by {users_backend}"
42
        )
UNCOV
43
        super().__init__(self.message)
×
44

45

46
def import_numpy(backend: str | None = None) -> ModuleType:
6✔
47
    """
48
    Import the NumPy module, raising a helpful error if NumPy is not installed.
49

50
    Parameters
51
    ----------
52
    backend
53
        The name of the backend requested by the user.
54

55
    Returns
56
    -------
57
        The NumPy module.
58

59
    Raises
60
    ------
61
    ModuleNotFoundError
62
        If NumPy is not found in the user's environment.
63

64
    Notes
65
    -----
66
    This is useful for explaining to the user why NumPy is required when their chosen
67
    backend does not implement a needed function.
68
    """
69
    try:
5✔
70
        import numpy  # noqa: ICN001, PLC0415
71

UNCOV
72
    except ModuleNotFoundError as err:
×
UNCOV
73
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
74
    else:
75
        return numpy
5✔
76

77

78
def default_xp() -> ModuleType:
6✔
79
    """Returns the library backend we default to if none is specified by the user."""
80
    return import_numpy()
5✔
81

82

83
class XPAdditions:
6✔
84
    """
85
    Additional functions missing from both array-api-strict and array-api-extra.
86

87
    This class provides wrappers for common array operations such as integration,
88
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
89
    and array-api-strict backends.
90

91
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
92
    for details.
93
    """
94

95
    xp: ModuleType
6✔
96
    backend: str
6✔
97

98
    def __init__(self, xp: ModuleType) -> None:
6✔
99
        """
100
        Initialize XPAdditions with the given array namespace.
101

102
        Parameters
103
        ----------
104
        xp
105
            The array library backend to use for array operations.
106
        """
107
        self.xp = xp
6✔
108

109
    def trapezoid(
6✔
110
        self,
111
        y: AnyArray,
112
        x: AnyArray = None,
113
        dx: float = 1.0,
114
        axis: int = -1,
115
    ) -> AnyArray:
116
        """
117
        Integrate along the given axis using the composite trapezoidal rule.
118

119
        Parameters
120
        ----------
121
        y
122
            Input array to integrate.
123
        x
124
            Sample points corresponding to y.
125
        dx
126
            Spacing between sample points.
127
        axis
128
            Axis along which to integrate.
129

130
        Returns
131
        -------
132
            Integrated result.
133

134
        Raises
135
        ------
136
        NotImplementedError
137
            If the array backend is not supported.
138

139
        Notes
140
        -----
141
        See https://github.com/glass-dev/glass/issues/646
142
        """
143
        if self.xp.__name__ == "jax.numpy":
6✔
144
            import glass.jax  # noqa: PLC0415
145

146
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
5✔
147

148
        if self.xp.__name__ == "numpy":
6✔
149
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
6✔
150

151
        if self.xp.__name__ == "array_api_strict":
5✔
152
            np = import_numpy(self.xp.__name__)
5✔
153

154
            # Using design principle of scipy (i.e. copy, use np, copy back)
155
            y_np = np.asarray(y, copy=True)
5✔
156
            x_np = np.asarray(x, copy=True)
5✔
157
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
5✔
158
            return self.xp.asarray(result_np, copy=True)
5✔
159

UNCOV
160
        msg = "the array backend in not supported"
×
UNCOV
161
        raise NotImplementedError(msg)
×
162

163
    def interp(  # noqa: PLR0913
6✔
164
        self,
165
        x: AnyArray,
166
        x_points: AnyArray,
167
        y_points: AnyArray,
168
        left: float | None = None,
169
        right: float | None = None,
170
        period: float | None = None,
171
    ) -> AnyArray:
172
        """
173
        One-dimensional linear interpolation for monotonically increasing sample points.
174

175
        Parameters
176
        ----------
177
        x
178
            The x-coordinates at which to evaluate the interpolated values.
179
        x_points
180
            The x-coordinates of the data points.
181
        y_points
182
            The y-coordinates of the data points.
183
        left
184
            Value to return for x < x_points[0].
185
        right
186
            Value to return for x > x_points[-1].
187
        period
188
            Period for periodic interpolation.
189

190
        Returns
191
        -------
192
            Interpolated values.
193

194
        Raises
195
        ------
196
        NotImplementedError
197
            If the array backend is not supported.
198

199
        Notes
200
        -----
201
        See https://github.com/glass-dev/glass/issues/650
202
        """
203
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
6✔
204
            return self.xp.interp(
6✔
205
                x,
206
                x_points,
207
                y_points,
208
                left=left,
209
                right=right,
210
                period=period,
211
            )
212

213
        if self.xp.__name__ == "array_api_strict":
5✔
214
            np = import_numpy(self.xp.__name__)
5✔
215

216
            # Using design principle of scipy (i.e. copy, use np, copy back)
217
            x_np = np.asarray(x, copy=True)
5✔
218
            x_points_np = np.asarray(x_points, copy=True)
5✔
219
            y_points_np = np.asarray(y_points, copy=True)
5✔
220
            result_np = np.interp(
5✔
221
                x_np,
222
                x_points_np,
223
                y_points_np,
224
                left=left,
225
                right=right,
226
                period=period,
227
            )
228
            return self.xp.asarray(result_np, copy=True)
5✔
229

UNCOV
230
        msg = "the array backend in not supported"
×
UNCOV
231
        raise NotImplementedError(msg)
×
232

233
    def gradient(self, f: AnyArray) -> AnyArray:
6✔
234
        """
235
        Return the gradient of an N-dimensional array.
236

237
        Parameters
238
        ----------
239
        f
240
            Input array.
241

242
        Returns
243
        -------
244
            Gradient of the input array.
245

246
        Raises
247
        ------
248
        NotImplementedError
249
            If the array backend is not supported.
250

251
        Notes
252
        -----
253
        See https://github.com/glass-dev/glass/issues/648
254
        """
255
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
256
            return self.xp.gradient(f)
5✔
257

258
        if self.xp.__name__ == "array_api_strict":
5✔
259
            np = import_numpy(self.xp.__name__)
5✔
260

261
            # Using design principle of scipy (i.e. copy, use np, copy back)
262
            f_np = np.asarray(f, copy=True)
5✔
263
            result_np = np.gradient(f_np)
5✔
264
            return self.xp.asarray(result_np, copy=True)
5✔
265

UNCOV
266
        msg = "the array backend in not supported"
×
UNCOV
267
        raise NotImplementedError(msg)
×
268

269
    def linalg_lstsq(
6✔
270
        self,
271
        a: AnyArray,
272
        b: AnyArray,
273
        rcond: float | None = None,
274
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
275
        """
276
        Solve a linear least squares problem.
277

278
        Parameters
279
        ----------
280
        a
281
            Coefficient matrix.
282
        b
283
            Ordinate or "dependent variable" values.
284
        rcond
285
            Cut-off ratio for small singular values.
286

287
        Returns
288
        -------
289
        x
290
            Least-squares solution. If b is two-dimensional, the solutions are in the K
291
            columns of x.
292

293
        residuals
294
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
295
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
296
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
297

298
        rank
299
            Rank of matrix a.
300

301
        s
302
            Singular values of a.
303

304
        Raises
305
        ------
306
        NotImplementedError
307
            If the array backend is not supported.
308

309
        Notes
310
        -----
311
        See https://github.com/glass-dev/glass/issues/649
312
        """
313
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
314
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
5✔
315

316
        if self.xp.__name__ == "array_api_strict":
5✔
317
            np = import_numpy(self.xp.__name__)
5✔
318

319
            # Using design principle of scipy (i.e. copy, use np, copy back)
320
            a_np = np.asarray(a, copy=True)
5✔
321
            b_np = np.asarray(b, copy=True)
5✔
322
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
5✔
323
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
5✔
324

UNCOV
325
        msg = "the array backend in not supported"
×
UNCOV
326
        raise NotImplementedError(msg)
×
327

328
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
6✔
329
        """
330
        Evaluate the Einstein summation convention on the operands.
331

332
        Parameters
333
        ----------
334
        subscripts
335
            Specifies the subscripts for summation.
336
        *operands
337
            Arrays to be summed.
338

339
        Returns
340
        -------
341
            Result of the Einstein summation.
342

343
        Raises
344
        ------
345
        NotImplementedError
346
            If the array backend is not supported.
347

348
        Notes
349
        -----
350
        See https://github.com/glass-dev/glass/issues/657
351
        """
352
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
353
            return self.xp.einsum(subscripts, *operands)
5✔
354

355
        if self.xp.__name__ == "array_api_strict":
5✔
356
            np = import_numpy(self.xp.__name__)
5✔
357

358
            # Using design principle of scipy (i.e. copy, use np, copy back)
359
            operands_np = (np.asarray(op, copy=True) for op in operands)
5✔
360
            result_np = np.einsum(subscripts, *operands_np)
5✔
361
            return self.xp.asarray(result_np, copy=True)
5✔
362

UNCOV
363
        msg = "the array backend in not supported"
×
UNCOV
364
        raise NotImplementedError(msg)
×
365

366
    def apply_along_axis(
6✔
367
        self,
368
        func1d: Callable[..., Any],
369
        axis: int,
370
        arr: AnyArray,
371
        *args: object,
372
        **kwargs: object,
373
    ) -> AnyArray:
374
        """
375
        Apply a function to 1-D slices along the given axis.
376

377
        Parameters
378
        ----------
379
        func1d
380
            Function to apply to 1-D slices.
381
        axis
382
            Axis along which to apply the function.
383
        arr
384
            Input array.
385
        *args
386
            Additional positional arguments to pass to func1d.
387
        **kwargs
388
            Additional keyword arguments to pass to func1d.
389

390
        Returns
391
        -------
392
            Result of applying the function along the axis.
393

394
        Raises
395
        ------
396
        NotImplementedError
397
            If the array backend is not supported.
398

399
        Notes
400
        -----
401
        See https://github.com/glass-dev/glass/issues/651
402

403
        """
404
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
405
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
5✔
406

407
        if self.xp.__name__ == "array_api_strict":
5✔
408
            # Import here to prevent users relying on numpy unless in this instance
409
            np = import_numpy(self.xp.__name__)
5✔
410

411
            return self.xp.asarray(
5✔
412
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
413
                copy=True,
414
            )
415

UNCOV
416
        msg = "the array backend in not supported"
×
UNCOV
417
        raise NotImplementedError(msg)
×
418

419
    def vectorize(
6✔
420
        self,
421
        pyfunc: Callable[..., Any],
422
        otypes: tuple[type[float]],
423
    ) -> Callable[..., Any]:
424
        """
425
        Returns an object that acts like pyfunc, but takes arrays as input.
426

427
        Parameters
428
        ----------
429
        pyfunc
430
            Python function to vectorize.
431
        otypes
432
            Output types.
433

434
        Returns
435
        -------
436
            Vectorized function.
437

438
        Raises
439
        ------
440
        NotImplementedError
441
            If the array backend is not supported.
442

443
        Notes
444
        -----
445
        See https://github.com/glass-dev/glass/issues/671
446
        """
447
        if self.xp.__name__ == "numpy":
5✔
448
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
449

450
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
451
            # Import here to prevent users relying on numpy unless in this instance
452
            np = import_numpy(self.xp.__name__)
5✔
453

454
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
455

UNCOV
456
        msg = "the array backend in not supported"
×
UNCOV
457
        raise NotImplementedError(msg)
×
458

459
    def radians(self, deg_arr: AnyArray) -> AnyArray:
6✔
460
        """
461
        Convert angles from degrees to radians.
462

463
        Parameters
464
        ----------
465
        deg_arr
466
            Array of angles in degrees.
467

468
        Returns
469
        -------
470
            Array of angles in radians.
471

472
        Raises
473
        ------
474
        NotImplementedError
475
            If the array backend is not supported.
476
        """
477
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
478
            return self.xp.radians(deg_arr)
5✔
479

480
        if self.xp.__name__ == "array_api_strict":
5✔
481
            np = import_numpy(self.xp.__name__)
5✔
482

483
            return self.xp.asarray(np.radians(deg_arr))
5✔
484

UNCOV
485
        msg = "the array backend in not supported"
×
UNCOV
486
        raise NotImplementedError(msg)
×
487

488
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
6✔
489
        """
490
        Convert angles from radians to degrees.
491

492
        Parameters
493
        ----------
494
        deg_arr
495
            Array of angles in radians.
496

497
        Returns
498
        -------
499
            Array of angles in degrees.
500

501
        Raises
502
        ------
503
        NotImplementedError
504
            If the array backend is not supported.
505
        """
506
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
6✔
507
            return self.xp.degrees(deg_arr)
6✔
508

509
        if self.xp.__name__ == "array_api_strict":
5✔
510
            np = import_numpy(self.xp.__name__)
5✔
511

512
            return self.xp.asarray(np.degrees(deg_arr))
5✔
513

UNCOV
514
        msg = "the array backend in not supported"
×
UNCOV
515
        raise NotImplementedError(msg)
×
516

517
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
6✔
518
        """
519
        Wrapper for numpy.ndindex.
520

521
        See relevant docs for details:
522
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
523

524
        Raises
525
        ------
526
        NotImplementedError
527
            If the array backend is not supported.
528

529
        """
530
        if self.xp.__name__ == "numpy":
6✔
531
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
6✔
532

533
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
534
            np = import_numpy(self.xp.__name__)
5✔
535

536
            return np.ndindex(shape)  # type: ignore[no-any-return]
5✔
537

UNCOV
538
        msg = "the array backend in not supported"
×
UNCOV
539
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc