• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 21474312807

29 Jan 2026 10:18AM UTC coverage: 94.259% (-2.7%) from 96.916%
21474312807

push

github

web-flow
gh-905: Port lensing.py functions to array-api with jax (#925)

Co-authored-by: Patrick J. Roddy <patrickjamesroddy@gmail.com>

213 of 215 branches covered (99.07%)

Branch coverage included in aggregate %.

38 of 38 new or added lines in 4 files covered. (100.0%)

35 existing lines in 5 files now uncovered.

1396 of 1492 relevant lines covered (93.57%)

5.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.89
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
==============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
xp_additions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
18

19
import functools
20
from typing import TYPE_CHECKING, Any
21

22
import array_api_compat
23

24
if TYPE_CHECKING:
25
    from collections.abc import Callable, Sequence
26
    from types import ModuleType
27

28
    import numpy as np
29

30
    from glass._types import AnyArray, DTypeLike
31

32

33
class CompatibleBackendNotFoundError(Exception):
6✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37

38
    """
39

40
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
6✔
UNCOV
41
        self.message = (
×
42
            f"{missing_backend} is required here as "
43
            "no alternative has been provided by the user."
44
            if users_backend is None
45
            else f"GLASS depends on functions not supported by {users_backend}"
46
        )
UNCOV
47
        super().__init__(self.message)
×
48

49

50
def import_numpy(backend: str | None = None) -> ModuleType:
6✔
51
    """
52
    Import the NumPy module, raising a helpful error if NumPy is not installed.
53

54
    Parameters
55
    ----------
56
    backend
57
        The name of the backend requested by the user.
58

59
    Returns
60
    -------
61
        The NumPy module.
62

63
    Raises
64
    ------
65
    ModuleNotFoundError
66
        If NumPy is not found in the user's environment.
67

68
    Notes
69
    -----
70
    This is useful for explaining to the user why NumPy is required when their chosen
71
    backend does not implement a needed function.
72

73
    """
74
    try:
5✔
75
        import numpy  # noqa: ICN001, PLC0415
76

UNCOV
77
    except ModuleNotFoundError as err:
×
UNCOV
78
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
79
    else:
80
        return numpy
5✔
81

82

83
def default_xp() -> ModuleType:
6✔
84
    """Returns the library backend we default to if none is specified by the user."""
85
    return import_numpy()
5✔
86

87

88
class xp_additions:  # noqa: N801
6✔
89
    """
90
    Additional functions missing from both array-api-strict and array-api-extra.
91

92
    This class provides wrappers for common array operations such as integration,
93
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
94
    and array-api-strict backends.
95

96
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
97
    for details.
98

99
    """
100

101
    @staticmethod
6✔
102
    def trapezoid(
6✔
103
        y: AnyArray,
104
        x: AnyArray | None = None,
105
        dx: float = 1.0,
106
        axis: int = -1,
107
    ) -> AnyArray:
108
        """
109
        Integrate along the given axis using the composite trapezoidal rule.
110

111
        Parameters
112
        ----------
113
        y
114
            Input array to integrate.
115
        x
116
            Sample points corresponding to y.
117
        dx
118
            Spacing between sample points.
119
        axis
120
            Axis along which to integrate.
121

122
        Returns
123
        -------
124
            Integrated result.
125

126
        Raises
127
        ------
128
        NotImplementedError
129
            If the array backend is not supported.
130

131
        Notes
132
        -----
133
        See https://github.com/glass-dev/glass/issues/646
134

135
        """
136
        xp = array_api_compat.array_namespace(y, x, use_compat=False)
6✔
137

138
        if xp.__name__ == "jax.numpy":
6✔
139
            import glass.jax  # noqa: PLC0415
140

141
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
5✔
142

143
        if xp.__name__ == "numpy":
6✔
144
            return xp.trapezoid(y, x=x, dx=dx, axis=axis)
6✔
145

146
        if xp.__name__ == "array_api_strict":
5✔
147
            np = import_numpy(xp.__name__)
5✔
148

149
            # Using design principle of scipy (i.e. copy, use np, copy back)
150
            y_np = np.asarray(y, copy=True)
5✔
151
            x_np = np.asarray(x, copy=True)
5✔
152
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
5✔
153
            return xp.asarray(result_np, copy=True)
5✔
154

UNCOV
155
        msg = "the array backend in not supported"
×
UNCOV
156
        raise NotImplementedError(msg)
×
157

158
    @staticmethod
6✔
159
    def interp(  # noqa: PLR0913
6✔
160
        x: AnyArray,
161
        x_points: AnyArray,
162
        y_points: AnyArray,
163
        left: float | None = None,
164
        right: float | None = None,
165
        period: float | None = None,
166
    ) -> AnyArray:
167
        """
168
        One-dimensional linear interpolation for monotonically increasing sample points.
169

170
        Parameters
171
        ----------
172
        x
173
            The x-coordinates at which to evaluate the interpolated values.
174
        x_points
175
            The x-coordinates of the data points.
176
        y_points
177
            The y-coordinates of the data points.
178
        left
179
            Value to return for x < x_points[0].
180
        right
181
            Value to return for x > x_points[-1].
182
        period
183
            Period for periodic interpolation.
184

185
        Returns
186
        -------
187
            Interpolated values.
188

189
        Raises
190
        ------
191
        NotImplementedError
192
            If the array backend is not supported.
193

194
        Notes
195
        -----
196
        See https://github.com/glass-dev/glass/issues/650
197

198
        """
199
        xp = array_api_compat.array_namespace(x, x_points, y_points, use_compat=False)
6✔
200

201
        if xp.__name__ in {"numpy", "jax.numpy"}:
6✔
202
            return xp.interp(
6✔
203
                x,
204
                x_points,
205
                y_points,
206
                left=left,
207
                right=right,
208
                period=period,
209
            )
210

211
        if xp.__name__ == "array_api_strict":
5✔
212
            np = import_numpy(xp.__name__)
5✔
213

214
            # Using design principle of scipy (i.e. copy, use np, copy back)
215
            x_np = np.asarray(x, copy=True)
5✔
216
            x_points_np = np.asarray(x_points, copy=True)
5✔
217
            y_points_np = np.asarray(y_points, copy=True)
5✔
218
            result_np = np.interp(
5✔
219
                x_np,
220
                x_points_np,
221
                y_points_np,
222
                left=left,
223
                right=right,
224
                period=period,
225
            )
226
            return xp.asarray(result_np, copy=True)
5✔
227

UNCOV
228
        msg = "the array backend in not supported"
×
UNCOV
229
        raise NotImplementedError(msg)
×
230

231
    @staticmethod
6✔
232
    def gradient(f: AnyArray) -> AnyArray:
6✔
233
        """
234
        Return the gradient of an N-dimensional array.
235

236
        Parameters
237
        ----------
238
        f
239
            Input array.
240

241
        Returns
242
        -------
243
            Gradient of the input array.
244

245
        Raises
246
        ------
247
        NotImplementedError
248
            If the array backend is not supported.
249

250
        Notes
251
        -----
252
        See https://github.com/glass-dev/glass/issues/648
253

254
        """
255
        xp = f.__array_namespace__()
5✔
256

257
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
258
            return xp.gradient(f)
5✔
259

260
        if xp.__name__ == "array_api_strict":
5✔
261
            np = import_numpy(xp.__name__)
5✔
262
            # Using design principle of scipy (i.e. copy, use np, copy back)
263
            f_np = np.asarray(f, copy=True)
5✔
264
            result_np = np.gradient(f_np)
5✔
265
            return xp.asarray(result_np, copy=True)
5✔
266

UNCOV
267
        msg = "the array backend in not supported"
×
UNCOV
268
        raise NotImplementedError(msg)
×
269

270
    @staticmethod
6✔
271
    def linalg_lstsq(
6✔
272
        a: AnyArray,
273
        b: AnyArray,
274
        rcond: float | None = None,
275
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
276
        """
277
        Solve a linear least squares problem.
278

279
        Parameters
280
        ----------
281
        a
282
            Coefficient matrix.
283
        b
284
            Ordinate or "dependent variable" values.
285
        rcond
286
            Cut-off ratio for small singular values.
287

288
        Returns
289
        -------
290
        x
291
            Least-squares solution. If b is two-dimensional, the solutions are in the K
292
            columns of x.
293

294
        residuals
295
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
296
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
297
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
298

299
        rank
300
            Rank of matrix a.
301

302
        s
303
            Singular values of a.
304

305
        Raises
306
        ------
307
        NotImplementedError
308
            If the array backend is not supported.
309

310
        Notes
311
        -----
312
        See https://github.com/glass-dev/glass/issues/649
313

314
        """
315
        xp = array_api_compat.array_namespace(a, b, use_compat=False)
5✔
316

317
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
318
            return xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
5✔
319

320
        if xp.__name__ == "array_api_strict":
5✔
321
            np = import_numpy(xp.__name__)
5✔
322

323
            # Using design principle of scipy (i.e. copy, use np, copy back)
324
            a_np = np.asarray(a, copy=True)
5✔
325
            b_np = np.asarray(b, copy=True)
5✔
326
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
5✔
327
            return tuple(xp.asarray(res, copy=True) for res in result_np)
5✔
328

UNCOV
329
        msg = "the array backend in not supported"
×
UNCOV
330
        raise NotImplementedError(msg)
×
331

332
    @staticmethod
6✔
333
    def einsum(subscripts: str, *operands: AnyArray) -> AnyArray:
6✔
334
        """
335
        Evaluate the Einstein summation convention on the operands.
336

337
        Parameters
338
        ----------
339
        subscripts
340
            Specifies the subscripts for summation.
341
        *operands
342
            Arrays to be summed.
343

344
        Returns
345
        -------
346
            Result of the Einstein summation.
347

348
        Raises
349
        ------
350
        NotImplementedError
351
            If the array backend is not supported.
352

353
        Notes
354
        -----
355
        See https://github.com/glass-dev/glass/issues/657
356

357
        """
358
        xp = array_api_compat.array_namespace(*operands, use_compat=False)
5✔
359

360
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
361
            return xp.einsum(subscripts, *operands)
5✔
362

363
        if xp.__name__ == "array_api_strict":
5✔
364
            np = import_numpy(xp.__name__)
5✔
365

366
            # Using design principle of scipy (i.e. copy, use np, copy back)
367
            operands_np = (np.asarray(op, copy=True) for op in operands)
5✔
368
            result_np = np.einsum(subscripts, *operands_np)
5✔
369
            return xp.asarray(result_np, copy=True)
5✔
370

UNCOV
371
        msg = "the array backend in not supported"
×
UNCOV
372
        raise NotImplementedError(msg)
×
373

374
    @staticmethod
6✔
375
    def apply_along_axis(
6✔
376
        func: Callable[..., Any],
377
        func_inputs: tuple[Any, ...],
378
        axis: int,
379
        arr: AnyArray,
380
        *args: object,
381
        **kwargs: object,
382
    ) -> AnyArray:
383
        """
384
        Apply a function to 1-D slices along the given axis.
385

386
        Rather than accepting a partial function as usual, the function and
387
        its inputs are passed in separately for better compatibility.
388

389
        Parameters
390
        ----------
391
        func
392
            Function to apply to 1-D slices.
393
        func_inputs
394
            All inputs to the func besides arr.
395
        axis
396
            Axis along which to apply the function.
397
        arr
398
            Input array.
399
        *args
400
            Additional positional arguments to pass to func1d.
401
        **kwargs
402
            Additional keyword arguments to pass to func1d.
403

404
        Returns
405
        -------
406
            Result of applying the function along the axis.
407

408
        Raises
409
        ------
410
        NotImplementedError
411
            If the array backend is not supported.
412

413
        Notes
414
        -----
415
        See https://github.com/glass-dev/glass/issues/651
416

417
        """
418
        xp = array_api_compat.array_namespace(arr, *func_inputs, use_compat=False)
5✔
419

420
        if xp.__name__ in {"numpy", "jax.numpy"}:
5✔
421
            func1d = functools.partial(func, *func_inputs)
5✔
422
            return xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
5✔
423

424
        if xp.__name__ == "array_api_strict":
5✔
425
            # Import here to prevent users relying on numpy unless in this instance
426
            np = import_numpy(xp.__name__)
5✔
427

428
            # Everything must be NumPy to avoid mismatches between array types
429
            inputs_np = (np.asarray(inp) for inp in func_inputs)
5✔
430
            func1d = functools.partial(func, *inputs_np)
5✔
431

432
            return xp.asarray(
5✔
433
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
434
                copy=True,
435
            )
436

UNCOV
437
        msg = "the array backend in not supported"
×
UNCOV
438
        raise NotImplementedError(msg)
×
439

440
    @staticmethod
6✔
441
    def vectorize(
6✔
442
        pyfunc: Callable[..., Any],
443
        otypes: str | Sequence[DTypeLike],
444
        *,
445
        xp: ModuleType,
446
    ) -> Callable[..., Any]:
447
        """
448
        Returns an object that acts like pyfunc, but takes arrays as input.
449

450
        Parameters
451
        ----------
452
        pyfunc
453
            Python function to vectorize.
454
        otypes
455
            Output types.
456
        xp
457
            The array library backend to use for array operations.
458

459
        Returns
460
        -------
461
            Vectorized function.
462

463
        Raises
464
        ------
465
        NotImplementedError
466
            If the array backend is not supported.
467

468
        Notes
469
        -----
470
        See https://github.com/glass-dev/glass/issues/671
471

472
        """
473
        if xp.__name__ == "numpy":
5✔
474
            return xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
475

476
        if xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
477
            # Import here to prevent users relying on numpy unless in this instance
478
            np = import_numpy(xp.__name__)
5✔
479

480
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
481

UNCOV
482
        msg = "the array backend in not supported"
×
UNCOV
483
        raise NotImplementedError(msg)
×
484

485
    @staticmethod
6✔
486
    def radians(deg_arr: AnyArray) -> AnyArray:
6✔
487
        """
488
        Convert angles from degrees to radians.
489

490
        Parameters
491
        ----------
492
        deg_arr
493
            Array of angles in degrees.
494

495
        Returns
496
        -------
497
            Array of angles in radians.
498

499
        Raises
500
        ------
501
        NotImplementedError
502
            If the array backend is not supported.
503

504
        """
505
        xp = deg_arr.__array_namespace__()
6✔
506

507
        if xp.__name__ in {"numpy", "jax.numpy"}:
6✔
508
            return xp.radians(deg_arr)
6✔
509

510
        if xp.__name__ == "array_api_strict":
5✔
511
            np = import_numpy(xp.__name__)
5✔
512
            return xp.asarray(np.radians(deg_arr))
5✔
513

UNCOV
514
        msg = "the array backend in not supported"
×
UNCOV
515
        raise NotImplementedError(msg)
×
516

517
    @staticmethod
6✔
518
    def degrees(rad_arr: AnyArray) -> AnyArray:
6✔
519
        """
520
        Convert angles from radians to degrees.
521

522
        Parameters
523
        ----------
524
        rad_arr
525
            Array of angles in radians.
526

527
        Returns
528
        -------
529
            Array of angles in degrees.
530

531
        Raises
532
        ------
533
        NotImplementedError
534
            If the array backend is not supported.
535

536
        """
537
        xp = rad_arr.__array_namespace__()
6✔
538

539
        if xp.__name__ in {"numpy", "jax.numpy"}:
6✔
540
            return xp.degrees(rad_arr)
6✔
541

542
        if xp.__name__ == "array_api_strict":
5✔
543
            np = import_numpy(xp.__name__)
5✔
544
            return xp.asarray(np.degrees(rad_arr))
5✔
545

UNCOV
546
        msg = "the array backend in not supported"
×
UNCOV
547
        raise NotImplementedError(msg)
×
548

549
    @staticmethod
6✔
550
    def ndindex(shape: tuple[int, ...], *, xp: ModuleType) -> np.ndindex:
6✔
551
        """
552
        Wrapper for numpy.ndindex.
553

554
        See relevant docs for details:
555
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
556

557
        Parameters
558
        ----------
559
        shape
560
            Shape of the array to index.
561
        xp
562
            The array library backend to use for array operations.
563

564
        Raises
565
        ------
566
        NotImplementedError
567
            If the array backend is not supported.
568

569
        """
570
        if xp.__name__ == "numpy":
6✔
571
            return xp.ndindex(shape)  # type: ignore[no-any-return]
6✔
572

573
        if xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
574
            np = import_numpy(xp.__name__)
5✔
575
            return np.ndindex(shape)  # type: ignore[no-any-return]
5✔
576

UNCOV
577
        msg = "the array backend in not supported"
×
UNCOV
578
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc