• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18910467102

29 Oct 2025 02:00PM UTC coverage: 93.69% (+0.001%) from 93.689%
18910467102

Pull #722

github

web-flow
Merge 074bdfc9d into c4cfa4a63
Pull Request #722: gh-721: Port RNG functions in `shapes.py`

219 of 221 branches covered (99.1%)

Branch coverage included in aggregate %.

70 of 72 new or added lines in 2 files covered. (97.22%)

3 existing lines in 1 file now uncovered.

1444 of 1554 relevant lines covered (92.92%)

7.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.53
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from jaxtyping import Array as JAXArray
27
    from numpy.typing import DTypeLike, NDArray
28

29
    from array_api_strict._array_object import Array as AArray
30

31
    import glass.jax
32

33
    Size: TypeAlias = int | tuple[int, ...] | None
34

35
    AnyArray: TypeAlias = NDArray[Any] | JAXArray | AArray
36
    ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
37
    DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
38
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
39
    IntArray: TypeAlias = NDArray[np.int_] | JAXArray | AArray
40

41

42
class CompatibleBackendNotFoundError(Exception):
8✔
43
    """
44
    Exception raised when an array library backend that
45
    implements a requested function, is not found.
46
    """
47

48
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
UNCOV
49
        self.message = (
×
50
            f"{missing_backend} is required here as some functions required by GLASS "
51
            f"are not supported by {users_backend}"
52
        )
UNCOV
53
        super().__init__(self.message)
×
54

55

56
def import_numpy(backend: str) -> ModuleType:
8✔
57
    """
58
    Import the NumPy module, raising a helpful error if NumPy is not installed.
59

60
    Parameters
61
    ----------
62
    backend
63
        The name of the backend requested by the user.
64

65
    Returns
66
    -------
67
        The NumPy module.
68

69
    Raises
70
    ------
71
    ModuleNotFoundError
72
        If NumPy is not found in the user's environment.
73

74
    Notes
75
    -----
76
    This is useful for explaining to the user why NumPy is required when their chosen
77
    backend does not implement a needed function.
78
    """
79
    try:
8✔
80
        import numpy  # noqa: ICN001, PLC0415
8✔
81

UNCOV
82
    except ModuleNotFoundError as err:
×
83
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
84
    else:
85
        return numpy
8✔
86

87

88
def rng_dispatcher(
8✔
89
    *,
90
    xp: ModuleType,
91
) -> np.random.Generator | glass.jax.Generator | Generator:
92
    """
93
    Dispatch a random number generator based on the provided array's backend.
94

95
    Parameters
96
    ----------
97
    xp
98
        The array backend which determines the RNG.
99

100
    Returns
101
    -------
102
        The appropriate random number generator for the array's backend.
103

104
    Raises
105
    ------
106
    NotImplementedError
107
        If the array backend is not supported.
108
    """
109
    if xp.__name__ == "jax.numpy":
8✔
110
        import glass.jax  # noqa: PLC0415
8✔
111

112
        return glass.jax.Generator(seed=42)
8✔
113

114
    if xp.__name__ == "numpy":
8✔
115
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
116

117
    if xp.__name__ == "array_api_strict":
8✔
118
        return Generator(seed=42)
8✔
119

120
    msg = "the array backend in not supported"
×
121
    raise NotImplementedError(msg)
×
122

123

124
class Generator:
8✔
125
    """
126
    NumPy random number generator returning array_api_strict Array.
127

128
    This class wraps NumPy's random number generator and returns arrays compatible
129
    with array_api_strict.
130
    """
131

132
    __slots__ = ("axp", "nxp", "rng")
8✔
133

134
    def __init__(
8✔
135
        self,
136
        seed: int | bool | AArray | None = None,  # noqa: FBT001
137
    ) -> None:
138
        """
139
        Initialize the Generator.
140

141
        Parameters
142
        ----------
143
        seed
144
            Seed for the random number generator.
145
        """
146
        import numpy as np  # noqa: PLC0415
8✔
147

148
        import array_api_strict  # noqa: PLC0415
8✔
149

150
        self.axp = array_api_strict
8✔
151
        self.nxp = np
8✔
152
        self.rng = np.random.default_rng(seed=seed)
8✔
153

154
    def random(
8✔
155
        self,
156
        size: Size = None,
157
        dtype: DTypeLike | None = None,
158
        out: AArray | None = None,
159
    ) -> AArray:
160
        """
161
        Return random floats in the half-open interval [0.0, 1.0).
162

163
        Parameters
164
        ----------
165
        size
166
            Output shape.
167
        dtype
168
            Desired data type.
169
        out
170
            Optional output array.
171

172
        Returns
173
        -------
174
            Array of random floats.
175
        """
176
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
177
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
178

179
    def normal(
8✔
180
        self,
181
        loc: float | FloatArray = 0.0,
182
        scale: float | FloatArray = 1.0,
183
        size: Size = None,
184
    ) -> AArray:
185
        """
186
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
187

188
        Parameters
189
        ----------
190
        loc
191
            Mean of the distribution.
192
        scale
193
            Standard deviation of the distribution.
194
        size
195
            Output shape.
196

197
        Returns
198
        -------
199
            Array of samples from the normal distribution.
200
        """
201
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
202

203
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
204
        """
205
        Draw samples from a Poisson distribution.
206

207
        Parameters
208
        ----------
209
        lam
210
            Expected number of events.
211
        size
212
            Output shape.
213

214
        Returns
215
        -------
216
            Array of samples from the Poisson distribution.
217
        """
218
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
219

220
    def standard_normal(
8✔
221
        self,
222
        size: Size = None,
223
        dtype: DTypeLike | None = None,
224
        out: AArray | None = None,
225
    ) -> AArray:
226
        """
227
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
228

229
        Parameters
230
        ----------
231
        size
232
            Output shape.
233
        dtype
234
            Desired data type.
235
        out
236
            Optional output array.
237

238
        Returns
239
        -------
240
            Array of samples from the standard normal distribution.
241
        """
242
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
243
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
244

245
    def uniform(
8✔
246
        self,
247
        low: float | AArray = 0.0,
248
        high: float | AArray = 1.0,
249
        size: Size = None,
250
    ) -> AArray:
251
        """
252
        Draw samples from a Uniform distribution.
253

254
        Parameters
255
        ----------
256
        low
257
            Lower bound of the distribution.
258
        high
259
            Upper bound of the distribution.
260
        size : Size, optional
261
            Output shape.
262

263
        Returns
264
        -------
265
            Array of samples from the uniform distribution.
266
        """
267
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
268

269

270
class XPAdditions:
8✔
271
    """
272
    Additional functions missing from both array-api-strict and array-api-extra.
273

274
    This class provides wrappers for common array operations such as integration,
275
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
276
    and array-api-strict backends.
277

278
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
279
    for details.
280
    """
281

282
    xp: ModuleType
8✔
283
    backend: str
8✔
284

285
    def __init__(self, xp: ModuleType) -> None:
8✔
286
        """
287
        Initialize XPAdditions with the given array namespace.
288

289
        Parameters
290
        ----------
291
        xp
292
            The array namespace module.
293
        """
294
        self.xp = xp
8✔
295

296
    def trapezoid(
8✔
297
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
298
    ) -> AnyArray:
299
        """
300
        Integrate along the given axis using the composite trapezoidal rule.
301

302
        Parameters
303
        ----------
304
        y
305
            Input array to integrate.
306
        x
307
            Sample points corresponding to y.
308
        dx
309
            Spacing between sample points.
310
        axis
311
            Axis along which to integrate.
312

313
        Returns
314
        -------
315
            Integrated result.
316

317
        Raises
318
        ------
319
        NotImplementedError
320
            If the array backend is not supported.
321

322
        Notes
323
        -----
324
        See https://github.com/glass-dev/glass/issues/646
325
        """
326
        if self.xp.__name__ == "jax.numpy":
8✔
327
            import glass.jax  # noqa: PLC0415
8✔
328

329
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
330

331
        if self.xp.__name__ == "numpy":
8✔
332
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
333

334
        if self.xp.__name__ == "array_api_strict":
8✔
335
            np = import_numpy(self.xp.__name__)
8✔
336

337
            # Using design principle of scipy (i.e. copy, use np, copy back)
338
            y_np = np.asarray(y, copy=True)
8✔
339
            x_np = np.asarray(x, copy=True)
8✔
340
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
341
            return self.xp.asarray(result_np, copy=True)
8✔
342

343
        msg = "the array backend in not supported"
×
344
        raise NotImplementedError(msg)
×
345

346
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
347
        """
348
        Compute the set union of two 1D arrays.
349

350
        Parameters
351
        ----------
352
        ar1
353
            First input array.
354
        ar2
355
            Second input array.
356

357
        Returns
358
        -------
359
            The union of the two arrays.
360

361
        Raises
362
        ------
363
        NotImplementedError
364
            If the array backend is not supported.
365

366
        Notes
367
        -----
368
        See https://github.com/glass-dev/glass/issues/647
369
        """
370
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
371
            return self.xp.union1d(ar1, ar2)
8✔
372

373
        if self.xp.__name__ == "array_api_strict":
8✔
374
            np = import_numpy(self.xp.__name__)
8✔
375

376
            # Using design principle of scipy (i.e. copy, use np, copy back)
377
            ar1_np = np.asarray(ar1, copy=True)
8✔
378
            ar2_np = np.asarray(ar2, copy=True)
8✔
379
            result_np = np.union1d(ar1_np, ar2_np)
8✔
380
            return self.xp.asarray(result_np, copy=True)
8✔
381

382
        msg = "the array backend in not supported"
×
383
        raise NotImplementedError(msg)
×
384

385
    def interp(  # noqa: PLR0913
8✔
386
        self,
387
        x: AnyArray,
388
        x_points: AnyArray,
389
        y_points: AnyArray,
390
        left: float | None = None,
391
        right: float | None = None,
392
        period: float | None = None,
393
    ) -> AnyArray:
394
        """
395
        One-dimensional linear interpolation for monotonically increasing sample points.
396

397
        Parameters
398
        ----------
399
        x
400
            The x-coordinates at which to evaluate the interpolated values.
401
        x_points
402
            The x-coordinates of the data points.
403
        y_points
404
            The y-coordinates of the data points.
405
        left
406
            Value to return for x < x_points[0].
407
        right
408
            Value to return for x > x_points[-1].
409
        period
410
            Period for periodic interpolation.
411

412
        Returns
413
        -------
414
            Interpolated values.
415

416
        Raises
417
        ------
418
        NotImplementedError
419
            If the array backend is not supported.
420

421
        Notes
422
        -----
423
        See https://github.com/glass-dev/glass/issues/650
424
        """
425
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
426
            return self.xp.interp(
8✔
427
                x, x_points, y_points, left=left, right=right, period=period
428
            )
429

430
        if self.xp.__name__ == "array_api_strict":
8✔
431
            np = import_numpy(self.xp.__name__)
8✔
432

433
            # Using design principle of scipy (i.e. copy, use np, copy back)
434
            x_np = np.asarray(x, copy=True)
8✔
435
            x_points_np = np.asarray(x_points, copy=True)
8✔
436
            y_points_np = np.asarray(y_points, copy=True)
8✔
437
            result_np = np.interp(
8✔
438
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
439
            )
440
            return self.xp.asarray(result_np, copy=True)
8✔
441

442
        msg = "the array backend in not supported"
×
443
        raise NotImplementedError(msg)
×
444

445
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
446
        """
447
        Return the gradient of an N-dimensional array.
448

449
        Parameters
450
        ----------
451
        f
452
            Input array.
453

454
        Returns
455
        -------
456
            Gradient of the input array.
457

458
        Raises
459
        ------
460
        NotImplementedError
461
            If the array backend is not supported.
462

463
        Notes
464
        -----
465
        See https://github.com/glass-dev/glass/issues/648
466
        """
467
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
468
            return self.xp.gradient(f)
8✔
469

470
        if self.xp.__name__ == "array_api_strict":
8✔
471
            np = import_numpy(self.xp.__name__)
8✔
472

473
            # Using design principle of scipy (i.e. copy, use np, copy back)
474
            f_np = np.asarray(f, copy=True)
8✔
475
            result_np = np.gradient(f_np)
8✔
476
            return self.xp.asarray(result_np, copy=True)
8✔
477

478
        msg = "the array backend in not supported"
×
479
        raise NotImplementedError(msg)
×
480

481
    def linalg_lstsq(
8✔
482
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
483
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
484
        """
485
        Solve a linear least squares problem.
486

487
        Parameters
488
        ----------
489
        a
490
            Coefficient matrix.
491
        b
492
            Ordinate or "dependent variable" values.
493
        rcond
494
            Cut-off ratio for small singular values.
495

496
        Returns
497
        -------
498
        x
499
            Least-squares solution. If b is two-dimensional, the solutions are in the K
500
            columns of x.
501

502
        residuals
503
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
504
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
505
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
506

507
        rank
508
            Rank of matrix a.
509

510
        s
511
            Singular values of a.
512

513
        Raises
514
        ------
515
        NotImplementedError
516
            If the array backend is not supported.
517

518
        Notes
519
        -----
520
        See https://github.com/glass-dev/glass/issues/649
521
        """
522
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
523
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
524

525
        if self.xp.__name__ == "array_api_strict":
8✔
526
            np = import_numpy(self.xp.__name__)
8✔
527

528
            # Using design principle of scipy (i.e. copy, use np, copy back)
529
            a_np = np.asarray(a, copy=True)
8✔
530
            b_np = np.asarray(b, copy=True)
8✔
531
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
532
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
533

534
        msg = "the array backend in not supported"
×
535
        raise NotImplementedError(msg)
×
536

537
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
538
        """
539
        Evaluate the Einstein summation convention on the operands.
540

541
        Parameters
542
        ----------
543
        subscripts
544
            Specifies the subscripts for summation.
545
        *operands
546
            Arrays to be summed.
547

548
        Returns
549
        -------
550
            Result of the Einstein summation.
551

552
        Raises
553
        ------
554
        NotImplementedError
555
            If the array backend is not supported.
556

557
        Notes
558
        -----
559
        See https://github.com/glass-dev/glass/issues/657
560
        """
561
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
562
            return self.xp.einsum(subscripts, *operands)
8✔
563

564
        if self.xp.__name__ == "array_api_strict":
8✔
565
            np = import_numpy(self.xp.__name__)
8✔
566

567
            # Using design principle of scipy (i.e. copy, use np, copy back)
568
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
569
            result_np = np.einsum(subscripts, *operands_np)
8✔
570
            return self.xp.asarray(result_np, copy=True)
8✔
571

572
        msg = "the array backend in not supported"
×
573
        raise NotImplementedError(msg)
×
574

575
    def apply_along_axis(
8✔
576
        self,
577
        func1d: Callable[..., Any],
578
        axis: int,
579
        arr: AnyArray,
580
        *args: object,
581
        **kwargs: object,
582
    ) -> AnyArray:
583
        """
584
        Apply a function to 1-D slices along the given axis.
585

586
        Parameters
587
        ----------
588
        func1d
589
            Function to apply to 1-D slices.
590
        axis
591
            Axis along which to apply the function.
592
        arr
593
            Input array.
594
        *args
595
            Additional positional arguments to pass to func1d.
596
        **kwargs
597
            Additional keyword arguments to pass to func1d.
598

599
        Returns
600
        -------
601
            Result of applying the function along the axis.
602

603
        Raises
604
        ------
605
        NotImplementedError
606
            If the array backend is not supported.
607

608
        Notes
609
        -----
610
        See https://github.com/glass-dev/glass/issues/651
611

612
        """
613
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
614
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
615

616
        if self.xp.__name__ == "array_api_strict":
8✔
617
            # Import here to prevent users relying on numpy unless in this instance
618
            np = import_numpy(self.xp.__name__)
8✔
619

620
            return self.xp.asarray(
8✔
621
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
622
            )
623

624
        msg = "the array backend in not supported"
×
625
        raise NotImplementedError(msg)
×
626

627
    def vectorize(
8✔
628
        self,
629
        pyfunc: Callable[..., Any],
630
        otypes: tuple[type[float]],
631
    ) -> Callable[..., Any]:
632
        """
633
        Returns an object that acts like pyfunc, but takes arrays as input.
634

635
        Parameters
636
        ----------
637
        pyfunc
638
            Python function to vectorize.
639
        otypes
640
            Output types.
641

642
        Returns
643
        -------
644
            Vectorized function.
645

646
        Raises
647
        ------
648
        NotImplementedError
649
            If the array backend is not supported.
650

651
        Notes
652
        -----
653
        See https://github.com/glass-dev/glass/issues/671
654
        """
655
        if self.xp.__name__ == "numpy":
8✔
656
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
657

658
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
659
            # Import here to prevent users relying on numpy unless in this instance
660
            np = import_numpy(self.xp.__name__)
8✔
661

662
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
663

664
        msg = "the array backend in not supported"
×
665
        raise NotImplementedError(msg)
×
666

667
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
668
        """
669
        Convert angles from degrees to radians.
670

671
        Parameters
672
        ----------
673
        deg_arr
674
            Array of angles in degrees.
675

676
        Returns
677
        -------
678
            Array of angles in radians.
679

680
        Raises
681
        ------
682
        NotImplementedError
683
            If the array backend is not supported.
684
        """
685
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
686
            return self.xp.radians(deg_arr)
8✔
687

688
        if self.xp.__name__ == "array_api_strict":
8✔
689
            np = import_numpy(self.xp.__name__)
8✔
690

691
            return self.xp.asarray(np.radians(deg_arr))
8✔
692

693
        msg = "the array backend in not supported"
×
694
        raise NotImplementedError(msg)
×
695

696
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
697
        """
698
        Convert angles from radians to degrees.
699

700
        Parameters
701
        ----------
702
        deg_arr
703
            Array of angles in radians.
704

705
        Returns
706
        -------
707
            Array of angles in degrees.
708

709
        Raises
710
        ------
711
        NotImplementedError
712
            If the array backend is not supported.
713
        """
714
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
715
            return self.xp.degrees(deg_arr)
8✔
716

717
        if self.xp.__name__ == "array_api_strict":
8✔
718
            np = import_numpy(self.xp.__name__)
8✔
719

720
            return self.xp.asarray(np.degrees(deg_arr))
8✔
721

722
        msg = "the array backend in not supported"
×
723
        raise NotImplementedError(msg)
×
724

725
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
726
        """
727
        Wrapper for numpy.ndindex.
728

729
        See relevant docs for details:
730
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
731

732
        Raises
733
        ------
734
        NotImplementedError
735
            If the array backend is not supported.
736

737
        """
738
        if self.xp.__name__ == "numpy":
8✔
739
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
740

741
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
742
            np = import_numpy(self.xp.__name__)
8✔
743

744
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
745

NEW
746
        msg = "the array backend in not supported"
×
NEW
747
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc