• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18994847882

31 Oct 2025 07:30PM UTC coverage: 93.414% (+0.2%) from 93.208%
18994847882

Pull #746

github

web-flow
Merge db06ad960 into 13f6b6e52
Pull Request #746: gh-417: Improve type rendering in the documentation

223 of 225 branches covered (99.11%)

Branch coverage included in aggregate %.

8 of 9 new or added lines in 4 files covered. (88.89%)

26 existing lines in 2 files now uncovered.

1465 of 1582 relevant lines covered (92.6%)

7.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.53
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from numpy.typing import DTypeLike
27

28
    from array_api_strict._array_object import Array as AArray
29

30
    from glass._types import AnyArray, FloatArray, UnifiedGenerator
31

32

33
class CompatibleBackendNotFoundError(Exception):
8✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
40
        self.message = (
×
41
            f"{missing_backend} is required here as some functions required by GLASS "
42
            f"are not supported by {users_backend}"
43
        )
44
        super().__init__(self.message)
×
45

46

47
def import_numpy(backend: str) -> ModuleType:
8✔
48
    """
49
    Import the NumPy module, raising a helpful error if NumPy is not installed.
50

51
    Parameters
52
    ----------
53
    backend
54
        The name of the backend requested by the user.
55

56
    Returns
57
    -------
58
        The NumPy module.
59

60
    Raises
61
    ------
62
    ModuleNotFoundError
63
        If NumPy is not found in the user's environment.
64

65
    Notes
66
    -----
67
    This is useful for explaining to the user why NumPy is required when their chosen
68
    backend does not implement a needed function.
69
    """
70
    try:
8✔
71
        import numpy  # noqa: ICN001, PLC0415
8✔
72

73
    except ModuleNotFoundError as err:
×
74
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
75
    else:
76
        return numpy
8✔
77

78

79
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
80
    """
81
    Dispatch a random number generator based on the provided array's backend.
82

83
    Parameters
84
    ----------
85
    xp
86
        The array backend which determines the RNG.
87

88
    Returns
89
    -------
90
        The appropriate random number generator for the array's backend.
91

92
    Raises
93
    ------
94
    NotImplementedError
95
        If the array backend is not supported.
96
    """
97
    if xp.__name__ == "jax.numpy":
8✔
98
        import glass.jax  # noqa: PLC0415
8✔
99

100
        return glass.jax.Generator(seed=42)
8✔
101

102
    if xp.__name__ == "numpy":
8✔
103
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
104

105
    if xp.__name__ == "array_api_strict":
8✔
106
        return Generator(seed=42)
8✔
107

108
    msg = "the array backend in not supported"
×
109
    raise NotImplementedError(msg)
×
110

111

112
class Generator:
8✔
113
    """
114
    NumPy random number generator returning array_api_strict Array.
115

116
    This class wraps NumPy's random number generator and returns arrays compatible
117
    with array_api_strict.
118
    """
119

120
    __slots__ = ("axp", "nxp", "rng")
8✔
121

122
    def __init__(
8✔
123
        self,
124
        seed: int | bool | AArray | None = None,  # noqa: FBT001
125
    ) -> None:
126
        """
127
        Initialize the Generator.
128

129
        Parameters
130
        ----------
131
        seed
132
            Seed for the random number generator.
133
        """
134
        import numpy as np  # noqa: PLC0415
8✔
135

136
        import array_api_strict  # noqa: PLC0415
8✔
137

138
        self.axp = array_api_strict
8✔
139
        self.nxp = np
8✔
140
        self.rng = np.random.default_rng(seed=seed)
8✔
141

142
    def random(
8✔
143
        self,
144
        size: int | tuple[int, ...] | None = None,
145
        dtype: DTypeLike | None = None,
146
        out: AArray | None = None,
147
    ) -> AArray:
148
        """
149
        Return random floats in the half-open interval [0.0, 1.0).
150

151
        Parameters
152
        ----------
153
        size
154
            Output shape.
155
        dtype
156
            Desired data type.
157
        out
158
            Optional output array.
159

160
        Returns
161
        -------
162
            Array of random floats.
163
        """
164
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
165
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
166

167
    def normal(
8✔
168
        self,
169
        loc: float | FloatArray = 0.0,
170
        scale: float | FloatArray = 1.0,
171
        size: int | tuple[int, ...] | None = None,
172
    ) -> AArray:
173
        """
174
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
175

176
        Parameters
177
        ----------
178
        loc
179
            Mean of the distribution.
180
        scale
181
            Standard deviation of the distribution.
182
        size
183
            Output shape.
184

185
        Returns
186
        -------
187
            Array of samples from the normal distribution.
188
        """
189
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
190

191
    def poisson(
8✔
192
        self, lam: float | AArray, size: int | tuple[int, ...] | None = None
193
    ) -> AArray:
194
        """
195
        Draw samples from a Poisson distribution.
196

197
        Parameters
198
        ----------
199
        lam
200
            Expected number of events.
201
        size
202
            Output shape.
203

204
        Returns
205
        -------
206
            Array of samples from the Poisson distribution.
207
        """
208
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
209

210
    def standard_normal(
8✔
211
        self,
212
        size: int | tuple[int, ...] | None = None,
213
        dtype: DTypeLike | None = None,
214
        out: AArray | None = None,
215
    ) -> AArray:
216
        """
217
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
218

219
        Parameters
220
        ----------
221
        size
222
            Output shape.
223
        dtype
224
            Desired data type.
225
        out
226
            Optional output array.
227

228
        Returns
229
        -------
230
            Array of samples from the standard normal distribution.
231
        """
232
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
233
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
234

235
    def uniform(
8✔
236
        self,
237
        low: float | AArray = 0.0,
238
        high: float | AArray = 1.0,
239
        size: int | tuple[int, ...] | None = None,
240
    ) -> AArray:
241
        """
242
        Draw samples from a Uniform distribution.
243

244
        Parameters
245
        ----------
246
        low
247
            Lower bound of the distribution.
248
        high
249
            Upper bound of the distribution.
250
        size
251
            Output shape.
252

253
        Returns
254
        -------
255
            Array of samples from the uniform distribution.
256
        """
257
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
258

259

260
class XPAdditions:
8✔
261
    """
262
    Additional functions missing from both array-api-strict and array-api-extra.
263

264
    This class provides wrappers for common array operations such as integration,
265
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
266
    and array-api-strict backends.
267

268
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
269
    for details.
270
    """
271

272
    xp: ModuleType
8✔
273
    backend: str
8✔
274

275
    def __init__(self, xp: ModuleType) -> None:
8✔
276
        """
277
        Initialize XPAdditions with the given array namespace.
278

279
        Parameters
280
        ----------
281
        xp
282
            The array namespace module.
283
        """
284
        self.xp = xp
8✔
285

286
    def trapezoid(
8✔
287
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
288
    ) -> AnyArray:
289
        """
290
        Integrate along the given axis using the composite trapezoidal rule.
291

292
        Parameters
293
        ----------
294
        y
295
            Input array to integrate.
296
        x
297
            Sample points corresponding to y.
298
        dx
299
            Spacing between sample points.
300
        axis
301
            Axis along which to integrate.
302

303
        Returns
304
        -------
305
            Integrated result.
306

307
        Raises
308
        ------
309
        NotImplementedError
310
            If the array backend is not supported.
311

312
        Notes
313
        -----
314
        See https://github.com/glass-dev/glass/issues/646
315
        """
316
        if self.xp.__name__ == "jax.numpy":
8✔
317
            import glass.jax  # noqa: PLC0415
8✔
318

319
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
320

321
        if self.xp.__name__ == "numpy":
8✔
322
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
323

324
        if self.xp.__name__ == "array_api_strict":
8✔
325
            np = import_numpy(self.xp.__name__)
8✔
326

327
            # Using design principle of scipy (i.e. copy, use np, copy back)
328
            y_np = np.asarray(y, copy=True)
8✔
329
            x_np = np.asarray(x, copy=True)
8✔
330
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
331
            return self.xp.asarray(result_np, copy=True)
8✔
332

UNCOV
333
        msg = "the array backend in not supported"
×
UNCOV
334
        raise NotImplementedError(msg)
×
335

336
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
337
        """
338
        Compute the set union of two 1D arrays.
339

340
        Parameters
341
        ----------
342
        ar1
343
            First input array.
344
        ar2
345
            Second input array.
346

347
        Returns
348
        -------
349
            The union of the two arrays.
350

351
        Raises
352
        ------
353
        NotImplementedError
354
            If the array backend is not supported.
355

356
        Notes
357
        -----
358
        See https://github.com/glass-dev/glass/issues/647
359
        """
360
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
361
            return self.xp.union1d(ar1, ar2)
8✔
362

363
        if self.xp.__name__ == "array_api_strict":
8✔
364
            np = import_numpy(self.xp.__name__)
8✔
365

366
            # Using design principle of scipy (i.e. copy, use np, copy back)
367
            ar1_np = np.asarray(ar1, copy=True)
8✔
368
            ar2_np = np.asarray(ar2, copy=True)
8✔
369
            result_np = np.union1d(ar1_np, ar2_np)
8✔
370
            return self.xp.asarray(result_np, copy=True)
8✔
371

UNCOV
372
        msg = "the array backend in not supported"
×
UNCOV
373
        raise NotImplementedError(msg)
×
374

375
    def interp(  # noqa: PLR0913
8✔
376
        self,
377
        x: AnyArray,
378
        x_points: AnyArray,
379
        y_points: AnyArray,
380
        left: float | None = None,
381
        right: float | None = None,
382
        period: float | None = None,
383
    ) -> AnyArray:
384
        """
385
        One-dimensional linear interpolation for monotonically increasing sample points.
386

387
        Parameters
388
        ----------
389
        x
390
            The x-coordinates at which to evaluate the interpolated values.
391
        x_points
392
            The x-coordinates of the data points.
393
        y_points
394
            The y-coordinates of the data points.
395
        left
396
            Value to return for x < x_points[0].
397
        right
398
            Value to return for x > x_points[-1].
399
        period
400
            Period for periodic interpolation.
401

402
        Returns
403
        -------
404
            Interpolated values.
405

406
        Raises
407
        ------
408
        NotImplementedError
409
            If the array backend is not supported.
410

411
        Notes
412
        -----
413
        See https://github.com/glass-dev/glass/issues/650
414
        """
415
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
416
            return self.xp.interp(
8✔
417
                x, x_points, y_points, left=left, right=right, period=period
418
            )
419

420
        if self.xp.__name__ == "array_api_strict":
8✔
421
            np = import_numpy(self.xp.__name__)
8✔
422

423
            # Using design principle of scipy (i.e. copy, use np, copy back)
424
            x_np = np.asarray(x, copy=True)
8✔
425
            x_points_np = np.asarray(x_points, copy=True)
8✔
426
            y_points_np = np.asarray(y_points, copy=True)
8✔
427
            result_np = np.interp(
8✔
428
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
429
            )
430
            return self.xp.asarray(result_np, copy=True)
8✔
431

UNCOV
432
        msg = "the array backend in not supported"
×
UNCOV
433
        raise NotImplementedError(msg)
×
434

435
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
436
        """
437
        Return the gradient of an N-dimensional array.
438

439
        Parameters
440
        ----------
441
        f
442
            Input array.
443

444
        Returns
445
        -------
446
            Gradient of the input array.
447

448
        Raises
449
        ------
450
        NotImplementedError
451
            If the array backend is not supported.
452

453
        Notes
454
        -----
455
        See https://github.com/glass-dev/glass/issues/648
456
        """
457
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
458
            return self.xp.gradient(f)
8✔
459

460
        if self.xp.__name__ == "array_api_strict":
8✔
461
            np = import_numpy(self.xp.__name__)
8✔
462

463
            # Using design principle of scipy (i.e. copy, use np, copy back)
464
            f_np = np.asarray(f, copy=True)
8✔
465
            result_np = np.gradient(f_np)
8✔
466
            return self.xp.asarray(result_np, copy=True)
8✔
467

UNCOV
468
        msg = "the array backend in not supported"
×
UNCOV
469
        raise NotImplementedError(msg)
×
470

471
    def linalg_lstsq(
8✔
472
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
473
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
474
        """
475
        Solve a linear least squares problem.
476

477
        Parameters
478
        ----------
479
        a
480
            Coefficient matrix.
481
        b
482
            Ordinate or "dependent variable" values.
483
        rcond
484
            Cut-off ratio for small singular values.
485

486
        Returns
487
        -------
488
        x
489
            Least-squares solution. If b is two-dimensional, the solutions are in the K
490
            columns of x.
491

492
        residuals
493
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
494
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
495
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
496

497
        rank
498
            Rank of matrix a.
499

500
        s
501
            Singular values of a.
502

503
        Raises
504
        ------
505
        NotImplementedError
506
            If the array backend is not supported.
507

508
        Notes
509
        -----
510
        See https://github.com/glass-dev/glass/issues/649
511
        """
512
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
513
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
514

515
        if self.xp.__name__ == "array_api_strict":
8✔
516
            np = import_numpy(self.xp.__name__)
8✔
517

518
            # Using design principle of scipy (i.e. copy, use np, copy back)
519
            a_np = np.asarray(a, copy=True)
8✔
520
            b_np = np.asarray(b, copy=True)
8✔
521
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
522
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
523

UNCOV
524
        msg = "the array backend in not supported"
×
UNCOV
525
        raise NotImplementedError(msg)
×
526

527
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
528
        """
529
        Evaluate the Einstein summation convention on the operands.
530

531
        Parameters
532
        ----------
533
        subscripts
534
            Specifies the subscripts for summation.
535
        *operands
536
            Arrays to be summed.
537

538
        Returns
539
        -------
540
            Result of the Einstein summation.
541

542
        Raises
543
        ------
544
        NotImplementedError
545
            If the array backend is not supported.
546

547
        Notes
548
        -----
549
        See https://github.com/glass-dev/glass/issues/657
550
        """
551
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
552
            return self.xp.einsum(subscripts, *operands)
8✔
553

554
        if self.xp.__name__ == "array_api_strict":
8✔
555
            np = import_numpy(self.xp.__name__)
8✔
556

557
            # Using design principle of scipy (i.e. copy, use np, copy back)
558
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
559
            result_np = np.einsum(subscripts, *operands_np)
8✔
560
            return self.xp.asarray(result_np, copy=True)
8✔
561

UNCOV
562
        msg = "the array backend in not supported"
×
UNCOV
563
        raise NotImplementedError(msg)
×
564

565
    def apply_along_axis(
8✔
566
        self,
567
        func1d: Callable[..., Any],
568
        axis: int,
569
        arr: AnyArray,
570
        *args: object,
571
        **kwargs: object,
572
    ) -> AnyArray:
573
        """
574
        Apply a function to 1-D slices along the given axis.
575

576
        Parameters
577
        ----------
578
        func1d
579
            Function to apply to 1-D slices.
580
        axis
581
            Axis along which to apply the function.
582
        arr
583
            Input array.
584
        *args
585
            Additional positional arguments to pass to func1d.
586
        **kwargs
587
            Additional keyword arguments to pass to func1d.
588

589
        Returns
590
        -------
591
            Result of applying the function along the axis.
592

593
        Raises
594
        ------
595
        NotImplementedError
596
            If the array backend is not supported.
597

598
        Notes
599
        -----
600
        See https://github.com/glass-dev/glass/issues/651
601

602
        """
603
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
604
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
605

606
        if self.xp.__name__ == "array_api_strict":
8✔
607
            # Import here to prevent users relying on numpy unless in this instance
608
            np = import_numpy(self.xp.__name__)
8✔
609

610
            return self.xp.asarray(
8✔
611
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
612
            )
613

UNCOV
614
        msg = "the array backend in not supported"
×
UNCOV
615
        raise NotImplementedError(msg)
×
616

617
    def vectorize(
8✔
618
        self,
619
        pyfunc: Callable[..., Any],
620
        otypes: tuple[type[float]],
621
    ) -> Callable[..., Any]:
622
        """
623
        Returns an object that acts like pyfunc, but takes arrays as input.
624

625
        Parameters
626
        ----------
627
        pyfunc
628
            Python function to vectorize.
629
        otypes
630
            Output types.
631

632
        Returns
633
        -------
634
            Vectorized function.
635

636
        Raises
637
        ------
638
        NotImplementedError
639
            If the array backend is not supported.
640

641
        Notes
642
        -----
643
        See https://github.com/glass-dev/glass/issues/671
644
        """
645
        if self.xp.__name__ == "numpy":
8✔
646
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
647

648
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
649
            # Import here to prevent users relying on numpy unless in this instance
650
            np = import_numpy(self.xp.__name__)
8✔
651

652
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
653

UNCOV
654
        msg = "the array backend in not supported"
×
UNCOV
655
        raise NotImplementedError(msg)
×
656

657
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
658
        """
659
        Convert angles from degrees to radians.
660

661
        Parameters
662
        ----------
663
        deg_arr
664
            Array of angles in degrees.
665

666
        Returns
667
        -------
668
            Array of angles in radians.
669

670
        Raises
671
        ------
672
        NotImplementedError
673
            If the array backend is not supported.
674
        """
675
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
676
            return self.xp.radians(deg_arr)
8✔
677

678
        if self.xp.__name__ == "array_api_strict":
8✔
679
            np = import_numpy(self.xp.__name__)
8✔
680

681
            return self.xp.asarray(np.radians(deg_arr))
8✔
682

UNCOV
683
        msg = "the array backend in not supported"
×
UNCOV
684
        raise NotImplementedError(msg)
×
685

686
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
687
        """
688
        Convert angles from radians to degrees.
689

690
        Parameters
691
        ----------
692
        deg_arr
693
            Array of angles in radians.
694

695
        Returns
696
        -------
697
            Array of angles in degrees.
698

699
        Raises
700
        ------
701
        NotImplementedError
702
            If the array backend is not supported.
703
        """
704
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
705
            return self.xp.degrees(deg_arr)
8✔
706

707
        if self.xp.__name__ == "array_api_strict":
8✔
708
            np = import_numpy(self.xp.__name__)
8✔
709

710
            return self.xp.asarray(np.degrees(deg_arr))
8✔
711

UNCOV
712
        msg = "the array backend in not supported"
×
UNCOV
713
        raise NotImplementedError(msg)
×
714

715
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
716
        """
717
        Wrapper for numpy.ndindex.
718

719
        See relevant docs for details:
720
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
721

722
        Raises
723
        ------
724
        NotImplementedError
725
            If the array backend is not supported.
726

727
        """
728
        if self.xp.__name__ == "numpy":
8✔
729
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
730

731
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
732
            np = import_numpy(self.xp.__name__)
8✔
733

734
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
735

UNCOV
736
        msg = "the array backend in not supported"
×
UNCOV
737
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc