• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18910519250

29 Oct 2025 02:01PM UTC coverage: 93.582% (-0.1%) from 93.689%
18910519250

Pull #729

github

web-flow
Merge b7e662b40 into c4cfa4a63
Pull Request #729: gh-644: create `UnifiedGenerator` type in its own file

213 of 215 branches covered (99.07%)

Branch coverage included in aggregate %.

1 of 3 new or added lines in 2 files covered. (33.33%)

54 existing lines in 6 files now uncovered.

1420 of 1530 relevant lines covered (92.81%)

7.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.88
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from jaxtyping import Array as JAXArray
27
    from numpy.typing import DTypeLike, NDArray
28

29
    from array_api_strict._array_object import Array as AArray
30

31
    from glass._types import UnifiedGenerator
32

33
    Size: TypeAlias = int | tuple[int, ...] | None
34

35
    AnyArray: TypeAlias = NDArray[Any] | JAXArray | AArray
36
    ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
37
    DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
38
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
39
    IntArray: TypeAlias = NDArray[np.int_] | JAXArray | AArray
40

41

42
class CompatibleBackendNotFoundError(Exception):
8✔
43
    """
44
    Exception raised when an array library backend that
45
    implements a requested function, is not found.
46
    """
47

48
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
UNCOV
49
        self.message = (
×
50
            f"{missing_backend} is required here as some functions required by GLASS "
51
            f"are not supported by {users_backend}"
52
        )
UNCOV
53
        super().__init__(self.message)
×
54

55

56
def import_numpy(backend: str) -> ModuleType:
8✔
57
    """
58
    Import the NumPy module, raising a helpful error if NumPy is not installed.
59

60
    Parameters
61
    ----------
62
    backend
63
        The name of the backend requested by the user.
64

65
    Returns
66
    -------
67
        The NumPy module.
68

69
    Raises
70
    ------
71
    ModuleNotFoundError
72
        If NumPy is not found in the user's environment.
73

74
    Notes
75
    -----
76
    This is useful for explaining to the user why NumPy is required when their chosen
77
    backend does not implement a needed function.
78
    """
79
    try:
8✔
80
        import numpy  # noqa: ICN001, PLC0415
8✔
81

UNCOV
82
    except ModuleNotFoundError as err:
×
UNCOV
83
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
84
    else:
85
        return numpy
8✔
86

87

88
def rng_dispatcher(array: AnyArray) -> UnifiedGenerator:
8✔
89
    """
90
    Dispatch a random number generator based on the provided array's backend.
91

92
    Parameters
93
    ----------
94
    array
95
        The array whose backend determines the RNG.
96

97
    Returns
98
    -------
99
        The appropriate random number generator for the array's backend.
100

101
    Raises
102
    ------
103
    NotImplementedError
104
        If the array backend is not supported.
105
    """
106
    xp = array.__array_namespace__()
8✔
107

108
    if xp.__name__ == "jax.numpy":
8✔
109
        import glass.jax  # noqa: PLC0415
8✔
110

111
        return glass.jax.Generator(seed=42)
8✔
112

113
    if xp.__name__ == "numpy":
8✔
114
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
115

116
    if xp.__name__ == "array_api_strict":
8✔
117
        return Generator(seed=42)
8✔
118

UNCOV
119
    msg = "the array backend in not supported"
×
UNCOV
120
    raise NotImplementedError(msg)
×
121

122

123
class Generator:
8✔
124
    """
125
    NumPy random number generator returning array_api_strict Array.
126

127
    This class wraps NumPy's random number generator and returns arrays compatible
128
    with array_api_strict.
129
    """
130

131
    __slots__ = ("axp", "nxp", "rng")
8✔
132

133
    def __init__(
8✔
134
        self,
135
        seed: int | bool | AArray | None = None,  # noqa: FBT001
136
    ) -> None:
137
        """
138
        Initialize the Generator.
139

140
        Parameters
141
        ----------
142
        seed
143
            Seed for the random number generator.
144
        """
145
        import numpy as np  # noqa: PLC0415
8✔
146

147
        import array_api_strict  # noqa: PLC0415
8✔
148

149
        self.axp = array_api_strict
8✔
150
        self.nxp = np
8✔
151
        self.rng = np.random.default_rng(seed=seed)
8✔
152

153
    def random(
8✔
154
        self,
155
        size: Size = None,
156
        dtype: DTypeLike | None = None,
157
        out: AArray | None = None,
158
    ) -> AArray:
159
        """
160
        Return random floats in the half-open interval [0.0, 1.0).
161

162
        Parameters
163
        ----------
164
        size
165
            Output shape.
166
        dtype
167
            Desired data type.
168
        out
169
            Optional output array.
170

171
        Returns
172
        -------
173
            Array of random floats.
174
        """
175
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
176
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
177

178
    def normal(
8✔
179
        self,
180
        loc: float | AArray = 0.0,
181
        scale: float | AArray = 1.0,
182
        size: Size = None,
183
    ) -> AArray:
184
        """
185
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
186

187
        Parameters
188
        ----------
189
        loc
190
            Mean of the distribution.
191
        scale
192
            Standard deviation of the distribution.
193
        size
194
            Output shape.
195

196
        Returns
197
        -------
198
            Array of samples from the normal distribution.
199
        """
200
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
201

202
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
203
        """
204
        Draw samples from a Poisson distribution.
205

206
        Parameters
207
        ----------
208
        lam
209
            Expected number of events.
210
        size
211
            Output shape.
212

213
        Returns
214
        -------
215
            Array of samples from the Poisson distribution.
216
        """
217
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
218

219
    def standard_normal(
8✔
220
        self,
221
        size: Size = None,
222
        dtype: DTypeLike | None = None,
223
        out: AArray | None = None,
224
    ) -> AArray:
225
        """
226
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
227

228
        Parameters
229
        ----------
230
        size
231
            Output shape.
232
        dtype
233
            Desired data type.
234
        out
235
            Optional output array.
236

237
        Returns
238
        -------
239
            Array of samples from the standard normal distribution.
240
        """
241
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
242
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
243

244
    def uniform(
8✔
245
        self,
246
        low: float | AArray = 0.0,
247
        high: float | AArray = 1.0,
248
        size: Size = None,
249
    ) -> AArray:
250
        """
251
        Draw samples from a Uniform distribution.
252

253
        Parameters
254
        ----------
255
        low
256
            Lower bound of the distribution.
257
        high
258
            Upper bound of the distribution.
259
        size : Size, optional
260
            Output shape.
261

262
        Returns
263
        -------
264
            Array of samples from the uniform distribution.
265
        """
266
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
267

268

269
class XPAdditions:
8✔
270
    """
271
    Additional functions missing from both array-api-strict and array-api-extra.
272

273
    This class provides wrappers for common array operations such as integration,
274
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
275
    and array-api-strict backends.
276

277
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
278
    for details.
279
    """
280

281
    xp: ModuleType
8✔
282
    backend: str
8✔
283

284
    def __init__(self, xp: ModuleType) -> None:
8✔
285
        """
286
        Initialize XPAdditions with the given array namespace.
287

288
        Parameters
289
        ----------
290
        xp
291
            The array namespace module.
292
        """
293
        self.xp = xp
8✔
294

295
    def trapezoid(
8✔
296
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
297
    ) -> AnyArray:
298
        """
299
        Integrate along the given axis using the composite trapezoidal rule.
300

301
        Parameters
302
        ----------
303
        y
304
            Input array to integrate.
305
        x
306
            Sample points corresponding to y.
307
        dx
308
            Spacing between sample points.
309
        axis
310
            Axis along which to integrate.
311

312
        Returns
313
        -------
314
            Integrated result.
315

316
        Raises
317
        ------
318
        NotImplementedError
319
            If the array backend is not supported.
320

321
        Notes
322
        -----
323
        See https://github.com/glass-dev/glass/issues/646
324
        """
325
        if self.xp.__name__ == "jax.numpy":
8✔
326
            import glass.jax  # noqa: PLC0415
8✔
327

328
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
329

330
        if self.xp.__name__ == "numpy":
8✔
331
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
332

333
        if self.xp.__name__ == "array_api_strict":
8✔
334
            np = import_numpy(self.xp.__name__)
8✔
335

336
            # Using design principle of scipy (i.e. copy, use np, copy back)
337
            y_np = np.asarray(y, copy=True)
8✔
338
            x_np = np.asarray(x, copy=True)
8✔
339
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
340
            return self.xp.asarray(result_np, copy=True)
8✔
341

UNCOV
342
        msg = "the array backend in not supported"
×
UNCOV
343
        raise NotImplementedError(msg)
×
344

345
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
346
        """
347
        Compute the set union of two 1D arrays.
348

349
        Parameters
350
        ----------
351
        ar1
352
            First input array.
353
        ar2
354
            Second input array.
355

356
        Returns
357
        -------
358
            The union of the two arrays.
359

360
        Raises
361
        ------
362
        NotImplementedError
363
            If the array backend is not supported.
364

365
        Notes
366
        -----
367
        See https://github.com/glass-dev/glass/issues/647
368
        """
369
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
370
            return self.xp.union1d(ar1, ar2)
8✔
371

372
        if self.xp.__name__ == "array_api_strict":
8✔
373
            np = import_numpy(self.xp.__name__)
8✔
374

375
            # Using design principle of scipy (i.e. copy, use np, copy back)
376
            ar1_np = np.asarray(ar1, copy=True)
8✔
377
            ar2_np = np.asarray(ar2, copy=True)
8✔
378
            result_np = np.union1d(ar1_np, ar2_np)
8✔
379
            return self.xp.asarray(result_np, copy=True)
8✔
380

UNCOV
381
        msg = "the array backend in not supported"
×
UNCOV
382
        raise NotImplementedError(msg)
×
383

384
    def interp(  # noqa: PLR0913
8✔
385
        self,
386
        x: AnyArray,
387
        x_points: AnyArray,
388
        y_points: AnyArray,
389
        left: float | None = None,
390
        right: float | None = None,
391
        period: float | None = None,
392
    ) -> AnyArray:
393
        """
394
        One-dimensional linear interpolation for monotonically increasing sample points.
395

396
        Parameters
397
        ----------
398
        x
399
            The x-coordinates at which to evaluate the interpolated values.
400
        x_points
401
            The x-coordinates of the data points.
402
        y_points
403
            The y-coordinates of the data points.
404
        left
405
            Value to return for x < x_points[0].
406
        right
407
            Value to return for x > x_points[-1].
408
        period
409
            Period for periodic interpolation.
410

411
        Returns
412
        -------
413
            Interpolated values.
414

415
        Raises
416
        ------
417
        NotImplementedError
418
            If the array backend is not supported.
419

420
        Notes
421
        -----
422
        See https://github.com/glass-dev/glass/issues/650
423
        """
424
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
425
            return self.xp.interp(
8✔
426
                x, x_points, y_points, left=left, right=right, period=period
427
            )
428

429
        if self.xp.__name__ == "array_api_strict":
8✔
430
            np = import_numpy(self.xp.__name__)
8✔
431

432
            # Using design principle of scipy (i.e. copy, use np, copy back)
433
            x_np = np.asarray(x, copy=True)
8✔
434
            x_points_np = np.asarray(x_points, copy=True)
8✔
435
            y_points_np = np.asarray(y_points, copy=True)
8✔
436
            result_np = np.interp(
8✔
437
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
438
            )
439
            return self.xp.asarray(result_np, copy=True)
8✔
440

UNCOV
441
        msg = "the array backend in not supported"
×
UNCOV
442
        raise NotImplementedError(msg)
×
443

444
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
445
        """
446
        Return the gradient of an N-dimensional array.
447

448
        Parameters
449
        ----------
450
        f
451
            Input array.
452

453
        Returns
454
        -------
455
            Gradient of the input array.
456

457
        Raises
458
        ------
459
        NotImplementedError
460
            If the array backend is not supported.
461

462
        Notes
463
        -----
464
        See https://github.com/glass-dev/glass/issues/648
465
        """
466
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
467
            return self.xp.gradient(f)
8✔
468

469
        if self.xp.__name__ == "array_api_strict":
8✔
470
            np = import_numpy(self.xp.__name__)
8✔
471

472
            # Using design principle of scipy (i.e. copy, use np, copy back)
473
            f_np = np.asarray(f, copy=True)
8✔
474
            result_np = np.gradient(f_np)
8✔
475
            return self.xp.asarray(result_np, copy=True)
8✔
476

UNCOV
477
        msg = "the array backend in not supported"
×
UNCOV
478
        raise NotImplementedError(msg)
×
479

480
    def linalg_lstsq(
8✔
481
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
482
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
483
        """
484
        Solve a linear least squares problem.
485

486
        Parameters
487
        ----------
488
        a
489
            Coefficient matrix.
490
        b
491
            Ordinate or "dependent variable" values.
492
        rcond
493
            Cut-off ratio for small singular values.
494

495
        Returns
496
        -------
497
        x
498
            Least-squares solution. If b is two-dimensional, the solutions are in the K
499
            columns of x.
500

501
        residuals
502
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
503
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
504
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
505

506
        rank
507
            Rank of matrix a.
508

509
        s
510
            Singular values of a.
511

512
        Raises
513
        ------
514
        NotImplementedError
515
            If the array backend is not supported.
516

517
        Notes
518
        -----
519
        See https://github.com/glass-dev/glass/issues/649
520
        """
521
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
522
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
523

524
        if self.xp.__name__ == "array_api_strict":
8✔
525
            np = import_numpy(self.xp.__name__)
8✔
526

527
            # Using design principle of scipy (i.e. copy, use np, copy back)
528
            a_np = np.asarray(a, copy=True)
8✔
529
            b_np = np.asarray(b, copy=True)
8✔
530
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
531
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
532

UNCOV
533
        msg = "the array backend in not supported"
×
UNCOV
534
        raise NotImplementedError(msg)
×
535

536
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
537
        """
538
        Evaluate the Einstein summation convention on the operands.
539

540
        Parameters
541
        ----------
542
        subscripts
543
            Specifies the subscripts for summation.
544
        *operands
545
            Arrays to be summed.
546

547
        Returns
548
        -------
549
            Result of the Einstein summation.
550

551
        Raises
552
        ------
553
        NotImplementedError
554
            If the array backend is not supported.
555

556
        Notes
557
        -----
558
        See https://github.com/glass-dev/glass/issues/657
559
        """
560
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
561
            return self.xp.einsum(subscripts, *operands)
8✔
562

563
        if self.xp.__name__ == "array_api_strict":
8✔
564
            np = import_numpy(self.xp.__name__)
8✔
565

566
            # Using design principle of scipy (i.e. copy, use np, copy back)
567
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
568
            result_np = np.einsum(subscripts, *operands_np)
8✔
569
            return self.xp.asarray(result_np, copy=True)
8✔
570

UNCOV
571
        msg = "the array backend in not supported"
×
UNCOV
572
        raise NotImplementedError(msg)
×
573

574
    def apply_along_axis(
8✔
575
        self,
576
        func1d: Callable[..., Any],
577
        axis: int,
578
        arr: AnyArray,
579
        *args: object,
580
        **kwargs: object,
581
    ) -> AnyArray:
582
        """
583
        Apply a function to 1-D slices along the given axis.
584

585
        Parameters
586
        ----------
587
        func1d
588
            Function to apply to 1-D slices.
589
        axis
590
            Axis along which to apply the function.
591
        arr
592
            Input array.
593
        *args
594
            Additional positional arguments to pass to func1d.
595
        **kwargs
596
            Additional keyword arguments to pass to func1d.
597

598
        Returns
599
        -------
600
            Result of applying the function along the axis.
601

602
        Raises
603
        ------
604
        NotImplementedError
605
            If the array backend is not supported.
606

607
        Notes
608
        -----
609
        See https://github.com/glass-dev/glass/issues/651
610

611
        """
612
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
613
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
614

615
        if self.xp.__name__ == "array_api_strict":
8✔
616
            # Import here to prevent users relying on numpy unless in this instance
617
            np = import_numpy(self.xp.__name__)
8✔
618

619
            return self.xp.asarray(
8✔
620
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
621
            )
622

UNCOV
623
        msg = "the array backend in not supported"
×
UNCOV
624
        raise NotImplementedError(msg)
×
625

626
    def vectorize(
8✔
627
        self,
628
        pyfunc: Callable[..., Any],
629
        otypes: tuple[type[float]],
630
    ) -> Callable[..., Any]:
631
        """
632
        Returns an object that acts like pyfunc, but takes arrays as input.
633

634
        Parameters
635
        ----------
636
        pyfunc
637
            Python function to vectorize.
638
        otypes
639
            Output types.
640

641
        Returns
642
        -------
643
            Vectorized function.
644

645
        Raises
646
        ------
647
        NotImplementedError
648
            If the array backend is not supported.
649

650
        Notes
651
        -----
652
        See https://github.com/glass-dev/glass/issues/671
653
        """
654
        if self.xp.__name__ == "numpy":
8✔
655
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
656

657
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
658
            # Import here to prevent users relying on numpy unless in this instance
659
            np = import_numpy(self.xp.__name__)
8✔
660

661
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
662

UNCOV
663
        msg = "the array backend in not supported"
×
UNCOV
664
        raise NotImplementedError(msg)
×
665

666
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
667
        """
668
        Convert angles from degrees to radians.
669

670
        Parameters
671
        ----------
672
        deg_arr
673
            Array of angles in degrees.
674

675
        Returns
676
        -------
677
            Array of angles in radians.
678

679
        Raises
680
        ------
681
        NotImplementedError
682
            If the array backend is not supported.
683
        """
684
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
685
            return self.xp.radians(deg_arr)
8✔
686

687
        if self.xp.__name__ == "array_api_strict":
8✔
688
            np = import_numpy(self.xp.__name__)
8✔
689

690
            return self.xp.asarray(np.radians(deg_arr))
8✔
691

UNCOV
692
        msg = "the array backend in not supported"
×
UNCOV
693
        raise NotImplementedError(msg)
×
694

695
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
696
        """
697
        Convert angles from radians to degrees.
698

699
        Parameters
700
        ----------
701
        deg_arr
702
            Array of angles in radians.
703

704
        Returns
705
        -------
706
            Array of angles in degrees.
707

708
        Raises
709
        ------
710
        NotImplementedError
711
            If the array backend is not supported.
712
        """
713
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
714
            return self.xp.degrees(deg_arr)
8✔
715

716
        if self.xp.__name__ == "array_api_strict":
8✔
717
            np = import_numpy(self.xp.__name__)
8✔
718

719
            return self.xp.asarray(np.degrees(deg_arr))
8✔
720

UNCOV
721
        msg = "the array backend in not supported"
×
UNCOV
722
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc