• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 21400369777

27 Jan 2026 02:09PM UTC coverage: 93.956% (-0.2%) from 94.174%
21400369777

Pull #911

github

web-flow
Merge e6b956b38 into 89a5c97e5
Pull Request #911: gh-910: consistent definition of displacement

211 of 213 branches covered (99.06%)

Branch coverage included in aggregate %.

4 of 4 new or added lines in 1 file covered. (100.0%)

50 existing lines in 6 files now uncovered.

1359 of 1458 relevant lines covered (93.21%)

5.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.89
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
==============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
18

19
import functools
20
from typing import TYPE_CHECKING, Any
21

22
import array_api_compat
23

24
if TYPE_CHECKING:
25
    from collections.abc import Callable, Sequence
26
    from types import ModuleType
27

28
    import numpy as np
29

30
    from glass._types import AnyArray, DTypeLike
31

32

33
class CompatibleBackendNotFoundError(Exception):
6✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
6✔
UNCOV
40
        self.message = (
×
41
            f"{missing_backend} is required here as "
42
            "no alternative has been provided by the user."
43
            if users_backend is None
44
            else f"GLASS depends on functions not supported by {users_backend}"
45
        )
UNCOV
46
        super().__init__(self.message)
×
47

48

49
def import_numpy(backend: str | None = None) -> ModuleType:
6✔
50
    """
51
    Import the NumPy module, raising a helpful error if NumPy is not installed.
52

53
    Parameters
54
    ----------
55
    backend
56
        The name of the backend requested by the user.
57

58
    Returns
59
    -------
60
        The NumPy module.
61

62
    Raises
63
    ------
64
    ModuleNotFoundError
65
        If NumPy is not found in the user's environment.
66

67
    Notes
68
    -----
69
    This is useful for explaining to the user why NumPy is required when their chosen
70
    backend does not implement a needed function.
71
    """
72
    try:
5✔
73
        import numpy  # noqa: ICN001, PLC0415
74

UNCOV
75
    except ModuleNotFoundError as err:
×
UNCOV
76
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
77
    else:
78
        return numpy
5✔
79

80

81
def default_xp() -> ModuleType:
6✔
82
    """Returns the library backend we default to if none is specified by the user."""
83
    return import_numpy()
5✔
84

85

86
class xp_additions:  # noqa: N801
6✔
87
    """
88
    Additional functions missing from both array-api-strict and array-api-extra.
89

90
    This class provides wrappers for common array operations such as integration,
91
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
92
    and array-api-strict backends.
93

94
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
95
    for details.
96
    """
97

98
    @staticmethod
6✔
99
    def trapezoid(
6✔
100
        y: AnyArray,
101
        x: AnyArray | None = None,
102
        dx: float = 1.0,
103
        axis: int = -1,
104
    ) -> AnyArray:
105
        """
106
        Integrate along the given axis using the composite trapezoidal rule.
107

108
        Parameters
109
        ----------
110
        y
111
            Input array to integrate.
112
        x
113
            Sample points corresponding to y.
114
        dx
115
            Spacing between sample points.
116
        axis
117
            Axis along which to integrate.
118

119
        Returns
120
        -------
121
            Integrated result.
122

123
        Raises
124
        ------
125
        NotImplementedError
126
            If the array backend is not supported.
127

128
        Notes
129
        -----
130
        See https://github.com/glass-dev/glass/issues/646
131

132
        """
133
        xp = array_api_compat.array_namespace(y, x, use_compat=False)
6✔
134

135
        if xp.__name__ == "jax.numpy":
6✔
136
            import glass.jax  # noqa: PLC0415
137

138
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
5✔
139

140
        if xp.__name__ == "numpy":
6✔
141
            return xp.trapezoid(y, x=x, dx=dx, axis=axis)
6✔
142

143
        if xp.__name__ == "array_api_strict":
5✔
144
            np = import_numpy(xp.__name__)
5✔
145

146
            # Using design principle of scipy (i.e. copy, use np, copy back)
147
            y_np = np.asarray(y, copy=True)
5✔
148
            x_np = np.asarray(x, copy=True)
5✔
149
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
5✔
150
            return xp.asarray(result_np, copy=True)
5✔
151

UNCOV
152
        msg = "the array backend in not supported"
×
UNCOV
153
        raise NotImplementedError(msg)
×
154

155
    @staticmethod
6✔
156
    def interp(  # noqa: PLR0913
6✔
157
        x: AnyArray,
158
        x_points: AnyArray,
159
        y_points: AnyArray,
160
        left: float | None = None,
161
        right: float | None = None,
162
        period: float | None = None,
163
    ) -> AnyArray:
164
        """
165
        One-dimensional linear interpolation for monotonically increasing sample points.
166

167
        Parameters
168
        ----------
169
        x
170
            The x-coordinates at which to evaluate the interpolated values.
171
        x_points
172
            The x-coordinates of the data points.
173
        y_points
174
            The y-coordinates of the data points.
175
        left
176
            Value to return for x < x_points[0].
177
        right
178
            Value to return for x > x_points[-1].
179
        period
180
            Period for periodic interpolation.
181

182
        Returns
183
        -------
184
            Interpolated values.
185

186
        Raises
187
        ------
188
        NotImplementedError
189
            If the array backend is not supported.
190

191
        Notes
192
        -----
193
        See https://github.com/glass-dev/glass/issues/650
194

195
        """
196
        xp = array_api_compat.array_namespace(x, x_points, y_points, use_compat=False)
6✔
197

198
        if xp.__name__ in {"numpy", "jax.numpy"}:
6✔
199
            return xp.interp(
6✔
200
                x,
201
                x_points,
202
                y_points,
203
                left=left,
204
                right=right,
205
                period=period,
206
            )
207

208
        if xp.__name__ == "array_api_strict":
5✔
209
            np = import_numpy(xp.__name__)
5✔
210

211
            # Using design principle of scipy (i.e. copy, use np, copy back)
212
            x_np = np.asarray(x, copy=True)
5✔
213
            x_points_np = np.asarray(x_points, copy=True)
5✔
214
            y_points_np = np.asarray(y_points, copy=True)
5✔
215
            result_np = np.interp(
5✔
216
                x_np,
217
                x_points_np,
218
                y_points_np,
219
                left=left,
220
                right=right,
221
                period=period,
222
            )
223
            return xp.asarray(result_np, copy=True)
5✔
224

UNCOV
225
        msg = "the array backend in not supported"
×
UNCOV
226
        raise NotImplementedError(msg)
×
227

228
    @staticmethod
6✔
229
    def gradient(f: AnyArray) -> AnyArray:
6✔
230
        """
231
        Return the gradient of an N-dimensional array.
232

233
        Parameters
234
        ----------
235
        f
236
            Input array.
237

238
        Returns
239
        -------
240
            Gradient of the input array.
241

242
        Raises
243
        ------
244
        NotImplementedError
245
            If the array backend is not supported.
246

247
        Notes
248
        -----
249
        See https://github.com/glass-dev/glass/issues/648
250

251
        """
252
        xp = f.__array_namespace__()
5✔
253

254
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
255
            return xp.gradient(f)
5✔
256

257
        if xp.__name__ == "array_api_strict":
5✔
258
            np = import_numpy(xp.__name__)
5✔
259
            # Using design principle of scipy (i.e. copy, use np, copy back)
260
            f_np = np.asarray(f, copy=True)
5✔
261
            result_np = np.gradient(f_np)
5✔
262
            return xp.asarray(result_np, copy=True)
5✔
263

264
        msg = "the array backend in not supported"
×
265
        raise NotImplementedError(msg)
×
266

267
    @staticmethod
6✔
268
    def linalg_lstsq(
6✔
269
        a: AnyArray,
270
        b: AnyArray,
271
        rcond: float | None = None,
272
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
273
        """
274
        Solve a linear least squares problem.
275

276
        Parameters
277
        ----------
278
        a
279
            Coefficient matrix.
280
        b
281
            Ordinate or "dependent variable" values.
282
        rcond
283
            Cut-off ratio for small singular values.
284

285
        Returns
286
        -------
287
        x
288
            Least-squares solution. If b is two-dimensional, the solutions are in the K
289
            columns of x.
290

291
        residuals
292
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
293
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
294
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
295

296
        rank
297
            Rank of matrix a.
298

299
        s
300
            Singular values of a.
301

302
        Raises
303
        ------
304
        NotImplementedError
305
            If the array backend is not supported.
306

307
        Notes
308
        -----
309
        See https://github.com/glass-dev/glass/issues/649
310

311
        """
312
        xp = array_api_compat.array_namespace(a, b, use_compat=False)
5✔
313

314
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
315
            return xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
5✔
316

317
        if xp.__name__ == "array_api_strict":
5✔
318
            np = import_numpy(xp.__name__)
5✔
319

320
            # Using design principle of scipy (i.e. copy, use np, copy back)
321
            a_np = np.asarray(a, copy=True)
5✔
322
            b_np = np.asarray(b, copy=True)
5✔
323
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
5✔
324
            return tuple(xp.asarray(res, copy=True) for res in result_np)
5✔
325

UNCOV
326
        msg = "the array backend in not supported"
×
UNCOV
327
        raise NotImplementedError(msg)
×
328

329
    @staticmethod
6✔
330
    def einsum(subscripts: str, *operands: AnyArray) -> AnyArray:
6✔
331
        """
332
        Evaluate the Einstein summation convention on the operands.
333

334
        Parameters
335
        ----------
336
        subscripts
337
            Specifies the subscripts for summation.
338
        *operands
339
            Arrays to be summed.
340

341
        Returns
342
        -------
343
            Result of the Einstein summation.
344

345
        Raises
346
        ------
347
        NotImplementedError
348
            If the array backend is not supported.
349

350
        Notes
351
        -----
352
        See https://github.com/glass-dev/glass/issues/657
353

354
        """
355
        xp = array_api_compat.array_namespace(*operands, use_compat=False)
5✔
356

357
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
358
            return xp.einsum(subscripts, *operands)
5✔
359

360
        if xp.__name__ == "array_api_strict":
5✔
361
            np = import_numpy(xp.__name__)
5✔
362

363
            # Using design principle of scipy (i.e. copy, use np, copy back)
364
            operands_np = (np.asarray(op, copy=True) for op in operands)
5✔
365
            result_np = np.einsum(subscripts, *operands_np)
5✔
366
            return xp.asarray(result_np, copy=True)
5✔
367

UNCOV
368
        msg = "the array backend in not supported"
×
UNCOV
369
        raise NotImplementedError(msg)
×
370

371
    @staticmethod
6✔
372
    def apply_along_axis(
6✔
373
        func: Callable[..., Any],
374
        func_inputs: tuple[Any, ...],
375
        axis: int,
376
        arr: AnyArray,
377
        *args: object,
378
        **kwargs: object,
379
    ) -> AnyArray:
380
        """
381
        Apply a function to 1-D slices along the given axis.
382

383
        Rather than accepting a partial function as usual, the function and
384
        its inputs are passed in separately for better compatibility.
385

386
        Parameters
387
        ----------
388
        func
389
            Function to apply to 1-D slices.
390
        func_inputs
391
            All inputs to the func besides arr.
392
        axis
393
            Axis along which to apply the function.
394
        arr
395
            Input array.
396
        *args
397
            Additional positional arguments to pass to func1d.
398
        **kwargs
399
            Additional keyword arguments to pass to func1d.
400

401
        Returns
402
        -------
403
            Result of applying the function along the axis.
404

405
        Raises
406
        ------
407
        NotImplementedError
408
            If the array backend is not supported.
409

410
        Notes
411
        -----
412
        See https://github.com/glass-dev/glass/issues/651
413

414
        """
415
        xp = array_api_compat.array_namespace(arr, *func_inputs, use_compat=False)
5✔
416

417
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
418
            func1d = functools.partial(func, *func_inputs)
5✔
419
            return xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
5✔
420

421
        if xp.__name__ == "array_api_strict":
5✔
422
            # Import here to prevent users relying on numpy unless in this instance
423
            np = import_numpy(xp.__name__)
5✔
424

425
            # Everything must be NumPy to avoid mismatches between array types
426
            inputs_np = (np.asarray(inp) for inp in func_inputs)
5✔
427
            func1d = functools.partial(func, *inputs_np)
5✔
428

429
            return xp.asarray(
5✔
430
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
431
                copy=True,
432
            )
433

UNCOV
434
        msg = "the array backend in not supported"
×
UNCOV
435
        raise NotImplementedError(msg)
×
436

437
    @staticmethod
6✔
438
    def vectorize(
6✔
439
        pyfunc: Callable[..., Any],
440
        otypes: str | Sequence[DTypeLike],
441
        *,
442
        xp: ModuleType,
443
    ) -> Callable[..., Any]:
444
        """
445
        Returns an object that acts like pyfunc, but takes arrays as input.
446

447
        Parameters
448
        ----------
449
        pyfunc
450
            Python function to vectorize.
451
        otypes
452
            Output types.
453
        xp
454
            The array library backend to use for array operations.
455

456
        Returns
457
        -------
458
            Vectorized function.
459

460
        Raises
461
        ------
462
        NotImplementedError
463
            If the array backend is not supported.
464

465
        Notes
466
        -----
467
        See https://github.com/glass-dev/glass/issues/671
468

469
        """
470
        if xp.__name__ == "numpy":
5✔
471
            return xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
472

473
        if xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
474
            # Import here to prevent users relying on numpy unless in this instance
475
            np = import_numpy(xp.__name__)
5✔
476

477
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
478

UNCOV
479
        msg = "the array backend in not supported"
×
UNCOV
480
        raise NotImplementedError(msg)
×
481

482
    @staticmethod
6✔
483
    def radians(deg_arr: AnyArray) -> AnyArray:
6✔
484
        """
485
        Convert angles from degrees to radians.
486

487
        Parameters
488
        ----------
489
        deg_arr
490
            Array of angles in degrees.
491

492
        Returns
493
        -------
494
            Array of angles in radians.
495

496
        Raises
497
        ------
498
        NotImplementedError
499
            If the array backend is not supported.
500

501
        """
502
        xp = deg_arr.__array_namespace__()
5✔
503

504
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
505
            return xp.radians(deg_arr)
5✔
506

507
        if xp.__name__ == "array_api_strict":
5✔
508
            np = import_numpy(xp.__name__)
5✔
509
            return xp.asarray(np.radians(deg_arr))
5✔
510

UNCOV
511
        msg = "the array backend in not supported"
×
UNCOV
512
        raise NotImplementedError(msg)
×
513

514
    @staticmethod
6✔
515
    def degrees(rad_arr: AnyArray) -> AnyArray:
6✔
516
        """
517
        Convert angles from radians to degrees.
518

519
        Parameters
520
        ----------
521
        rad_arr
522
            Array of angles in radians.
523

524
        Returns
525
        -------
526
            Array of angles in degrees.
527

528
        Raises
529
        ------
530
        NotImplementedError
531
            If the array backend is not supported.
532

533
        """
534
        xp = rad_arr.__array_namespace__()
6✔
535

536
        if xp.__name__ in {"numpy", "jax.numpy"}:
6✔
537
            return xp.degrees(rad_arr)
6✔
538

539
        if xp.__name__ == "array_api_strict":
5✔
540
            np = import_numpy(xp.__name__)
5✔
541
            return xp.asarray(np.degrees(rad_arr))
5✔
542

UNCOV
543
        msg = "the array backend in not supported"
×
UNCOV
544
        raise NotImplementedError(msg)
×
545

546
    @staticmethod
6✔
547
    def ndindex(shape: tuple[int, ...], *, xp: ModuleType) -> np.ndindex:
6✔
548
        """
549
        Wrapper for numpy.ndindex.
550

551
        See relevant docs for details:
552
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
553

554
        Parameters
555
        ----------
556
        shape
557
            Shape of the array to index.
558
        xp
559
            The array library backend to use for array operations.
560

561
        Raises
562
        ------
563
        NotImplementedError
564
            If the array backend is not supported.
565

566
        """
567
        if xp.__name__ == "numpy":
6✔
568
            return xp.ndindex(shape)  # type: ignore[no-any-return]
6✔
569

570
        if xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
571
            np = import_numpy(xp.__name__)
5✔
572
            return np.ndindex(shape)  # type: ignore[no-any-return]
5✔
573

UNCOV
574
        msg = "the array backend in not supported"
×
UNCOV
575
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc