• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18773426547

23 Oct 2025 04:36PM UTC coverage: 93.634% (+0.003%) from 93.631%
18773426547

Pull #711

github

web-flow
Merge 7fe2e7915 into c1716301d
Pull Request #711: gh-623: use `array_api_compat.array_namespace` helper

214 of 216 branches covered (99.07%)

Branch coverage included in aggregate %.

27 of 27 new or added lines in 8 files covered. (100.0%)

57 existing lines in 4 files now uncovered.

1404 of 1512 relevant lines covered (92.86%)

7.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.47
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from jaxtyping import Array as JAXArray
27
    from numpy.typing import DTypeLike, NDArray
28

29
    from array_api_strict._array_object import Array as AArray
30

31
    import glass.jax
32

33
    Size: TypeAlias = int | tuple[int, ...] | None
34

35
    AnyArray: TypeAlias = NDArray[Any] | JAXArray | AArray
36
    ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
37
    DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
38
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
39
    IntArray: TypeAlias = NDArray[np.int_] | JAXArray | AArray
40

41

42
class CompatibleBackendNotFoundError(Exception):
8✔
43
    """
44
    Exception raised when an array library backend that
45
    implements a requested function, is not found.
46
    """
47

48
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
49
        self.message = (
×
50
            f"{missing_backend} is required here as some functions required by GLASS "
51
            f"are not supported by {users_backend}"
52
        )
53
        super().__init__(self.message)
×
54

55

56
def import_numpy(backend: str) -> ModuleType:
8✔
57
    """
58
    Import the NumPy module, raising a helpful error if NumPy is not installed.
59

60
    Parameters
61
    ----------
62
    backend
63
        The name of the backend requested by the user.
64

65
    Returns
66
    -------
67
        The NumPy module.
68

69
    Raises
70
    ------
71
    ModuleNotFoundError
72
        If NumPy is not found in the user's environment.
73

74
    Notes
75
    -----
76
    This is useful for explaining to the user why NumPy is required when their chosen
77
    backend does not implement a needed function.
78
    """
79
    try:
8✔
80
        import numpy  # noqa: ICN001, PLC0415
8✔
81

82
    except ModuleNotFoundError as err:
×
83
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
84
    else:
85
        return numpy
8✔
86

87

88
def array_namespace(*arrays: AnyArray) -> ModuleType:
8✔
89
    """
90
    Return the array library (namespace) of input arrays if they all belong to the same
91
    library.
92

93
    Parameters
94
    ----------
95
    *arrays
96
        Arrays whose namespace is to be determined.
97

98
    Returns
99
    -------
100
        The array namespace module.
101

102
    Raises
103
    ------
104
    ValueError
105
        If input arrays do not all belong to the same array library.
106
    """
107
    namespace = arrays[0].__array_namespace__()
8✔
108
    if any(
8✔
109
        array.__array_namespace__() != namespace
110
        for array in arrays
111
        if array is not None
112
    ):
113
        msg = "input arrays should belong to the same array library"
8✔
114
        raise ValueError(msg)
8✔
115

116
    return namespace
8✔
117

118

119
def rng_dispatcher(
8✔
120
    array: AnyArray,
121
) -> np.random.Generator | glass.jax.Generator | Generator:
122
    """
123
    Dispatch a random number generator based on the provided array's backend.
124

125
    Parameters
126
    ----------
127
    array
128
        The array whose backend determines the RNG.
129

130
    Returns
131
    -------
132
        The appropriate random number generator for the array's backend.
133

134
    Raises
135
    ------
136
    NotImplementedError
137
        If the array backend is not supported.
138
    """
139
    xp = array_namespace(array)
8✔
140

141
    if xp.__name__ == "jax.numpy":
8✔
142
        import glass.jax  # noqa: PLC0415
8✔
143

144
        return glass.jax.Generator(seed=42)
8✔
145

146
    if xp.__name__ == "numpy":
8✔
147
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
148

149
    if xp.__name__ == "array_api_strict":
8✔
150
        return Generator(seed=42)
8✔
151

UNCOV
152
    msg = "the array backend in not supported"
×
UNCOV
153
    raise NotImplementedError(msg)
×
154

155

156
class Generator:
8✔
157
    """
158
    NumPy random number generator returning array_api_strict Array.
159

160
    This class wraps NumPy's random number generator and returns arrays compatible
161
    with array_api_strict.
162
    """
163

164
    __slots__ = ("axp", "nxp", "rng")
8✔
165

166
    def __init__(
8✔
167
        self,
168
        seed: int | bool | AArray | None = None,  # noqa: FBT001
169
    ) -> None:
170
        """
171
        Initialize the Generator.
172

173
        Parameters
174
        ----------
175
        seed
176
            Seed for the random number generator.
177
        """
178
        import numpy as np  # noqa: PLC0415
8✔
179

180
        import array_api_strict  # noqa: PLC0415
8✔
181

182
        self.axp = array_api_strict
8✔
183
        self.nxp = np
8✔
184
        self.rng = np.random.default_rng(seed=seed)
8✔
185

186
    def random(
8✔
187
        self,
188
        size: Size = None,
189
        dtype: DTypeLike | None = None,
190
        out: AArray | None = None,
191
    ) -> AArray:
192
        """
193
        Return random floats in the half-open interval [0.0, 1.0).
194

195
        Parameters
196
        ----------
197
        size
198
            Output shape.
199
        dtype
200
            Desired data type.
201
        out
202
            Optional output array.
203

204
        Returns
205
        -------
206
            Array of random floats.
207
        """
208
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
209
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
210

211
    def normal(
8✔
212
        self,
213
        loc: float | AArray = 0.0,
214
        scale: float | AArray = 1.0,
215
        size: Size = None,
216
    ) -> AArray:
217
        """
218
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
219

220
        Parameters
221
        ----------
222
        loc
223
            Mean of the distribution.
224
        scale
225
            Standard deviation of the distribution.
226
        size
227
            Output shape.
228

229
        Returns
230
        -------
231
            Array of samples from the normal distribution.
232
        """
233
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
234

235
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
236
        """
237
        Draw samples from a Poisson distribution.
238

239
        Parameters
240
        ----------
241
        lam
242
            Expected number of events.
243
        size
244
            Output shape.
245

246
        Returns
247
        -------
248
            Array of samples from the Poisson distribution.
249
        """
250
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
251

252
    def standard_normal(
8✔
253
        self,
254
        size: Size = None,
255
        dtype: DTypeLike | None = None,
256
        out: AArray | None = None,
257
    ) -> AArray:
258
        """
259
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
260

261
        Parameters
262
        ----------
263
        size
264
            Output shape.
265
        dtype
266
            Desired data type.
267
        out
268
            Optional output array.
269

270
        Returns
271
        -------
272
            Array of samples from the standard normal distribution.
273
        """
274
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
275
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
276

277
    def uniform(
8✔
278
        self,
279
        low: float | AArray = 0.0,
280
        high: float | AArray = 1.0,
281
        size: Size = None,
282
    ) -> AArray:
283
        """
284
        Draw samples from a Uniform distribution.
285

286
        Parameters
287
        ----------
288
        low
289
            Lower bound of the distribution.
290
        high
291
            Upper bound of the distribution.
292
        size : Size, optional
293
            Output shape.
294

295
        Returns
296
        -------
297
            Array of samples from the uniform distribution.
298
        """
299
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
300

301

302
class XPAdditions:
8✔
303
    """
304
    Additional functions missing from both array-api-strict and array-api-extra.
305

306
    This class provides wrappers for common array operations such as integration,
307
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
308
    and array-api-strict backends.
309

310
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
311
    for details.
312
    """
313

314
    xp: ModuleType
8✔
315
    backend: str
8✔
316

317
    def __init__(self, xp: ModuleType) -> None:
8✔
318
        """
319
        Initialize XPAdditions with the given array namespace.
320

321
        Parameters
322
        ----------
323
        xp
324
            The array namespace module.
325
        """
326
        self.xp = xp
8✔
327

328
    def trapezoid(
8✔
329
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
330
    ) -> AnyArray:
331
        """
332
        Integrate along the given axis using the composite trapezoidal rule.
333

334
        Parameters
335
        ----------
336
        y
337
            Input array to integrate.
338
        x
339
            Sample points corresponding to y.
340
        dx
341
            Spacing between sample points.
342
        axis
343
            Axis along which to integrate.
344

345
        Returns
346
        -------
347
            Integrated result.
348

349
        Raises
350
        ------
351
        NotImplementedError
352
            If the array backend is not supported.
353

354
        Notes
355
        -----
356
        See https://github.com/glass-dev/glass/issues/646
357
        """
358
        if self.xp.__name__ == "jax.numpy":
8✔
359
            import glass.jax  # noqa: PLC0415
8✔
360

361
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
362

363
        if self.xp.__name__ == "numpy":
8✔
364
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
365

366
        if self.xp.__name__ == "array_api_strict":
8✔
367
            np = import_numpy(self.xp.__name__)
8✔
368

369
            # Using design principle of scipy (i.e. copy, use np, copy back)
370
            y_np = np.asarray(y, copy=True)
8✔
371
            x_np = np.asarray(x, copy=True)
8✔
372
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
373
            return self.xp.asarray(result_np, copy=True)
8✔
374

UNCOV
375
        msg = "the array backend in not supported"
×
UNCOV
376
        raise NotImplementedError(msg)
×
377

378
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
379
        """
380
        Compute the set union of two 1D arrays.
381

382
        Parameters
383
        ----------
384
        ar1
385
            First input array.
386
        ar2
387
            Second input array.
388

389
        Returns
390
        -------
391
            The union of the two arrays.
392

393
        Raises
394
        ------
395
        NotImplementedError
396
            If the array backend is not supported.
397

398
        Notes
399
        -----
400
        See https://github.com/glass-dev/glass/issues/647
401
        """
402
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
403
            return self.xp.union1d(ar1, ar2)
8✔
404

405
        if self.xp.__name__ == "array_api_strict":
8✔
406
            np = import_numpy(self.xp.__name__)
8✔
407

408
            # Using design principle of scipy (i.e. copy, use np, copy back)
409
            ar1_np = np.asarray(ar1, copy=True)
8✔
410
            ar2_np = np.asarray(ar2, copy=True)
8✔
411
            result_np = np.union1d(ar1_np, ar2_np)
8✔
412
            return self.xp.asarray(result_np, copy=True)
8✔
413

UNCOV
414
        msg = "the array backend in not supported"
×
UNCOV
415
        raise NotImplementedError(msg)
×
416

417
    def interp(  # noqa: PLR0913
8✔
418
        self,
419
        x: AnyArray,
420
        x_points: AnyArray,
421
        y_points: AnyArray,
422
        left: float | None = None,
423
        right: float | None = None,
424
        period: float | None = None,
425
    ) -> AnyArray:
426
        """
427
        One-dimensional linear interpolation for monotonically increasing sample points.
428

429
        Parameters
430
        ----------
431
        x
432
            The x-coordinates at which to evaluate the interpolated values.
433
        x_points
434
            The x-coordinates of the data points.
435
        y_points
436
            The y-coordinates of the data points.
437
        left
438
            Value to return for x < x_points[0].
439
        right
440
            Value to return for x > x_points[-1].
441
        period
442
            Period for periodic interpolation.
443

444
        Returns
445
        -------
446
            Interpolated values.
447

448
        Raises
449
        ------
450
        NotImplementedError
451
            If the array backend is not supported.
452

453
        Notes
454
        -----
455
        See https://github.com/glass-dev/glass/issues/650
456
        """
457
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
458
            return self.xp.interp(
8✔
459
                x, x_points, y_points, left=left, right=right, period=period
460
            )
461

462
        if self.xp.__name__ == "array_api_strict":
8✔
463
            np = import_numpy(self.xp.__name__)
8✔
464

465
            # Using design principle of scipy (i.e. copy, use np, copy back)
466
            x_np = np.asarray(x, copy=True)
8✔
467
            x_points_np = np.asarray(x_points, copy=True)
8✔
468
            y_points_np = np.asarray(y_points, copy=True)
8✔
469
            result_np = np.interp(
8✔
470
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
471
            )
472
            return self.xp.asarray(result_np, copy=True)
8✔
473

UNCOV
474
        msg = "the array backend in not supported"
×
UNCOV
475
        raise NotImplementedError(msg)
×
476

477
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
478
        """
479
        Return the gradient of an N-dimensional array.
480

481
        Parameters
482
        ----------
483
        f
484
            Input array.
485

486
        Returns
487
        -------
488
            Gradient of the input array.
489

490
        Raises
491
        ------
492
        NotImplementedError
493
            If the array backend is not supported.
494

495
        Notes
496
        -----
497
        See https://github.com/glass-dev/glass/issues/648
498
        """
499
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
500
            return self.xp.gradient(f)
8✔
501

502
        if self.xp.__name__ == "array_api_strict":
8✔
503
            np = import_numpy(self.xp.__name__)
8✔
504

505
            # Using design principle of scipy (i.e. copy, use np, copy back)
506
            f_np = np.asarray(f, copy=True)
8✔
507
            result_np = np.gradient(f_np)
8✔
508
            return self.xp.asarray(result_np, copy=True)
8✔
509

UNCOV
510
        msg = "the array backend in not supported"
×
UNCOV
511
        raise NotImplementedError(msg)
×
512

513
    def linalg_lstsq(
8✔
514
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
515
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
516
        """
517
        Solve a linear least squares problem.
518

519
        Parameters
520
        ----------
521
        a
522
            Coefficient matrix.
523
        b
524
            Ordinate or "dependent variable" values.
525
        rcond
526
            Cut-off ratio for small singular values.
527

528
        Returns
529
        -------
530
        x
531
            Least-squares solution. If b is two-dimensional, the solutions are in the K
532
            columns of x.
533

534
        residuals
535
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
536
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
537
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
538

539
        rank
540
            Rank of matrix a.
541

542
        s
543
            Singular values of a.
544

545
        Raises
546
        ------
547
        NotImplementedError
548
            If the array backend is not supported.
549

550
        Notes
551
        -----
552
        See https://github.com/glass-dev/glass/issues/649
553
        """
554
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
555
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
556

557
        if self.xp.__name__ == "array_api_strict":
8✔
558
            np = import_numpy(self.xp.__name__)
8✔
559

560
            # Using design principle of scipy (i.e. copy, use np, copy back)
561
            a_np = np.asarray(a, copy=True)
8✔
562
            b_np = np.asarray(b, copy=True)
8✔
563
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
564
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
565

UNCOV
566
        msg = "the array backend in not supported"
×
UNCOV
567
        raise NotImplementedError(msg)
×
568

569
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
570
        """
571
        Evaluate the Einstein summation convention on the operands.
572

573
        Parameters
574
        ----------
575
        subscripts
576
            Specifies the subscripts for summation.
577
        *operands
578
            Arrays to be summed.
579

580
        Returns
581
        -------
582
            Result of the Einstein summation.
583

584
        Raises
585
        ------
586
        NotImplementedError
587
            If the array backend is not supported.
588

589
        Notes
590
        -----
591
        See https://github.com/glass-dev/glass/issues/657
592
        """
593
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
594
            return self.xp.einsum(subscripts, *operands)
8✔
595

596
        if self.xp.__name__ == "array_api_strict":
8✔
597
            np = import_numpy(self.xp.__name__)
8✔
598

599
            # Using design principle of scipy (i.e. copy, use np, copy back)
600
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
601
            result_np = np.einsum(subscripts, *operands_np)
8✔
602
            return self.xp.asarray(result_np, copy=True)
8✔
603

UNCOV
604
        msg = "the array backend in not supported"
×
UNCOV
605
        raise NotImplementedError(msg)
×
606

607
    def apply_along_axis(
8✔
608
        self,
609
        func1d: Callable[..., Any],
610
        axis: int,
611
        arr: AnyArray,
612
        *args: object,
613
        **kwargs: object,
614
    ) -> AnyArray:
615
        """
616
        Apply a function to 1-D slices along the given axis.
617

618
        Parameters
619
        ----------
620
        func1d
621
            Function to apply to 1-D slices.
622
        axis
623
            Axis along which to apply the function.
624
        arr
625
            Input array.
626
        *args
627
            Additional positional arguments to pass to func1d.
628
        **kwargs
629
            Additional keyword arguments to pass to func1d.
630

631
        Returns
632
        -------
633
            Result of applying the function along the axis.
634

635
        Raises
636
        ------
637
        NotImplementedError
638
            If the array backend is not supported.
639

640
        Notes
641
        -----
642
        See https://github.com/glass-dev/glass/issues/651
643

644
        """
645
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
646
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
647

648
        if self.xp.__name__ == "array_api_strict":
8✔
649
            # Import here to prevent users relying on numpy unless in this instance
650
            np = import_numpy(self.xp.__name__)
8✔
651

652
            return self.xp.asarray(
8✔
653
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
654
            )
655

UNCOV
656
        msg = "the array backend in not supported"
×
UNCOV
657
        raise NotImplementedError(msg)
×
658

659
    def vectorize(
8✔
660
        self,
661
        pyfunc: Callable[..., Any],
662
        otypes: tuple[type[float]],
663
    ) -> Callable[..., Any]:
664
        """
665
        Returns an object that acts like pyfunc, but takes arrays as input.
666

667
        Parameters
668
        ----------
669
        pyfunc
670
            Python function to vectorize.
671
        otypes
672
            Output types.
673

674
        Returns
675
        -------
676
            Vectorized function.
677

678
        Raises
679
        ------
680
        NotImplementedError
681
            If the array backend is not supported.
682

683
        Notes
684
        -----
685
        See https://github.com/glass-dev/glass/issues/671
686
        """
687
        if self.xp.__name__ == "numpy":
8✔
688
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
689

690
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
691
            # Import here to prevent users relying on numpy unless in this instance
692
            np = import_numpy(self.xp.__name__)
8✔
693

694
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
695

UNCOV
696
        msg = "the array backend in not supported"
×
UNCOV
697
        raise NotImplementedError(msg)
×
698

699
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
700
        """
701
        Convert angles from degrees to radians.
702

703
        Parameters
704
        ----------
705
        deg_arr
706
            Array of angles in degrees.
707

708
        Returns
709
        -------
710
            Array of angles in radians.
711

712
        Raises
713
        ------
714
        NotImplementedError
715
            If the array backend is not supported.
716
        """
717
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
718
            return self.xp.radians(deg_arr)
8✔
719

720
        if self.xp.__name__ == "array_api_strict":
8✔
721
            np = import_numpy(self.xp.__name__)
8✔
722

723
            return self.xp.asarray(np.radians(deg_arr))
8✔
724

UNCOV
725
        msg = "the array backend in not supported"
×
UNCOV
726
        raise NotImplementedError(msg)
×
727

728
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
729
        """
730
        Convert angles from radians to degrees.
731

732
        Parameters
733
        ----------
734
        deg_arr
735
            Array of angles in radians.
736

737
        Returns
738
        -------
739
            Array of angles in degrees.
740

741
        Raises
742
        ------
743
        NotImplementedError
744
            If the array backend is not supported.
745
        """
746
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
747
            return self.xp.degrees(deg_arr)
8✔
748

749
        if self.xp.__name__ == "array_api_strict":
8✔
750
            np = import_numpy(self.xp.__name__)
8✔
751

752
            return self.xp.asarray(np.degrees(deg_arr))
8✔
753

UNCOV
754
        msg = "the array backend in not supported"
×
UNCOV
755
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc