• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 19109688739

05 Nov 2025 04:54PM UTC coverage: 93.548% (+0.2%) from 93.341%
19109688739

Pull #746

github

web-flow
Merge cbf034bc0 into 433f7fdd4
Pull Request #746: gh-417: Improve type rendering in the documentation

220 of 222 branches covered (99.1%)

Branch coverage included in aggregate %.

8 of 8 new or added lines in 4 files covered. (100.0%)

25 existing lines in 2 files now uncovered.

1462 of 1576 relevant lines covered (92.77%)

7.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.53
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from numpy.typing import DTypeLike
27

28
    from array_api_strict._array_object import Array as AArray
29

30
    from glass._types import AnyArray, FloatArray, UnifiedGenerator
31

32

33
class CompatibleBackendNotFoundError(Exception):
8✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
40
        self.message = (
×
41
            f"{missing_backend} is required here as some functions required by GLASS "
42
            f"are not supported by {users_backend}"
43
        )
44
        super().__init__(self.message)
×
45

46

47
def import_numpy(backend: str) -> ModuleType:
8✔
48
    """
49
    Import the NumPy module, raising a helpful error if NumPy is not installed.
50

51
    Parameters
52
    ----------
53
    backend
54
        The name of the backend requested by the user.
55

56
    Returns
57
    -------
58
        The NumPy module.
59

60
    Raises
61
    ------
62
    ModuleNotFoundError
63
        If NumPy is not found in the user's environment.
64

65
    Notes
66
    -----
67
    This is useful for explaining to the user why NumPy is required when their chosen
68
    backend does not implement a needed function.
69
    """
70
    try:
8✔
71
        import numpy  # noqa: ICN001, PLC0415
8✔
72

73
    except ModuleNotFoundError as err:
×
74
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
75
    else:
76
        return numpy
8✔
77

78

79
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
80
    """
81
    Dispatch a random number generator based on the provided array's backend.
82

83
    Parameters
84
    ----------
85
    xp
86
        The array library backend to use for array operations.
87

88
    Returns
89
    -------
90
        The appropriate random number generator for the array's backend.
91

92
    Raises
93
    ------
94
    NotImplementedError
95
        If the array backend is not supported.
96
    """
97
    if xp.__name__ == "jax.numpy":
8✔
98
        import glass.jax  # noqa: PLC0415
8✔
99

100
        return glass.jax.Generator(seed=42)
8✔
101

102
    if xp.__name__ == "numpy":
8✔
103
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
104

105
    if xp.__name__ == "array_api_strict":
8✔
106
        return Generator(seed=42)
8✔
107

108
    msg = "the array backend in not supported"
×
109
    raise NotImplementedError(msg)
×
110

111

112
class Generator:
8✔
113
    """
114
    NumPy random number generator returning array_api_strict Array.
115

116
    This class wraps NumPy's random number generator and returns arrays compatible
117
    with array_api_strict.
118
    """
119

120
    __slots__ = ("axp", "nxp", "rng")
8✔
121

122
    def __init__(
8✔
123
        self,
124
        seed: int | bool | AArray | None = None,  # noqa: FBT001
125
    ) -> None:
126
        """
127
        Initialize the Generator.
128

129
        Parameters
130
        ----------
131
        seed
132
            Seed for the random number generator.
133
        """
134
        import numpy  # noqa: ICN001, PLC0415
8✔
135

136
        import array_api_strict  # noqa: PLC0415
8✔
137

138
        self.axp = array_api_strict
8✔
139
        self.nxp = numpy
8✔
140
        self.rng = self.nxp.random.default_rng(seed=seed)
8✔
141

142
    def random(
8✔
143
        self,
144
        size: int | tuple[int, ...] | None = None,
145
        dtype: DTypeLike | None = None,
146
        out: AArray | None = None,
147
    ) -> AArray:
148
        """
149
        Return random floats in the half-open interval [0.0, 1.0).
150

151
        Parameters
152
        ----------
153
        size
154
            Output shape.
155
        dtype
156
            Desired data type.
157
        out
158
            Optional output array.
159

160
        Returns
161
        -------
162
            Array of random floats.
163
        """
164
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
165
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
166

167
    def normal(
8✔
168
        self,
169
        loc: float | FloatArray = 0.0,
170
        scale: float | FloatArray = 1.0,
171
        size: int | tuple[int, ...] | None = None,
172
    ) -> AArray:
173
        """
174
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
175

176
        Parameters
177
        ----------
178
        loc
179
            Mean of the distribution.
180
        scale
181
            Standard deviation of the distribution.
182
        size
183
            Output shape.
184

185
        Returns
186
        -------
187
            Array of samples from the normal distribution.
188
        """
189
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
190

191
    def poisson(
8✔
192
        self, lam: float | AArray, size: int | tuple[int, ...] | None = None
193
    ) -> AArray:
194
        """
195
        Draw samples from a Poisson distribution.
196

197
        Parameters
198
        ----------
199
        lam
200
            Expected number of events.
201
        size
202
            Output shape.
203

204
        Returns
205
        -------
206
            Array of samples from the Poisson distribution.
207
        """
208
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
209

210
    def standard_normal(
8✔
211
        self,
212
        size: int | tuple[int, ...] | None = None,
213
        dtype: DTypeLike | None = None,
214
        out: AArray | None = None,
215
    ) -> AArray:
216
        """
217
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
218

219
        Parameters
220
        ----------
221
        size
222
            Output shape.
223
        dtype
224
            Desired data type.
225
        out
226
            Optional output array.
227

228
        Returns
229
        -------
230
            Array of samples from the standard normal distribution.
231
        """
232
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
233
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
234

235
    def uniform(
8✔
236
        self,
237
        low: float | AArray = 0.0,
238
        high: float | AArray = 1.0,
239
        size: int | tuple[int, ...] | None = None,
240
    ) -> AArray:
241
        """
242
        Draw samples from a Uniform distribution.
243

244
        Parameters
245
        ----------
246
        low
247
            Lower bound of the distribution.
248
        high
249
            Upper bound of the distribution.
250
        size
251
            Output shape.
252

253
        Returns
254
        -------
255
            Array of samples from the uniform distribution.
256
        """
257
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
258

259

260
class XPAdditions:
8✔
261
    """
262
    Additional functions missing from both array-api-strict and array-api-extra.
263

264
    This class provides wrappers for common array operations such as integration,
265
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
266
    and array-api-strict backends.
267

268
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
269
    for details.
270
    """
271

272
    xp: ModuleType
8✔
273
    backend: str
8✔
274

275
    def __init__(self, xp: ModuleType) -> None:
8✔
276
        """
277
        Initialize XPAdditions with the given array namespace.
278

279
        Parameters
280
        ----------
281
        xp
282
            The array library backend to use for array operations.
283
        """
284
        self.xp = xp
8✔
285

286
    def trapezoid(
8✔
287
        self,
288
        y: AnyArray,
289
        x: AnyArray = None,
290
        dx: float = 1.0,
291
        axis: int = -1,
292
    ) -> AnyArray:
293
        """
294
        Integrate along the given axis using the composite trapezoidal rule.
295

296
        Parameters
297
        ----------
298
        y
299
            Input array to integrate.
300
        x
301
            Sample points corresponding to y.
302
        dx
303
            Spacing between sample points.
304
        axis
305
            Axis along which to integrate.
306

307
        Returns
308
        -------
309
            Integrated result.
310

311
        Raises
312
        ------
313
        NotImplementedError
314
            If the array backend is not supported.
315

316
        Notes
317
        -----
318
        See https://github.com/glass-dev/glass/issues/646
319
        """
320
        if self.xp.__name__ == "jax.numpy":
8✔
321
            import glass.jax  # noqa: PLC0415
8✔
322

323
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
324

325
        if self.xp.__name__ == "numpy":
8✔
326
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
327

328
        if self.xp.__name__ == "array_api_strict":
8✔
329
            np = import_numpy(self.xp.__name__)
8✔
330

331
            # Using design principle of scipy (i.e. copy, use np, copy back)
332
            y_np = np.asarray(y, copy=True)
8✔
333
            x_np = np.asarray(x, copy=True)
8✔
334
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
335
            return self.xp.asarray(result_np, copy=True)
8✔
336

UNCOV
337
        msg = "the array backend in not supported"
×
UNCOV
338
        raise NotImplementedError(msg)
×
339

340
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
341
        """
342
        Compute the set union of two 1D arrays.
343

344
        Parameters
345
        ----------
346
        ar1
347
            First input array.
348
        ar2
349
            Second input array.
350

351
        Returns
352
        -------
353
            The union of the two arrays.
354

355
        Raises
356
        ------
357
        NotImplementedError
358
            If the array backend is not supported.
359

360
        Notes
361
        -----
362
        See https://github.com/glass-dev/glass/issues/647
363
        """
364
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
365
            return self.xp.union1d(ar1, ar2)
8✔
366

367
        if self.xp.__name__ == "array_api_strict":
8✔
368
            np = import_numpy(self.xp.__name__)
8✔
369

370
            # Using design principle of scipy (i.e. copy, use np, copy back)
371
            ar1_np = np.asarray(ar1, copy=True)
8✔
372
            ar2_np = np.asarray(ar2, copy=True)
8✔
373
            result_np = np.union1d(ar1_np, ar2_np)
8✔
374
            return self.xp.asarray(result_np, copy=True)
8✔
375

UNCOV
376
        msg = "the array backend in not supported"
×
UNCOV
377
        raise NotImplementedError(msg)
×
378

379
    def interp(  # noqa: PLR0913
8✔
380
        self,
381
        x: AnyArray,
382
        x_points: AnyArray,
383
        y_points: AnyArray,
384
        left: float | None = None,
385
        right: float | None = None,
386
        period: float | None = None,
387
    ) -> AnyArray:
388
        """
389
        One-dimensional linear interpolation for monotonically increasing sample points.
390

391
        Parameters
392
        ----------
393
        x
394
            The x-coordinates at which to evaluate the interpolated values.
395
        x_points
396
            The x-coordinates of the data points.
397
        y_points
398
            The y-coordinates of the data points.
399
        left
400
            Value to return for x < x_points[0].
401
        right
402
            Value to return for x > x_points[-1].
403
        period
404
            Period for periodic interpolation.
405

406
        Returns
407
        -------
408
            Interpolated values.
409

410
        Raises
411
        ------
412
        NotImplementedError
413
            If the array backend is not supported.
414

415
        Notes
416
        -----
417
        See https://github.com/glass-dev/glass/issues/650
418
        """
419
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
420
            return self.xp.interp(
8✔
421
                x,
422
                x_points,
423
                y_points,
424
                left=left,
425
                right=right,
426
                period=period,
427
            )
428

429
        if self.xp.__name__ == "array_api_strict":
8✔
430
            np = import_numpy(self.xp.__name__)
8✔
431

432
            # Using design principle of scipy (i.e. copy, use np, copy back)
433
            x_np = np.asarray(x, copy=True)
8✔
434
            x_points_np = np.asarray(x_points, copy=True)
8✔
435
            y_points_np = np.asarray(y_points, copy=True)
8✔
436
            result_np = np.interp(
8✔
437
                x_np,
438
                x_points_np,
439
                y_points_np,
440
                left=left,
441
                right=right,
442
                period=period,
443
            )
444
            return self.xp.asarray(result_np, copy=True)
8✔
445

UNCOV
446
        msg = "the array backend in not supported"
×
UNCOV
447
        raise NotImplementedError(msg)
×
448

449
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
450
        """
451
        Return the gradient of an N-dimensional array.
452

453
        Parameters
454
        ----------
455
        f
456
            Input array.
457

458
        Returns
459
        -------
460
            Gradient of the input array.
461

462
        Raises
463
        ------
464
        NotImplementedError
465
            If the array backend is not supported.
466

467
        Notes
468
        -----
469
        See https://github.com/glass-dev/glass/issues/648
470
        """
471
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
472
            return self.xp.gradient(f)
8✔
473

474
        if self.xp.__name__ == "array_api_strict":
8✔
475
            np = import_numpy(self.xp.__name__)
8✔
476

477
            # Using design principle of scipy (i.e. copy, use np, copy back)
478
            f_np = np.asarray(f, copy=True)
8✔
479
            result_np = np.gradient(f_np)
8✔
480
            return self.xp.asarray(result_np, copy=True)
8✔
481

UNCOV
482
        msg = "the array backend in not supported"
×
UNCOV
483
        raise NotImplementedError(msg)
×
484

485
    def linalg_lstsq(
8✔
486
        self,
487
        a: AnyArray,
488
        b: AnyArray,
489
        rcond: float | None = None,
490
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
491
        """
492
        Solve a linear least squares problem.
493

494
        Parameters
495
        ----------
496
        a
497
            Coefficient matrix.
498
        b
499
            Ordinate or "dependent variable" values.
500
        rcond
501
            Cut-off ratio for small singular values.
502

503
        Returns
504
        -------
505
        x
506
            Least-squares solution. If b is two-dimensional, the solutions are in the K
507
            columns of x.
508

509
        residuals
510
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
511
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
512
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
513

514
        rank
515
            Rank of matrix a.
516

517
        s
518
            Singular values of a.
519

520
        Raises
521
        ------
522
        NotImplementedError
523
            If the array backend is not supported.
524

525
        Notes
526
        -----
527
        See https://github.com/glass-dev/glass/issues/649
528
        """
529
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
530
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
531

532
        if self.xp.__name__ == "array_api_strict":
8✔
533
            np = import_numpy(self.xp.__name__)
8✔
534

535
            # Using design principle of scipy (i.e. copy, use np, copy back)
536
            a_np = np.asarray(a, copy=True)
8✔
537
            b_np = np.asarray(b, copy=True)
8✔
538
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
539
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
540

UNCOV
541
        msg = "the array backend in not supported"
×
UNCOV
542
        raise NotImplementedError(msg)
×
543

544
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
545
        """
546
        Evaluate the Einstein summation convention on the operands.
547

548
        Parameters
549
        ----------
550
        subscripts
551
            Specifies the subscripts for summation.
552
        *operands
553
            Arrays to be summed.
554

555
        Returns
556
        -------
557
            Result of the Einstein summation.
558

559
        Raises
560
        ------
561
        NotImplementedError
562
            If the array backend is not supported.
563

564
        Notes
565
        -----
566
        See https://github.com/glass-dev/glass/issues/657
567
        """
568
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
569
            return self.xp.einsum(subscripts, *operands)
8✔
570

571
        if self.xp.__name__ == "array_api_strict":
8✔
572
            np = import_numpy(self.xp.__name__)
8✔
573

574
            # Using design principle of scipy (i.e. copy, use np, copy back)
575
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
576
            result_np = np.einsum(subscripts, *operands_np)
8✔
577
            return self.xp.asarray(result_np, copy=True)
8✔
578

UNCOV
579
        msg = "the array backend in not supported"
×
UNCOV
580
        raise NotImplementedError(msg)
×
581

582
    def apply_along_axis(
8✔
583
        self,
584
        func1d: Callable[..., Any],
585
        axis: int,
586
        arr: AnyArray,
587
        *args: object,
588
        **kwargs: object,
589
    ) -> AnyArray:
590
        """
591
        Apply a function to 1-D slices along the given axis.
592

593
        Parameters
594
        ----------
595
        func1d
596
            Function to apply to 1-D slices.
597
        axis
598
            Axis along which to apply the function.
599
        arr
600
            Input array.
601
        *args
602
            Additional positional arguments to pass to func1d.
603
        **kwargs
604
            Additional keyword arguments to pass to func1d.
605

606
        Returns
607
        -------
608
            Result of applying the function along the axis.
609

610
        Raises
611
        ------
612
        NotImplementedError
613
            If the array backend is not supported.
614

615
        Notes
616
        -----
617
        See https://github.com/glass-dev/glass/issues/651
618

619
        """
620
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
621
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
622

623
        if self.xp.__name__ == "array_api_strict":
8✔
624
            # Import here to prevent users relying on numpy unless in this instance
625
            np = import_numpy(self.xp.__name__)
8✔
626

627
            return self.xp.asarray(
8✔
628
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
629
                copy=True,
630
            )
631

UNCOV
632
        msg = "the array backend in not supported"
×
UNCOV
633
        raise NotImplementedError(msg)
×
634

635
    def vectorize(
8✔
636
        self,
637
        pyfunc: Callable[..., Any],
638
        otypes: tuple[type[float]],
639
    ) -> Callable[..., Any]:
640
        """
641
        Returns an object that acts like pyfunc, but takes arrays as input.
642

643
        Parameters
644
        ----------
645
        pyfunc
646
            Python function to vectorize.
647
        otypes
648
            Output types.
649

650
        Returns
651
        -------
652
            Vectorized function.
653

654
        Raises
655
        ------
656
        NotImplementedError
657
            If the array backend is not supported.
658

659
        Notes
660
        -----
661
        See https://github.com/glass-dev/glass/issues/671
662
        """
663
        if self.xp.__name__ == "numpy":
8✔
664
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
665

666
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
667
            # Import here to prevent users relying on numpy unless in this instance
668
            np = import_numpy(self.xp.__name__)
8✔
669

670
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
671

UNCOV
672
        msg = "the array backend in not supported"
×
UNCOV
673
        raise NotImplementedError(msg)
×
674

675
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
676
        """
677
        Convert angles from degrees to radians.
678

679
        Parameters
680
        ----------
681
        deg_arr
682
            Array of angles in degrees.
683

684
        Returns
685
        -------
686
            Array of angles in radians.
687

688
        Raises
689
        ------
690
        NotImplementedError
691
            If the array backend is not supported.
692
        """
693
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
694
            return self.xp.radians(deg_arr)
8✔
695

696
        if self.xp.__name__ == "array_api_strict":
8✔
697
            np = import_numpy(self.xp.__name__)
8✔
698

699
            return self.xp.asarray(np.radians(deg_arr))
8✔
700

UNCOV
701
        msg = "the array backend in not supported"
×
UNCOV
702
        raise NotImplementedError(msg)
×
703

704
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
705
        """
706
        Convert angles from radians to degrees.
707

708
        Parameters
709
        ----------
710
        deg_arr
711
            Array of angles in radians.
712

713
        Returns
714
        -------
715
            Array of angles in degrees.
716

717
        Raises
718
        ------
719
        NotImplementedError
720
            If the array backend is not supported.
721
        """
722
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
723
            return self.xp.degrees(deg_arr)
8✔
724

725
        if self.xp.__name__ == "array_api_strict":
8✔
726
            np = import_numpy(self.xp.__name__)
8✔
727

728
            return self.xp.asarray(np.degrees(deg_arr))
8✔
729

UNCOV
730
        msg = "the array backend in not supported"
×
UNCOV
731
        raise NotImplementedError(msg)
×
732

733
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
734
        """
735
        Wrapper for numpy.ndindex.
736

737
        See relevant docs for details:
738
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
739

740
        Raises
741
        ------
742
        NotImplementedError
743
            If the array backend is not supported.
744

745
        """
746
        if self.xp.__name__ == "numpy":
8✔
747
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
748

749
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
750
            np = import_numpy(self.xp.__name__)
8✔
751

752
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
753

UNCOV
754
        msg = "the array backend in not supported"
×
UNCOV
755
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc