• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 19262707898

11 Nov 2025 10:29AM UTC coverage: 93.341% (+0.1%) from 93.208%
19262707898

Pull #726

github

web-flow
Merge 8b8eaf98d into 9bcdbe8e9
Pull Request #726: gh-725: change @dependabot commit title

220 of 222 branches covered (99.1%)

Branch coverage included in aggregate %.

1462 of 1580 relevant lines covered (92.53%)

7.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.53
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from numpy.typing import DTypeLike
27

28
    from array_api_strict._array_object import Array as AArray
29

30
    from glass._types import AnyArray, FloatArray, Size, UnifiedGenerator
31

32

33
class CompatibleBackendNotFoundError(Exception):
8✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
40
        self.message = (
×
41
            f"{missing_backend} is required here as some functions required by GLASS "
42
            f"are not supported by {users_backend}"
43
        )
44
        super().__init__(self.message)
×
45

46

47
def import_numpy(backend: str) -> ModuleType:
8✔
48
    """
49
    Import the NumPy module, raising a helpful error if NumPy is not installed.
50

51
    Parameters
52
    ----------
53
    backend
54
        The name of the backend requested by the user.
55

56
    Returns
57
    -------
58
        The NumPy module.
59

60
    Raises
61
    ------
62
    ModuleNotFoundError
63
        If NumPy is not found in the user's environment.
64

65
    Notes
66
    -----
67
    This is useful for explaining to the user why NumPy is required when their chosen
68
    backend does not implement a needed function.
69
    """
70
    try:
8✔
71
        import numpy  # noqa: ICN001, PLC0415
8✔
72

73
    except ModuleNotFoundError as err:
×
74
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
75
    else:
76
        return numpy
8✔
77

78

79
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
80
    """
81
    Dispatch a random number generator based on the provided array's backend.
82

83
    Parameters
84
    ----------
85
    xp
86
        The array library backend to use for array operations.
87

88
    Returns
89
    -------
90
        The appropriate random number generator for the array's backend.
91

92
    Raises
93
    ------
94
    NotImplementedError
95
        If the array backend is not supported.
96
    """
97
    if xp.__name__ == "jax.numpy":
8✔
98
        import glass.jax  # noqa: PLC0415
8✔
99

100
        return glass.jax.Generator(seed=42)
8✔
101

102
    if xp.__name__ == "numpy":
8✔
103
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
104

105
    if xp.__name__ == "array_api_strict":
8✔
106
        return Generator(seed=42)
8✔
107

108
    msg = "the array backend in not supported"
×
109
    raise NotImplementedError(msg)
×
110

111

112
class Generator:
8✔
113
    """
114
    NumPy random number generator returning array_api_strict Array.
115

116
    This class wraps NumPy's random number generator and returns arrays compatible
117
    with array_api_strict.
118
    """
119

120
    __slots__ = ("axp", "nxp", "rng")
8✔
121

122
    def __init__(
8✔
123
        self,
124
        seed: int | bool | AArray | None = None,  # noqa: FBT001
125
    ) -> None:
126
        """
127
        Initialize the Generator.
128

129
        Parameters
130
        ----------
131
        seed
132
            Seed for the random number generator.
133
        """
134
        import numpy  # noqa: ICN001, PLC0415
8✔
135

136
        import array_api_strict  # noqa: PLC0415
8✔
137

138
        self.axp = array_api_strict
8✔
139
        self.nxp = numpy
8✔
140
        self.rng = self.nxp.random.default_rng(seed=seed)
8✔
141

142
    def random(
8✔
143
        self,
144
        size: Size = None,
145
        dtype: DTypeLike | None = None,
146
        out: AArray | None = None,
147
    ) -> AArray:
148
        """
149
        Return random floats in the half-open interval [0.0, 1.0).
150

151
        Parameters
152
        ----------
153
        size
154
            Output shape.
155
        dtype
156
            Desired data type.
157
        out
158
            Optional output array.
159

160
        Returns
161
        -------
162
            Array of random floats.
163
        """
164
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
165
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
166

167
    def normal(
8✔
168
        self,
169
        loc: float | FloatArray = 0.0,
170
        scale: float | FloatArray = 1.0,
171
        size: Size = None,
172
    ) -> AArray:
173
        """
174
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
175

176
        Parameters
177
        ----------
178
        loc
179
            Mean of the distribution.
180
        scale
181
            Standard deviation of the distribution.
182
        size
183
            Output shape.
184

185
        Returns
186
        -------
187
            Array of samples from the normal distribution.
188
        """
189
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
190

191
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
192
        """
193
        Draw samples from a Poisson distribution.
194

195
        Parameters
196
        ----------
197
        lam
198
            Expected number of events.
199
        size
200
            Output shape.
201

202
        Returns
203
        -------
204
            Array of samples from the Poisson distribution.
205
        """
206
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
207

208
    def standard_normal(
8✔
209
        self,
210
        size: Size = None,
211
        dtype: DTypeLike | None = None,
212
        out: AArray | None = None,
213
    ) -> AArray:
214
        """
215
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
216

217
        Parameters
218
        ----------
219
        size
220
            Output shape.
221
        dtype
222
            Desired data type.
223
        out
224
            Optional output array.
225

226
        Returns
227
        -------
228
            Array of samples from the standard normal distribution.
229
        """
230
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
231
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
232

233
    def uniform(
8✔
234
        self,
235
        low: float | AArray = 0.0,
236
        high: float | AArray = 1.0,
237
        size: Size = None,
238
    ) -> AArray:
239
        """
240
        Draw samples from a Uniform distribution.
241

242
        Parameters
243
        ----------
244
        low
245
            Lower bound of the distribution.
246
        high
247
            Upper bound of the distribution.
248
        size : Size, optional
249
            Output shape.
250

251
        Returns
252
        -------
253
            Array of samples from the uniform distribution.
254
        """
255
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
256

257

258
class XPAdditions:
8✔
259
    """
260
    Additional functions missing from both array-api-strict and array-api-extra.
261

262
    This class provides wrappers for common array operations such as integration,
263
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
264
    and array-api-strict backends.
265

266
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
267
    for details.
268
    """
269

270
    xp: ModuleType
8✔
271
    backend: str
8✔
272

273
    def __init__(self, xp: ModuleType) -> None:
8✔
274
        """
275
        Initialize XPAdditions with the given array namespace.
276

277
        Parameters
278
        ----------
279
        xp
280
            The array library backend to use for array operations.
281
        """
282
        self.xp = xp
8✔
283

284
    def trapezoid(
8✔
285
        self,
286
        y: AnyArray,
287
        x: AnyArray = None,
288
        dx: float = 1.0,
289
        axis: int = -1,
290
    ) -> AnyArray:
291
        """
292
        Integrate along the given axis using the composite trapezoidal rule.
293

294
        Parameters
295
        ----------
296
        y
297
            Input array to integrate.
298
        x
299
            Sample points corresponding to y.
300
        dx
301
            Spacing between sample points.
302
        axis
303
            Axis along which to integrate.
304

305
        Returns
306
        -------
307
            Integrated result.
308

309
        Raises
310
        ------
311
        NotImplementedError
312
            If the array backend is not supported.
313

314
        Notes
315
        -----
316
        See https://github.com/glass-dev/glass/issues/646
317
        """
318
        if self.xp.__name__ == "jax.numpy":
8✔
319
            import glass.jax  # noqa: PLC0415
8✔
320

321
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
322

323
        if self.xp.__name__ == "numpy":
8✔
324
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
325

326
        if self.xp.__name__ == "array_api_strict":
8✔
327
            np = import_numpy(self.xp.__name__)
8✔
328

329
            # Using design principle of scipy (i.e. copy, use np, copy back)
330
            y_np = np.asarray(y, copy=True)
8✔
331
            x_np = np.asarray(x, copy=True)
8✔
332
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
333
            return self.xp.asarray(result_np, copy=True)
8✔
334

335
        msg = "the array backend in not supported"
×
336
        raise NotImplementedError(msg)
×
337

338
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
339
        """
340
        Compute the set union of two 1D arrays.
341

342
        Parameters
343
        ----------
344
        ar1
345
            First input array.
346
        ar2
347
            Second input array.
348

349
        Returns
350
        -------
351
            The union of the two arrays.
352

353
        Raises
354
        ------
355
        NotImplementedError
356
            If the array backend is not supported.
357

358
        Notes
359
        -----
360
        See https://github.com/glass-dev/glass/issues/647
361
        """
362
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
363
            return self.xp.union1d(ar1, ar2)
8✔
364

365
        if self.xp.__name__ == "array_api_strict":
8✔
366
            np = import_numpy(self.xp.__name__)
8✔
367

368
            # Using design principle of scipy (i.e. copy, use np, copy back)
369
            ar1_np = np.asarray(ar1, copy=True)
8✔
370
            ar2_np = np.asarray(ar2, copy=True)
8✔
371
            result_np = np.union1d(ar1_np, ar2_np)
8✔
372
            return self.xp.asarray(result_np, copy=True)
8✔
373

374
        msg = "the array backend in not supported"
×
375
        raise NotImplementedError(msg)
×
376

377
    def interp(  # noqa: PLR0913
8✔
378
        self,
379
        x: AnyArray,
380
        x_points: AnyArray,
381
        y_points: AnyArray,
382
        left: float | None = None,
383
        right: float | None = None,
384
        period: float | None = None,
385
    ) -> AnyArray:
386
        """
387
        One-dimensional linear interpolation for monotonically increasing sample points.
388

389
        Parameters
390
        ----------
391
        x
392
            The x-coordinates at which to evaluate the interpolated values.
393
        x_points
394
            The x-coordinates of the data points.
395
        y_points
396
            The y-coordinates of the data points.
397
        left
398
            Value to return for x < x_points[0].
399
        right
400
            Value to return for x > x_points[-1].
401
        period
402
            Period for periodic interpolation.
403

404
        Returns
405
        -------
406
            Interpolated values.
407

408
        Raises
409
        ------
410
        NotImplementedError
411
            If the array backend is not supported.
412

413
        Notes
414
        -----
415
        See https://github.com/glass-dev/glass/issues/650
416
        """
417
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
418
            return self.xp.interp(
8✔
419
                x,
420
                x_points,
421
                y_points,
422
                left=left,
423
                right=right,
424
                period=period,
425
            )
426

427
        if self.xp.__name__ == "array_api_strict":
8✔
428
            np = import_numpy(self.xp.__name__)
8✔
429

430
            # Using design principle of scipy (i.e. copy, use np, copy back)
431
            x_np = np.asarray(x, copy=True)
8✔
432
            x_points_np = np.asarray(x_points, copy=True)
8✔
433
            y_points_np = np.asarray(y_points, copy=True)
8✔
434
            result_np = np.interp(
8✔
435
                x_np,
436
                x_points_np,
437
                y_points_np,
438
                left=left,
439
                right=right,
440
                period=period,
441
            )
442
            return self.xp.asarray(result_np, copy=True)
8✔
443

444
        msg = "the array backend in not supported"
×
445
        raise NotImplementedError(msg)
×
446

447
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
448
        """
449
        Return the gradient of an N-dimensional array.
450

451
        Parameters
452
        ----------
453
        f
454
            Input array.
455

456
        Returns
457
        -------
458
            Gradient of the input array.
459

460
        Raises
461
        ------
462
        NotImplementedError
463
            If the array backend is not supported.
464

465
        Notes
466
        -----
467
        See https://github.com/glass-dev/glass/issues/648
468
        """
469
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
470
            return self.xp.gradient(f)
8✔
471

472
        if self.xp.__name__ == "array_api_strict":
8✔
473
            np = import_numpy(self.xp.__name__)
8✔
474

475
            # Using design principle of scipy (i.e. copy, use np, copy back)
476
            f_np = np.asarray(f, copy=True)
8✔
477
            result_np = np.gradient(f_np)
8✔
478
            return self.xp.asarray(result_np, copy=True)
8✔
479

480
        msg = "the array backend in not supported"
×
481
        raise NotImplementedError(msg)
×
482

483
    def linalg_lstsq(
8✔
484
        self,
485
        a: AnyArray,
486
        b: AnyArray,
487
        rcond: float | None = None,
488
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
489
        """
490
        Solve a linear least squares problem.
491

492
        Parameters
493
        ----------
494
        a
495
            Coefficient matrix.
496
        b
497
            Ordinate or "dependent variable" values.
498
        rcond
499
            Cut-off ratio for small singular values.
500

501
        Returns
502
        -------
503
        x
504
            Least-squares solution. If b is two-dimensional, the solutions are in the K
505
            columns of x.
506

507
        residuals
508
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
509
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
510
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
511

512
        rank
513
            Rank of matrix a.
514

515
        s
516
            Singular values of a.
517

518
        Raises
519
        ------
520
        NotImplementedError
521
            If the array backend is not supported.
522

523
        Notes
524
        -----
525
        See https://github.com/glass-dev/glass/issues/649
526
        """
527
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
528
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
529

530
        if self.xp.__name__ == "array_api_strict":
8✔
531
            np = import_numpy(self.xp.__name__)
8✔
532

533
            # Using design principle of scipy (i.e. copy, use np, copy back)
534
            a_np = np.asarray(a, copy=True)
8✔
535
            b_np = np.asarray(b, copy=True)
8✔
536
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
537
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
538

539
        msg = "the array backend in not supported"
×
540
        raise NotImplementedError(msg)
×
541

542
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
543
        """
544
        Evaluate the Einstein summation convention on the operands.
545

546
        Parameters
547
        ----------
548
        subscripts
549
            Specifies the subscripts for summation.
550
        *operands
551
            Arrays to be summed.
552

553
        Returns
554
        -------
555
            Result of the Einstein summation.
556

557
        Raises
558
        ------
559
        NotImplementedError
560
            If the array backend is not supported.
561

562
        Notes
563
        -----
564
        See https://github.com/glass-dev/glass/issues/657
565
        """
566
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
567
            return self.xp.einsum(subscripts, *operands)
8✔
568

569
        if self.xp.__name__ == "array_api_strict":
8✔
570
            np = import_numpy(self.xp.__name__)
8✔
571

572
            # Using design principle of scipy (i.e. copy, use np, copy back)
573
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
574
            result_np = np.einsum(subscripts, *operands_np)
8✔
575
            return self.xp.asarray(result_np, copy=True)
8✔
576

577
        msg = "the array backend in not supported"
×
578
        raise NotImplementedError(msg)
×
579

580
    def apply_along_axis(
8✔
581
        self,
582
        func1d: Callable[..., Any],
583
        axis: int,
584
        arr: AnyArray,
585
        *args: object,
586
        **kwargs: object,
587
    ) -> AnyArray:
588
        """
589
        Apply a function to 1-D slices along the given axis.
590

591
        Parameters
592
        ----------
593
        func1d
594
            Function to apply to 1-D slices.
595
        axis
596
            Axis along which to apply the function.
597
        arr
598
            Input array.
599
        *args
600
            Additional positional arguments to pass to func1d.
601
        **kwargs
602
            Additional keyword arguments to pass to func1d.
603

604
        Returns
605
        -------
606
            Result of applying the function along the axis.
607

608
        Raises
609
        ------
610
        NotImplementedError
611
            If the array backend is not supported.
612

613
        Notes
614
        -----
615
        See https://github.com/glass-dev/glass/issues/651
616

617
        """
618
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
619
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
620

621
        if self.xp.__name__ == "array_api_strict":
8✔
622
            # Import here to prevent users relying on numpy unless in this instance
623
            np = import_numpy(self.xp.__name__)
8✔
624

625
            return self.xp.asarray(
8✔
626
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
627
                copy=True,
628
            )
629

630
        msg = "the array backend in not supported"
×
631
        raise NotImplementedError(msg)
×
632

633
    def vectorize(
8✔
634
        self,
635
        pyfunc: Callable[..., Any],
636
        otypes: tuple[type[float]],
637
    ) -> Callable[..., Any]:
638
        """
639
        Returns an object that acts like pyfunc, but takes arrays as input.
640

641
        Parameters
642
        ----------
643
        pyfunc
644
            Python function to vectorize.
645
        otypes
646
            Output types.
647

648
        Returns
649
        -------
650
            Vectorized function.
651

652
        Raises
653
        ------
654
        NotImplementedError
655
            If the array backend is not supported.
656

657
        Notes
658
        -----
659
        See https://github.com/glass-dev/glass/issues/671
660
        """
661
        if self.xp.__name__ == "numpy":
8✔
662
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
663

664
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
665
            # Import here to prevent users relying on numpy unless in this instance
666
            np = import_numpy(self.xp.__name__)
8✔
667

668
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
669

670
        msg = "the array backend in not supported"
×
671
        raise NotImplementedError(msg)
×
672

673
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
674
        """
675
        Convert angles from degrees to radians.
676

677
        Parameters
678
        ----------
679
        deg_arr
680
            Array of angles in degrees.
681

682
        Returns
683
        -------
684
            Array of angles in radians.
685

686
        Raises
687
        ------
688
        NotImplementedError
689
            If the array backend is not supported.
690
        """
691
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
692
            return self.xp.radians(deg_arr)
8✔
693

694
        if self.xp.__name__ == "array_api_strict":
8✔
695
            np = import_numpy(self.xp.__name__)
8✔
696

697
            return self.xp.asarray(np.radians(deg_arr))
8✔
698

699
        msg = "the array backend in not supported"
×
700
        raise NotImplementedError(msg)
×
701

702
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
703
        """
704
        Convert angles from radians to degrees.
705

706
        Parameters
707
        ----------
708
        deg_arr
709
            Array of angles in radians.
710

711
        Returns
712
        -------
713
            Array of angles in degrees.
714

715
        Raises
716
        ------
717
        NotImplementedError
718
            If the array backend is not supported.
719
        """
720
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
721
            return self.xp.degrees(deg_arr)
8✔
722

723
        if self.xp.__name__ == "array_api_strict":
8✔
724
            np = import_numpy(self.xp.__name__)
8✔
725

726
            return self.xp.asarray(np.degrees(deg_arr))
8✔
727

728
        msg = "the array backend in not supported"
×
729
        raise NotImplementedError(msg)
×
730

731
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
732
        """
733
        Wrapper for numpy.ndindex.
734

735
        See relevant docs for details:
736
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
737

738
        Raises
739
        ------
740
        NotImplementedError
741
            If the array backend is not supported.
742

743
        """
744
        if self.xp.__name__ == "numpy":
8✔
745
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
746

747
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
748
            np = import_numpy(self.xp.__name__)
8✔
749

750
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
751

752
        msg = "the array backend in not supported"
×
753
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc