• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 19740964514

27 Nov 2025 03:20PM UTC coverage: 93.566% (-2.2%) from 95.78%
19740964514

push

github

web-flow
gh-818: Adding benchmark(s) for `shells.py`  (#833)

Co-authored-by: Patrick J. Roddy <patrickjamesroddy@gmail.com>

219 of 221 branches covered (99.1%)

Branch coverage included in aggregate %.

1468 of 1582 relevant lines covered (92.79%)

5.17 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.78
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from numpy.typing import DTypeLike
27

28
    from array_api_strict._array_object import Array as AArray
29

30
    from glass._types import AnyArray, FloatArray, UnifiedGenerator
31

32

33
class CompatibleBackendNotFoundError(Exception):
8✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
8✔
40
        self.message = (
×
41
            f"{missing_backend} is required here as "
42
            "no alternative has been provided by the user."
43
            if users_backend is None
44
            else f"GLASS depends on functions not supported by {users_backend}"
45
        )
46
        super().__init__(self.message)
×
47

48

49
def import_numpy(backend: str | None = None) -> ModuleType:
8✔
50
    """
51
    Import the NumPy module, raising a helpful error if NumPy is not installed.
52

53
    Parameters
54
    ----------
55
    backend
56
        The name of the backend requested by the user.
57

58
    Returns
59
    -------
60
        The NumPy module.
61

62
    Raises
63
    ------
64
    ModuleNotFoundError
65
        If NumPy is not found in the user's environment.
66

67
    Notes
68
    -----
69
    This is useful for explaining to the user why NumPy is required when their chosen
70
    backend does not implement a needed function.
71
    """
72
    try:
8✔
73
        import numpy  # noqa: ICN001, PLC0415
8✔
74

75
    except ModuleNotFoundError as err:
×
76
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
77
    else:
78
        return numpy
8✔
79

80

81
def default_xp() -> ModuleType:
8✔
82
    """Returns the library backend we default to if none is specified by the user."""
83
    return import_numpy()
4✔
84

85

86
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
87
    """
88
    Dispatch a random number generator based on the provided array's backend.
89

90
    Parameters
91
    ----------
92
    xp
93
        The array library backend to use for array operations.
94

95
    Returns
96
    -------
97
        The appropriate random number generator for the array's backend.
98

99
    Raises
100
    ------
101
    NotImplementedError
102
        If the array backend is not supported.
103
    """
104
    seed = 42
8✔
105

106
    if xp.__name__ == "jax.numpy":
8✔
107
        import glass.jax  # noqa: PLC0415
4✔
108

109
        return glass.jax.Generator(seed=seed)
4✔
110

111
    if xp.__name__ == "numpy":
8✔
112
        return xp.random.default_rng(seed=seed)  # type: ignore[no-any-return]
8✔
113

114
    if xp.__name__ == "array_api_strict":
8✔
115
        return Generator(seed=seed)
8✔
116

117
    msg = "the array backend in not supported"
×
118
    raise NotImplementedError(msg)
×
119

120

121
class Generator:
8✔
122
    """
123
    NumPy random number generator returning array_api_strict Array.
124

125
    This class wraps NumPy's random number generator and returns arrays compatible
126
    with array_api_strict.
127
    """
128

129
    __slots__ = ("axp", "nxp", "rng")
8✔
130

131
    def __init__(
8✔
132
        self,
133
        seed: int | bool | AArray | None = None,  # noqa: FBT001
134
    ) -> None:
135
        """
136
        Initialize the Generator.
137

138
        Parameters
139
        ----------
140
        seed
141
            Seed for the random number generator.
142
        """
143
        import numpy  # noqa: ICN001, PLC0415
8✔
144

145
        import array_api_strict  # noqa: PLC0415
8✔
146

147
        self.axp = array_api_strict
8✔
148
        self.nxp = numpy
8✔
149
        self.rng = self.nxp.random.default_rng(seed=seed)
8✔
150

151
    def random(
8✔
152
        self,
153
        size: int | tuple[int, ...] | None = None,
154
        dtype: DTypeLike | None = None,
155
        out: AArray | None = None,
156
    ) -> AArray:
157
        """
158
        Return random floats in the half-open interval [0.0, 1.0).
159

160
        Parameters
161
        ----------
162
        size
163
            Output shape.
164
        dtype
165
            Desired data type.
166
        out
167
            Optional output array.
168

169
        Returns
170
        -------
171
            Array of random floats.
172
        """
173
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
174
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
175

176
    def normal(
8✔
177
        self,
178
        loc: float | FloatArray = 0.0,
179
        scale: float | FloatArray = 1.0,
180
        size: int | tuple[int, ...] | None = None,
181
    ) -> AArray:
182
        """
183
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
184

185
        Parameters
186
        ----------
187
        loc
188
            Mean of the distribution.
189
        scale
190
            Standard deviation of the distribution.
191
        size
192
            Output shape.
193

194
        Returns
195
        -------
196
            Array of samples from the normal distribution.
197
        """
198
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
199

200
    def poisson(
8✔
201
        self,
202
        lam: float | AArray,
203
        size: int | tuple[int, ...] | None = None,
204
    ) -> AArray:
205
        """
206
        Draw samples from a Poisson distribution.
207

208
        Parameters
209
        ----------
210
        lam
211
            Expected number of events.
212
        size
213
            Output shape.
214

215
        Returns
216
        -------
217
            Array of samples from the Poisson distribution.
218
        """
219
        return self.axp.asarray(self.rng.poisson(lam, size))
4✔
220

221
    def standard_normal(
8✔
222
        self,
223
        size: int | tuple[int, ...] | None = None,
224
        dtype: DTypeLike | None = None,
225
        out: AArray | None = None,
226
    ) -> AArray:
227
        """
228
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
229

230
        Parameters
231
        ----------
232
        size
233
            Output shape.
234
        dtype
235
            Desired data type.
236
        out
237
            Optional output array.
238

239
        Returns
240
        -------
241
            Array of samples from the standard normal distribution.
242
        """
243
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
244
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
245

246
    def uniform(
8✔
247
        self,
248
        low: float | AArray = 0.0,
249
        high: float | AArray = 1.0,
250
        size: int | tuple[int, ...] | None = None,
251
    ) -> AArray:
252
        """
253
        Draw samples from a Uniform distribution.
254

255
        Parameters
256
        ----------
257
        low
258
            Lower bound of the distribution.
259
        high
260
            Upper bound of the distribution.
261
        size
262
            Output shape.
263

264
        Returns
265
        -------
266
            Array of samples from the uniform distribution.
267
        """
268
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
269

270

271
class XPAdditions:
8✔
272
    """
273
    Additional functions missing from both array-api-strict and array-api-extra.
274

275
    This class provides wrappers for common array operations such as integration,
276
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
277
    and array-api-strict backends.
278

279
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
280
    for details.
281
    """
282

283
    xp: ModuleType
8✔
284
    backend: str
8✔
285

286
    def __init__(self, xp: ModuleType) -> None:
8✔
287
        """
288
        Initialize XPAdditions with the given array namespace.
289

290
        Parameters
291
        ----------
292
        xp
293
            The array library backend to use for array operations.
294
        """
295
        self.xp = xp
8✔
296

297
    def trapezoid(
8✔
298
        self,
299
        y: AnyArray,
300
        x: AnyArray = None,
301
        dx: float = 1.0,
302
        axis: int = -1,
303
    ) -> AnyArray:
304
        """
305
        Integrate along the given axis using the composite trapezoidal rule.
306

307
        Parameters
308
        ----------
309
        y
310
            Input array to integrate.
311
        x
312
            Sample points corresponding to y.
313
        dx
314
            Spacing between sample points.
315
        axis
316
            Axis along which to integrate.
317

318
        Returns
319
        -------
320
            Integrated result.
321

322
        Raises
323
        ------
324
        NotImplementedError
325
            If the array backend is not supported.
326

327
        Notes
328
        -----
329
        See https://github.com/glass-dev/glass/issues/646
330
        """
331
        if self.xp.__name__ == "jax.numpy":
8✔
332
            import glass.jax  # noqa: PLC0415
8✔
333

334
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
335

336
        if self.xp.__name__ == "numpy":
8✔
337
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
338

339
        if self.xp.__name__ == "array_api_strict":
8✔
340
            np = import_numpy(self.xp.__name__)
8✔
341

342
            # Using design principle of scipy (i.e. copy, use np, copy back)
343
            y_np = np.asarray(y, copy=True)
8✔
344
            x_np = np.asarray(x, copy=True)
8✔
345
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
346
            return self.xp.asarray(result_np, copy=True)
8✔
347

348
        msg = "the array backend in not supported"
×
349
        raise NotImplementedError(msg)
×
350

351
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
352
        """
353
        Compute the set union of two 1D arrays.
354

355
        Parameters
356
        ----------
357
        ar1
358
            First input array.
359
        ar2
360
            Second input array.
361

362
        Returns
363
        -------
364
            The union of the two arrays.
365

366
        Raises
367
        ------
368
        NotImplementedError
369
            If the array backend is not supported.
370

371
        Notes
372
        -----
373
        See https://github.com/glass-dev/glass/issues/647
374
        """
375
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
376
            return self.xp.union1d(ar1, ar2)
4✔
377

378
        if self.xp.__name__ == "array_api_strict":
4✔
379
            np = import_numpy(self.xp.__name__)
4✔
380

381
            # Using design principle of scipy (i.e. copy, use np, copy back)
382
            ar1_np = np.asarray(ar1, copy=True)
4✔
383
            ar2_np = np.asarray(ar2, copy=True)
4✔
384
            result_np = np.union1d(ar1_np, ar2_np)
4✔
385
            return self.xp.asarray(result_np, copy=True)
4✔
386

387
        msg = "the array backend in not supported"
×
388
        raise NotImplementedError(msg)
×
389

390
    def interp(  # noqa: PLR0913
8✔
391
        self,
392
        x: AnyArray,
393
        x_points: AnyArray,
394
        y_points: AnyArray,
395
        left: float | None = None,
396
        right: float | None = None,
397
        period: float | None = None,
398
    ) -> AnyArray:
399
        """
400
        One-dimensional linear interpolation for monotonically increasing sample points.
401

402
        Parameters
403
        ----------
404
        x
405
            The x-coordinates at which to evaluate the interpolated values.
406
        x_points
407
            The x-coordinates of the data points.
408
        y_points
409
            The y-coordinates of the data points.
410
        left
411
            Value to return for x < x_points[0].
412
        right
413
            Value to return for x > x_points[-1].
414
        period
415
            Period for periodic interpolation.
416

417
        Returns
418
        -------
419
            Interpolated values.
420

421
        Raises
422
        ------
423
        NotImplementedError
424
            If the array backend is not supported.
425

426
        Notes
427
        -----
428
        See https://github.com/glass-dev/glass/issues/650
429
        """
430
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
431
            return self.xp.interp(
4✔
432
                x,
433
                x_points,
434
                y_points,
435
                left=left,
436
                right=right,
437
                period=period,
438
            )
439

440
        if self.xp.__name__ == "array_api_strict":
4✔
441
            np = import_numpy(self.xp.__name__)
4✔
442

443
            # Using design principle of scipy (i.e. copy, use np, copy back)
444
            x_np = np.asarray(x, copy=True)
4✔
445
            x_points_np = np.asarray(x_points, copy=True)
4✔
446
            y_points_np = np.asarray(y_points, copy=True)
4✔
447
            result_np = np.interp(
4✔
448
                x_np,
449
                x_points_np,
450
                y_points_np,
451
                left=left,
452
                right=right,
453
                period=period,
454
            )
455
            return self.xp.asarray(result_np, copy=True)
4✔
456

457
        msg = "the array backend in not supported"
×
458
        raise NotImplementedError(msg)
×
459

460
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
461
        """
462
        Return the gradient of an N-dimensional array.
463

464
        Parameters
465
        ----------
466
        f
467
            Input array.
468

469
        Returns
470
        -------
471
            Gradient of the input array.
472

473
        Raises
474
        ------
475
        NotImplementedError
476
            If the array backend is not supported.
477

478
        Notes
479
        -----
480
        See https://github.com/glass-dev/glass/issues/648
481
        """
482
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
483
            return self.xp.gradient(f)
4✔
484

485
        if self.xp.__name__ == "array_api_strict":
4✔
486
            np = import_numpy(self.xp.__name__)
4✔
487

488
            # Using design principle of scipy (i.e. copy, use np, copy back)
489
            f_np = np.asarray(f, copy=True)
4✔
490
            result_np = np.gradient(f_np)
4✔
491
            return self.xp.asarray(result_np, copy=True)
4✔
492

493
        msg = "the array backend in not supported"
×
494
        raise NotImplementedError(msg)
×
495

496
    def linalg_lstsq(
8✔
497
        self,
498
        a: AnyArray,
499
        b: AnyArray,
500
        rcond: float | None = None,
501
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
502
        """
503
        Solve a linear least squares problem.
504

505
        Parameters
506
        ----------
507
        a
508
            Coefficient matrix.
509
        b
510
            Ordinate or "dependent variable" values.
511
        rcond
512
            Cut-off ratio for small singular values.
513

514
        Returns
515
        -------
516
        x
517
            Least-squares solution. If b is two-dimensional, the solutions are in the K
518
            columns of x.
519

520
        residuals
521
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
522
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
523
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
524

525
        rank
526
            Rank of matrix a.
527

528
        s
529
            Singular values of a.
530

531
        Raises
532
        ------
533
        NotImplementedError
534
            If the array backend is not supported.
535

536
        Notes
537
        -----
538
        See https://github.com/glass-dev/glass/issues/649
539
        """
540
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
541
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
4✔
542

543
        if self.xp.__name__ == "array_api_strict":
4✔
544
            np = import_numpy(self.xp.__name__)
4✔
545

546
            # Using design principle of scipy (i.e. copy, use np, copy back)
547
            a_np = np.asarray(a, copy=True)
4✔
548
            b_np = np.asarray(b, copy=True)
4✔
549
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
4✔
550
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
4✔
551

552
        msg = "the array backend in not supported"
×
553
        raise NotImplementedError(msg)
×
554

555
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
556
        """
557
        Evaluate the Einstein summation convention on the operands.
558

559
        Parameters
560
        ----------
561
        subscripts
562
            Specifies the subscripts for summation.
563
        *operands
564
            Arrays to be summed.
565

566
        Returns
567
        -------
568
            Result of the Einstein summation.
569

570
        Raises
571
        ------
572
        NotImplementedError
573
            If the array backend is not supported.
574

575
        Notes
576
        -----
577
        See https://github.com/glass-dev/glass/issues/657
578
        """
579
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
580
            return self.xp.einsum(subscripts, *operands)
4✔
581

582
        if self.xp.__name__ == "array_api_strict":
4✔
583
            np = import_numpy(self.xp.__name__)
4✔
584

585
            # Using design principle of scipy (i.e. copy, use np, copy back)
586
            operands_np = (np.asarray(op, copy=True) for op in operands)
4✔
587
            result_np = np.einsum(subscripts, *operands_np)
4✔
588
            return self.xp.asarray(result_np, copy=True)
4✔
589

590
        msg = "the array backend in not supported"
×
591
        raise NotImplementedError(msg)
×
592

593
    def apply_along_axis(
8✔
594
        self,
595
        func1d: Callable[..., Any],
596
        axis: int,
597
        arr: AnyArray,
598
        *args: object,
599
        **kwargs: object,
600
    ) -> AnyArray:
601
        """
602
        Apply a function to 1-D slices along the given axis.
603

604
        Parameters
605
        ----------
606
        func1d
607
            Function to apply to 1-D slices.
608
        axis
609
            Axis along which to apply the function.
610
        arr
611
            Input array.
612
        *args
613
            Additional positional arguments to pass to func1d.
614
        **kwargs
615
            Additional keyword arguments to pass to func1d.
616

617
        Returns
618
        -------
619
            Result of applying the function along the axis.
620

621
        Raises
622
        ------
623
        NotImplementedError
624
            If the array backend is not supported.
625

626
        Notes
627
        -----
628
        See https://github.com/glass-dev/glass/issues/651
629

630
        """
631
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
632
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
4✔
633

634
        if self.xp.__name__ == "array_api_strict":
4✔
635
            # Import here to prevent users relying on numpy unless in this instance
636
            np = import_numpy(self.xp.__name__)
4✔
637

638
            return self.xp.asarray(
4✔
639
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
640
                copy=True,
641
            )
642

643
        msg = "the array backend in not supported"
×
644
        raise NotImplementedError(msg)
×
645

646
    def vectorize(
8✔
647
        self,
648
        pyfunc: Callable[..., Any],
649
        otypes: tuple[type[float]],
650
    ) -> Callable[..., Any]:
651
        """
652
        Returns an object that acts like pyfunc, but takes arrays as input.
653

654
        Parameters
655
        ----------
656
        pyfunc
657
            Python function to vectorize.
658
        otypes
659
            Output types.
660

661
        Returns
662
        -------
663
            Vectorized function.
664

665
        Raises
666
        ------
667
        NotImplementedError
668
            If the array backend is not supported.
669

670
        Notes
671
        -----
672
        See https://github.com/glass-dev/glass/issues/671
673
        """
674
        if self.xp.__name__ == "numpy":
4✔
675
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
4✔
676

677
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
4✔
678
            # Import here to prevent users relying on numpy unless in this instance
679
            np = import_numpy(self.xp.__name__)
4✔
680

681
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
4✔
682

683
        msg = "the array backend in not supported"
×
684
        raise NotImplementedError(msg)
×
685

686
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
687
        """
688
        Convert angles from degrees to radians.
689

690
        Parameters
691
        ----------
692
        deg_arr
693
            Array of angles in degrees.
694

695
        Returns
696
        -------
697
            Array of angles in radians.
698

699
        Raises
700
        ------
701
        NotImplementedError
702
            If the array backend is not supported.
703
        """
704
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
705
            return self.xp.radians(deg_arr)
4✔
706

707
        if self.xp.__name__ == "array_api_strict":
4✔
708
            np = import_numpy(self.xp.__name__)
4✔
709

710
            return self.xp.asarray(np.radians(deg_arr))
4✔
711

712
        msg = "the array backend in not supported"
×
713
        raise NotImplementedError(msg)
×
714

715
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
716
        """
717
        Convert angles from radians to degrees.
718

719
        Parameters
720
        ----------
721
        deg_arr
722
            Array of angles in radians.
723

724
        Returns
725
        -------
726
            Array of angles in degrees.
727

728
        Raises
729
        ------
730
        NotImplementedError
731
            If the array backend is not supported.
732
        """
733
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
734
            return self.xp.degrees(deg_arr)
4✔
735

736
        if self.xp.__name__ == "array_api_strict":
4✔
737
            np = import_numpy(self.xp.__name__)
4✔
738

739
            return self.xp.asarray(np.degrees(deg_arr))
4✔
740

741
        msg = "the array backend in not supported"
×
742
        raise NotImplementedError(msg)
×
743

744
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
745
        """
746
        Wrapper for numpy.ndindex.
747

748
        See relevant docs for details:
749
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
750

751
        Raises
752
        ------
753
        NotImplementedError
754
            If the array backend is not supported.
755

756
        """
757
        if self.xp.__name__ == "numpy":
8✔
758
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
759

760
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
761
            np = import_numpy(self.xp.__name__)
8✔
762

763
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
764

765
        msg = "the array backend in not supported"
×
766
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc