• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18917191455

29 Oct 2025 05:45PM UTC coverage: 93.113% (+0.01%) from 93.101%
18917191455

Pull #722

github

web-flow
Merge 334c71bb3 into 84a77e10b
Pull Request #722: gh-721: Port RNG functions in `shapes.py`

219 of 221 branches covered (99.1%)

Branch coverage included in aggregate %.

46 of 48 new or added lines in 2 files covered. (95.83%)

14 existing lines in 1 file now uncovered.

1444 of 1565 relevant lines covered (92.27%)

7.37 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.53
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    from numpy.typing import DTypeLike
26

27
    from array_api_strict._array_object import Array as AArray
28

29
    from glass._types import AnyArray, Size, UnifiedGenerator
30

31

32
class CompatibleBackendNotFoundError(Exception):
8✔
33
    """
34
    Exception raised when an array library backend that
35
    implements a requested function, is not found.
36
    """
37

38
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
UNCOV
39
        self.message = (
×
40
            f"{missing_backend} is required here as some functions required by GLASS "
41
            f"are not supported by {users_backend}"
42
        )
UNCOV
43
        super().__init__(self.message)
×
44

45

46
def import_numpy(backend: str) -> ModuleType:
8✔
47
    """
48
    Import the NumPy module, raising a helpful error if NumPy is not installed.
49

50
    Parameters
51
    ----------
52
    backend
53
        The name of the backend requested by the user.
54

55
    Returns
56
    -------
57
        The NumPy module.
58

59
    Raises
60
    ------
61
    ModuleNotFoundError
62
        If NumPy is not found in the user's environment.
63

64
    Notes
65
    -----
66
    This is useful for explaining to the user why NumPy is required when their chosen
67
    backend does not implement a needed function.
68
    """
69
    try:
8✔
70
        import numpy  # noqa: ICN001, PLC0415
8✔
71

UNCOV
72
    except ModuleNotFoundError as err:
×
73
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
74
    else:
75
        return numpy
8✔
76

77

78
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
79
    """
80
    Dispatch a random number generator based on the provided array's backend.
81

82
    Parameters
83
    ----------
84
    xp
85
        The array backend which determines the RNG.
86

87
    Returns
88
    -------
89
        The appropriate random number generator for the array's backend.
90

91
    Raises
92
    ------
93
    NotImplementedError
94
        If the array backend is not supported.
95
    """
96
    if xp.__name__ == "jax.numpy":
8✔
97
        import glass.jax  # noqa: PLC0415
8✔
98

99
        return glass.jax.Generator(seed=42)
8✔
100

101
    if xp.__name__ == "numpy":
8✔
102
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
103

104
    if xp.__name__ == "array_api_strict":
8✔
105
        return Generator(seed=42)
8✔
106

UNCOV
107
    msg = "the array backend in not supported"
×
108
    raise NotImplementedError(msg)
×
109

110

111
class Generator:
8✔
112
    """
113
    NumPy random number generator returning array_api_strict Array.
114

115
    This class wraps NumPy's random number generator and returns arrays compatible
116
    with array_api_strict.
117
    """
118

119
    __slots__ = ("axp", "nxp", "rng")
8✔
120

121
    def __init__(
8✔
122
        self,
123
        seed: int | bool | AArray | None = None,  # noqa: FBT001
124
    ) -> None:
125
        """
126
        Initialize the Generator.
127

128
        Parameters
129
        ----------
130
        seed
131
            Seed for the random number generator.
132
        """
133
        import numpy as np  # noqa: PLC0415
8✔
134

135
        import array_api_strict  # noqa: PLC0415
8✔
136

137
        self.axp = array_api_strict
8✔
138
        self.nxp = np
8✔
139
        self.rng = np.random.default_rng(seed=seed)
8✔
140

141
    def random(
8✔
142
        self,
143
        size: Size = None,
144
        dtype: DTypeLike | None = None,
145
        out: AArray | None = None,
146
    ) -> AArray:
147
        """
148
        Return random floats in the half-open interval [0.0, 1.0).
149

150
        Parameters
151
        ----------
152
        size
153
            Output shape.
154
        dtype
155
            Desired data type.
156
        out
157
            Optional output array.
158

159
        Returns
160
        -------
161
            Array of random floats.
162
        """
163
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
164
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
165

166
    def normal(
8✔
167
        self,
168
        loc: float | FloatArray = 0.0,
169
        scale: float | FloatArray = 1.0,
170
        size: Size = None,
171
    ) -> AArray:
172
        """
173
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
174

175
        Parameters
176
        ----------
177
        loc
178
            Mean of the distribution.
179
        scale
180
            Standard deviation of the distribution.
181
        size
182
            Output shape.
183

184
        Returns
185
        -------
186
            Array of samples from the normal distribution.
187
        """
188
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
189

190
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
191
        """
192
        Draw samples from a Poisson distribution.
193

194
        Parameters
195
        ----------
196
        lam
197
            Expected number of events.
198
        size
199
            Output shape.
200

201
        Returns
202
        -------
203
            Array of samples from the Poisson distribution.
204
        """
205
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
206

207
    def standard_normal(
8✔
208
        self,
209
        size: Size = None,
210
        dtype: DTypeLike | None = None,
211
        out: AArray | None = None,
212
    ) -> AArray:
213
        """
214
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
215

216
        Parameters
217
        ----------
218
        size
219
            Output shape.
220
        dtype
221
            Desired data type.
222
        out
223
            Optional output array.
224

225
        Returns
226
        -------
227
            Array of samples from the standard normal distribution.
228
        """
229
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
230
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
231

232
    def uniform(
8✔
233
        self,
234
        low: float | AArray = 0.0,
235
        high: float | AArray = 1.0,
236
        size: Size = None,
237
    ) -> AArray:
238
        """
239
        Draw samples from a Uniform distribution.
240

241
        Parameters
242
        ----------
243
        low
244
            Lower bound of the distribution.
245
        high
246
            Upper bound of the distribution.
247
        size : Size, optional
248
            Output shape.
249

250
        Returns
251
        -------
252
            Array of samples from the uniform distribution.
253
        """
254
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
255

256

257
class XPAdditions:
8✔
258
    """
259
    Additional functions missing from both array-api-strict and array-api-extra.
260

261
    This class provides wrappers for common array operations such as integration,
262
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
263
    and array-api-strict backends.
264

265
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
266
    for details.
267
    """
268

269
    xp: ModuleType
8✔
270
    backend: str
8✔
271

272
    def __init__(self, xp: ModuleType) -> None:
8✔
273
        """
274
        Initialize XPAdditions with the given array namespace.
275

276
        Parameters
277
        ----------
278
        xp
279
            The array namespace module.
280
        """
281
        self.xp = xp
8✔
282

283
    def trapezoid(
8✔
284
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
285
    ) -> AnyArray:
286
        """
287
        Integrate along the given axis using the composite trapezoidal rule.
288

289
        Parameters
290
        ----------
291
        y
292
            Input array to integrate.
293
        x
294
            Sample points corresponding to y.
295
        dx
296
            Spacing between sample points.
297
        axis
298
            Axis along which to integrate.
299

300
        Returns
301
        -------
302
            Integrated result.
303

304
        Raises
305
        ------
306
        NotImplementedError
307
            If the array backend is not supported.
308

309
        Notes
310
        -----
311
        See https://github.com/glass-dev/glass/issues/646
312
        """
313
        if self.xp.__name__ == "jax.numpy":
8✔
314
            import glass.jax  # noqa: PLC0415
8✔
315

316
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
317

318
        if self.xp.__name__ == "numpy":
8✔
319
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
320

321
        if self.xp.__name__ == "array_api_strict":
8✔
322
            np = import_numpy(self.xp.__name__)
8✔
323

324
            # Using design principle of scipy (i.e. copy, use np, copy back)
325
            y_np = np.asarray(y, copy=True)
8✔
326
            x_np = np.asarray(x, copy=True)
8✔
327
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
328
            return self.xp.asarray(result_np, copy=True)
8✔
329

UNCOV
330
        msg = "the array backend in not supported"
×
331
        raise NotImplementedError(msg)
×
332

333
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
334
        """
335
        Compute the set union of two 1D arrays.
336

337
        Parameters
338
        ----------
339
        ar1
340
            First input array.
341
        ar2
342
            Second input array.
343

344
        Returns
345
        -------
346
            The union of the two arrays.
347

348
        Raises
349
        ------
350
        NotImplementedError
351
            If the array backend is not supported.
352

353
        Notes
354
        -----
355
        See https://github.com/glass-dev/glass/issues/647
356
        """
357
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
358
            return self.xp.union1d(ar1, ar2)
8✔
359

360
        if self.xp.__name__ == "array_api_strict":
8✔
361
            np = import_numpy(self.xp.__name__)
8✔
362

363
            # Using design principle of scipy (i.e. copy, use np, copy back)
364
            ar1_np = np.asarray(ar1, copy=True)
8✔
365
            ar2_np = np.asarray(ar2, copy=True)
8✔
366
            result_np = np.union1d(ar1_np, ar2_np)
8✔
367
            return self.xp.asarray(result_np, copy=True)
8✔
368

UNCOV
369
        msg = "the array backend in not supported"
×
370
        raise NotImplementedError(msg)
×
371

372
    def interp(  # noqa: PLR0913
8✔
373
        self,
374
        x: AnyArray,
375
        x_points: AnyArray,
376
        y_points: AnyArray,
377
        left: float | None = None,
378
        right: float | None = None,
379
        period: float | None = None,
380
    ) -> AnyArray:
381
        """
382
        One-dimensional linear interpolation for monotonically increasing sample points.
383

384
        Parameters
385
        ----------
386
        x
387
            The x-coordinates at which to evaluate the interpolated values.
388
        x_points
389
            The x-coordinates of the data points.
390
        y_points
391
            The y-coordinates of the data points.
392
        left
393
            Value to return for x < x_points[0].
394
        right
395
            Value to return for x > x_points[-1].
396
        period
397
            Period for periodic interpolation.
398

399
        Returns
400
        -------
401
            Interpolated values.
402

403
        Raises
404
        ------
405
        NotImplementedError
406
            If the array backend is not supported.
407

408
        Notes
409
        -----
410
        See https://github.com/glass-dev/glass/issues/650
411
        """
412
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
413
            return self.xp.interp(
8✔
414
                x, x_points, y_points, left=left, right=right, period=period
415
            )
416

417
        if self.xp.__name__ == "array_api_strict":
8✔
418
            np = import_numpy(self.xp.__name__)
8✔
419

420
            # Using design principle of scipy (i.e. copy, use np, copy back)
421
            x_np = np.asarray(x, copy=True)
8✔
422
            x_points_np = np.asarray(x_points, copy=True)
8✔
423
            y_points_np = np.asarray(y_points, copy=True)
8✔
424
            result_np = np.interp(
8✔
425
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
426
            )
427
            return self.xp.asarray(result_np, copy=True)
8✔
428

UNCOV
429
        msg = "the array backend in not supported"
×
430
        raise NotImplementedError(msg)
×
431

432
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
433
        """
434
        Return the gradient of an N-dimensional array.
435

436
        Parameters
437
        ----------
438
        f
439
            Input array.
440

441
        Returns
442
        -------
443
            Gradient of the input array.
444

445
        Raises
446
        ------
447
        NotImplementedError
448
            If the array backend is not supported.
449

450
        Notes
451
        -----
452
        See https://github.com/glass-dev/glass/issues/648
453
        """
454
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
455
            return self.xp.gradient(f)
8✔
456

457
        if self.xp.__name__ == "array_api_strict":
8✔
458
            np = import_numpy(self.xp.__name__)
8✔
459

460
            # Using design principle of scipy (i.e. copy, use np, copy back)
461
            f_np = np.asarray(f, copy=True)
8✔
462
            result_np = np.gradient(f_np)
8✔
463
            return self.xp.asarray(result_np, copy=True)
8✔
464

UNCOV
465
        msg = "the array backend in not supported"
×
466
        raise NotImplementedError(msg)
×
467

468
    def linalg_lstsq(
8✔
469
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
470
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
471
        """
472
        Solve a linear least squares problem.
473

474
        Parameters
475
        ----------
476
        a
477
            Coefficient matrix.
478
        b
479
            Ordinate or "dependent variable" values.
480
        rcond
481
            Cut-off ratio for small singular values.
482

483
        Returns
484
        -------
485
        x
486
            Least-squares solution. If b is two-dimensional, the solutions are in the K
487
            columns of x.
488

489
        residuals
490
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
491
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
492
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
493

494
        rank
495
            Rank of matrix a.
496

497
        s
498
            Singular values of a.
499

500
        Raises
501
        ------
502
        NotImplementedError
503
            If the array backend is not supported.
504

505
        Notes
506
        -----
507
        See https://github.com/glass-dev/glass/issues/649
508
        """
509
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
510
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
511

512
        if self.xp.__name__ == "array_api_strict":
8✔
513
            np = import_numpy(self.xp.__name__)
8✔
514

515
            # Using design principle of scipy (i.e. copy, use np, copy back)
516
            a_np = np.asarray(a, copy=True)
8✔
517
            b_np = np.asarray(b, copy=True)
8✔
518
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
519
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
520

UNCOV
521
        msg = "the array backend in not supported"
×
522
        raise NotImplementedError(msg)
×
523

524
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
525
        """
526
        Evaluate the Einstein summation convention on the operands.
527

528
        Parameters
529
        ----------
530
        subscripts
531
            Specifies the subscripts for summation.
532
        *operands
533
            Arrays to be summed.
534

535
        Returns
536
        -------
537
            Result of the Einstein summation.
538

539
        Raises
540
        ------
541
        NotImplementedError
542
            If the array backend is not supported.
543

544
        Notes
545
        -----
546
        See https://github.com/glass-dev/glass/issues/657
547
        """
548
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
549
            return self.xp.einsum(subscripts, *operands)
8✔
550

551
        if self.xp.__name__ == "array_api_strict":
8✔
552
            np = import_numpy(self.xp.__name__)
8✔
553

554
            # Using design principle of scipy (i.e. copy, use np, copy back)
555
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
556
            result_np = np.einsum(subscripts, *operands_np)
8✔
557
            return self.xp.asarray(result_np, copy=True)
8✔
558

UNCOV
559
        msg = "the array backend in not supported"
×
560
        raise NotImplementedError(msg)
×
561

562
    def apply_along_axis(
8✔
563
        self,
564
        func1d: Callable[..., Any],
565
        axis: int,
566
        arr: AnyArray,
567
        *args: object,
568
        **kwargs: object,
569
    ) -> AnyArray:
570
        """
571
        Apply a function to 1-D slices along the given axis.
572

573
        Parameters
574
        ----------
575
        func1d
576
            Function to apply to 1-D slices.
577
        axis
578
            Axis along which to apply the function.
579
        arr
580
            Input array.
581
        *args
582
            Additional positional arguments to pass to func1d.
583
        **kwargs
584
            Additional keyword arguments to pass to func1d.
585

586
        Returns
587
        -------
588
            Result of applying the function along the axis.
589

590
        Raises
591
        ------
592
        NotImplementedError
593
            If the array backend is not supported.
594

595
        Notes
596
        -----
597
        See https://github.com/glass-dev/glass/issues/651
598

599
        """
600
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
601
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
602

603
        if self.xp.__name__ == "array_api_strict":
8✔
604
            # Import here to prevent users relying on numpy unless in this instance
605
            np = import_numpy(self.xp.__name__)
8✔
606

607
            return self.xp.asarray(
8✔
608
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
609
            )
610

UNCOV
611
        msg = "the array backend in not supported"
×
612
        raise NotImplementedError(msg)
×
613

614
    def vectorize(
8✔
615
        self,
616
        pyfunc: Callable[..., Any],
617
        otypes: tuple[type[float]],
618
    ) -> Callable[..., Any]:
619
        """
620
        Returns an object that acts like pyfunc, but takes arrays as input.
621

622
        Parameters
623
        ----------
624
        pyfunc
625
            Python function to vectorize.
626
        otypes
627
            Output types.
628

629
        Returns
630
        -------
631
            Vectorized function.
632

633
        Raises
634
        ------
635
        NotImplementedError
636
            If the array backend is not supported.
637

638
        Notes
639
        -----
640
        See https://github.com/glass-dev/glass/issues/671
641
        """
642
        if self.xp.__name__ == "numpy":
8✔
643
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
644

645
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
646
            # Import here to prevent users relying on numpy unless in this instance
647
            np = import_numpy(self.xp.__name__)
8✔
648

649
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
650

UNCOV
651
        msg = "the array backend in not supported"
×
652
        raise NotImplementedError(msg)
×
653

654
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
655
        """
656
        Convert angles from degrees to radians.
657

658
        Parameters
659
        ----------
660
        deg_arr
661
            Array of angles in degrees.
662

663
        Returns
664
        -------
665
            Array of angles in radians.
666

667
        Raises
668
        ------
669
        NotImplementedError
670
            If the array backend is not supported.
671
        """
672
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
673
            return self.xp.radians(deg_arr)
8✔
674

675
        if self.xp.__name__ == "array_api_strict":
8✔
676
            np = import_numpy(self.xp.__name__)
8✔
677

678
            return self.xp.asarray(np.radians(deg_arr))
8✔
679

UNCOV
680
        msg = "the array backend in not supported"
×
681
        raise NotImplementedError(msg)
×
682

683
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
684
        """
685
        Convert angles from radians to degrees.
686

687
        Parameters
688
        ----------
689
        deg_arr
690
            Array of angles in radians.
691

692
        Returns
693
        -------
694
            Array of angles in degrees.
695

696
        Raises
697
        ------
698
        NotImplementedError
699
            If the array backend is not supported.
700
        """
701
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
702
            return self.xp.degrees(deg_arr)
8✔
703

704
        if self.xp.__name__ == "array_api_strict":
8✔
705
            np = import_numpy(self.xp.__name__)
8✔
706

707
            return self.xp.asarray(np.degrees(deg_arr))
8✔
708

UNCOV
709
        msg = "the array backend in not supported"
×
710
        raise NotImplementedError(msg)
×
711

712
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
713
        """
714
        Wrapper for numpy.ndindex.
715

716
        See relevant docs for details:
717
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
718

719
        Raises
720
        ------
721
        NotImplementedError
722
            If the array backend is not supported.
723

724
        """
725
        if self.xp.__name__ == "numpy":
8✔
726
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
727

728
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
729
            np = import_numpy(self.xp.__name__)
8✔
730

731
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
732

NEW
733
        msg = "the array backend in not supported"
×
NEW
734
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc