• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18599168943

17 Oct 2025 04:45PM UTC coverage: 94.086% (-0.4%) from 94.467%
18599168943

Pull #680

github

web-flow
Merge 114088965 into c5b0599c8
Pull Request #680: gh-679: only import numpy when it is absolutely necessary and add doc…

194 of 196 branches covered (98.98%)

Branch coverage included in aggregate %.

42 of 45 new or added lines in 1 file covered. (93.33%)

5 existing lines in 1 file now uncovered.

1397 of 1495 relevant lines covered (93.44%)

7.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.25
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from array_api_strict._array_object import Array as AArray
27
    from jaxtyping import Array as JAXArray
28
    from numpy.typing import DTypeLike, NDArray
29

30
    import glass.jax
31

32
    Size: TypeAlias = int | tuple[int, ...] | None
33

34
    AnyArray: TypeAlias = NDArray[Any] | JAXArray | AArray
35
    ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
36
    DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
37
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
38

39

40
def import_numpy(backend: str, function_name: str) -> ModuleType:
8✔
41
    """
42
    Import the NumPy module, raising a helpful error if NumPy is not installed.
43

44
    Parameters
45
    ----------
46
    backend : str
47
        The name of the backend requested by the user.
48
    function_name : str
49
        The name of the function which is not implemented in the user's chosen backend.
50

51
    Returns
52
    -------
53
    ModuleType
54
        The NumPy module.
55

56
    Raises
57
    ------
58
    ModuleNotFoundError
59
        If NumPy is not found in the user's environment.
60

61
    Notes
62
    -----
63
    This is useful for explaining to the user why NumPy is required when their chosen
64
    backend does not implement a needed function.
65
    """
66
    try:
8✔
67
        import numpy  # noqa: ICN001, PLC0415
8✔
68

NEW
69
    except ModuleNotFoundError as err:
×
NEW
70
        msg = (
×
71
            "numpy is required here as "
72
            + backend
73
            + " does not implement "
74
            + function_name
75
        )
NEW
76
        raise ModuleNotFoundError(msg) from err
×
77
    else:
78
        return numpy
8✔
79

80

81
def get_namespace(*arrays: AnyArray) -> ModuleType:
8✔
82
    """
83
    Return the array library (namespace) of input arrays if they all belong to the same
84
    library.
85

86
    Parameters
87
    ----------
88
    *arrays : AnyArray
89
        Arrays whose namespace is to be determined.
90

91
    Returns
92
    -------
93
    ModuleType
94
        The array namespace module.
95

96
    Raises
97
    ------
98
    ValueError
99
        If input arrays do not all belong to the same array library.
100
    """
101
    namespace = arrays[0].__array_namespace__()
8✔
102
    if any(
8✔
103
        array.__array_namespace__() != namespace
104
        for array in arrays
105
        if array is not None
106
    ):
107
        msg = "input arrays should belong to the same array library"
8✔
108
        raise ValueError(msg)
8✔
109

110
    return namespace
8✔
111

112

113
def rng_dispatcher(
8✔
114
    array: AnyArray,
115
) -> np.random.Generator | glass.jax.Generator | Generator:
116
    """
117
    Dispatch a random number generator based on the provided array's backend.
118

119
    Parameters
120
    ----------
121
    array : AnyArray
122
        The array whose backend determines the RNG.
123

124
    Returns
125
    -------
126
    np.random.Generator | glass.jax.Generator | Generator
127
        The appropriate random number generator for the array's backend.
128

129
    Raises
130
    ------
131
    NotImplementedError
132
        If the array backend is not supported.
133
    """
134
    xp = get_namespace(array)
8✔
135
    backend = xp.__name__
8✔
136

137
    if backend == "jax.numpy":
8✔
138
        import glass.jax  # noqa: PLC0415
8✔
139

140
        return glass.jax.Generator(seed=42)
8✔
141

142
    if backend == "numpy":
8✔
143
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
144

145
    if backend == "array_api_strict":
8✔
146
        return Generator(seed=42)
8✔
147

148
    msg = "the array backend in not supported"
×
149
    raise NotImplementedError(msg)
×
150

151

152
class Generator:
8✔
153
    """
154
    NumPy random number generator returning array_api_strict Array.
155

156
    This class wraps NumPy's random number generator and returns arrays compatible
157
    with array_api_strict.
158
    """
159

160
    __slots__ = ("axp", "nxp", "rng")
8✔
161

162
    def __init__(
8✔
163
        self,
164
        seed: int | bool | AArray | None = None,  # noqa: FBT001
165
    ) -> None:
166
        """
167
        Initialize the Generator.
168

169
        Parameters
170
        ----------
171
        seed : int | bool | NDArray[np.int_ | np.bool] | None, optional
172
            Seed for the random number generator.
173
        """
174
        import array_api_strict  # noqa: PLC0415
8✔
175
        import numpy as np  # noqa: PLC0415
8✔
176

177
        self.axp = array_api_strict
8✔
178
        self.nxp = np
8✔
179
        self.rng = self.nxp.random.default_rng(seed=seed)
8✔
180

181
    def random(
8✔
182
        self,
183
        size: Size = None,
184
        dtype: DTypeLike | None = None,
185
        out: AArray | None = None,
186
    ) -> AArray:
187
        """
188
        Return random floats in the half-open interval [0.0, 1.0).
189

190
        Parameters
191
        ----------
192
        size : Size, optional
193
            Output shape.
194
        dtype : DTypeLike | None, optional
195
            Desired data type.
196
        out : NDArray[Any] | None, optional
197
            Optional output array.
198

199
        Returns
200
        -------
201
        AArray
202
            Array of random floats.
203
        """
204
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
205
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
206

207
    def normal(
8✔
208
        self,
209
        loc: float | AArray = 0.0,
210
        scale: float | AArray = 1.0,
211
        size: Size = None,
212
    ) -> AArray:
213
        """
214
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
215

216
        Parameters
217
        ----------
218
        loc : float | NDArray[np.floating], optional
219
            Mean of the distribution.
220
        scale : float | NDArray[np.floating], optional
221
            Standard deviation of the distribution.
222
        size : Size, optional
223
            Output shape.
224

225
        Returns
226
        -------
227
        AArray
228
            Array of samples from the normal distribution.
229
        """
230
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
231

232
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
233
        """
234
        Draw samples from a Poisson distribution.
235

236
        Parameters
237
        ----------
238
        lam : float | NDArray[np.floating]
239
            Expected number of events.
240
        size : Size, optional
241
            Output shape.
242

243
        Returns
244
        -------
245
        AArray
246
            Array of samples from the Poisson distribution.
247
        """
248
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
249

250
    def standard_normal(
8✔
251
        self,
252
        size: Size = None,
253
        dtype: DTypeLike | None = None,
254
        out: AArray | None = None,
255
    ) -> AArray:
256
        """
257
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
258

259
        Parameters
260
        ----------
261
        size : Size, optional
262
            Output shape.
263
        dtype : DTypeLike | None, optional
264
            Desired data type.
265
        out : NDArray[Any] | None, optional
266
            Optional output array.
267

268
        Returns
269
        -------
270
        AArray
271
            Array of samples from the standard normal distribution.
272
        """
273
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
274
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
275

276
    def uniform(
8✔
277
        self,
278
        low: float | AArray = 0.0,
279
        high: float | AArray = 1.0,
280
        size: Size = None,
281
    ) -> AArray:
282
        """
283
        Draw samples from a Uniform distribution.
284

285
        Parameters
286
        ----------
287
        low : float | NDArray[np.floating], optional
288
            Lower bound of the distribution.
289
        high : float | NDArray[np.floating], optional
290
            Upper bound of the distribution.
291
        size : Size, optional
292
            Output shape.
293

294
        Returns
295
        -------
296
        AArray
297
            Array of samples from the uniform distribution.
298
        """
299
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
300

301

302
class XPAdditions:
8✔
303
    """
304
    Additional functions missing from both array-api-strict and array-api-extra.
305

306
    This class provides wrappers for common array operations such as integration,
307
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
308
    and array-api-strict backends.
309

310
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
311
    for details.
312
    """
313

314
    xp: ModuleType
8✔
315
    backend: str
8✔
316

317
    def __init__(self, xp: ModuleType) -> None:
8✔
318
        """
319
        Initialize XPAdditions with the given array namespace.
320

321
        Parameters
322
        ----------
323
        xp : ModuleType
324
            The array namespace module.
325
        """
326
        self.xp = xp
8✔
327
        self.backend = xp.__name__
8✔
328

329
    def trapezoid(
8✔
330
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
331
    ) -> AnyArray:
332
        """
333
        Integrate along the given axis using the composite trapezoidal rule.
334

335
        Parameters
336
        ----------
337
        y : AnyArray
338
            Input array to integrate.
339
        x : AnyArray, optional
340
            Sample points corresponding to y.
341
        dx : float, optional
342
            Spacing between sample points.
343
        axis : int, optional
344
            Axis along which to integrate.
345

346
        Returns
347
        -------
348
        AnyArray
349
            Integrated result.
350

351
        Raises
352
        ------
353
        NotImplementedError
354
            If the array backend is not supported.
355

356
        Notes
357
        -----
358
        See https://github.com/glass-dev/glass/issues/646
359
        """
360
        self.backend = self.xp.__name__
8✔
361
        if self.backend == "jax.numpy":
8✔
362
            import glass.jax  # noqa: PLC0415
8✔
363

364
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
365

366
        if self.backend == "numpy":
8✔
367
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
368

369
        if self.backend == "array_api_strict":
8✔
370
            np = import_numpy(self.backend, "trapezoid")
8✔
371

372
            # Using design principle of scipy (i.e. copy, use np, copy back)
373
            y_np = np.asarray(y, copy=True)
8✔
374
            x_np = np.asarray(x, copy=True)
8✔
375
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
376
            return self.xp.asarray(result_np, copy=True)
8✔
377

378
        msg = "the array backend in not supported"
×
379
        raise NotImplementedError(msg)
×
380

381
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
382
        """
383
        Compute the set union of two 1D arrays.
384

385
        Parameters
386
        ----------
387
        ar1 : AnyArray
388
            First input array.
389
        ar2 : AnyArray
390
            Second input array.
391

392
        Returns
393
        -------
394
        AnyArray
395
            The union of the two arrays.
396

397
        Raises
398
        ------
399
        NotImplementedError
400
            If the array backend is not supported.
401

402
        Notes
403
        -----
404
        See https://github.com/glass-dev/glass/issues/647
405
        """
406
        if self.backend in {"numpy", "jax.numpy"}:
8✔
407
            return self.xp.union1d(ar1, ar2)
8✔
408

409
        if self.backend == "array_api_strict":
8✔
410
            np = import_numpy(self.backend, "union1d")
8✔
411

412
            # Using design principle of scipy (i.e. copy, use np, copy back)
413
            ar1_np = np.asarray(ar1, copy=True)
8✔
414
            ar2_np = np.asarray(ar2, copy=True)
8✔
415
            result_np = np.union1d(ar1_np, ar2_np)
8✔
416
            return self.xp.asarray(result_np, copy=True)
8✔
417

418
        msg = "the array backend in not supported"
×
419
        raise NotImplementedError(msg)
×
420

421
    def interp(  # noqa: PLR0913
8✔
422
        self,
423
        x: AnyArray,
424
        x_points: AnyArray,
425
        y_points: AnyArray,
426
        left: float | None = None,
427
        right: float | None = None,
428
        period: float | None = None,
429
    ) -> AnyArray:
430
        """
431
        One-dimensional linear interpolation for monotonically increasing sample points.
432

433
        Parameters
434
        ----------
435
        x : AnyArray
436
            The x-coordinates at which to evaluate the interpolated values.
437
        x_points : AnyArray
438
            The x-coordinates of the data points.
439
        y_points : AnyArray
440
            The y-coordinates of the data points.
441
        left : float | None, optional
442
            Value to return for x < x_points[0].
443
        right : float | None, optional
444
            Value to return for x > x_points[-1].
445
        period : float | None, optional
446
            Period for periodic interpolation.
447

448
        Returns
449
        -------
450
        AnyArray
451
            Interpolated values.
452

453
        Raises
454
        ------
455
        NotImplementedError
456
            If the array backend is not supported.
457

458
        Notes
459
        -----
460
        See https://github.com/glass-dev/glass/issues/650
461
        """
462
        if self.backend in {"numpy", "jax.numpy"}:
8✔
463
            return self.xp.interp(
8✔
464
                x, x_points, y_points, left=left, right=right, period=period
465
            )
466

467
        if self.backend == "array_api_strict":
8✔
468
            np = import_numpy(self.backend, "interp")
8✔
469

470
            # Using design principle of scipy (i.e. copy, use np, copy back)
471
            x_np = np.asarray(x, copy=True)
8✔
472
            x_points_np = np.asarray(x_points, copy=True)
8✔
473
            y_points_np = np.asarray(y_points, copy=True)
8✔
474
            result_np = np.interp(
8✔
475
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
476
            )
477
            return self.xp.asarray(result_np, copy=True)
8✔
478

479
        msg = "the array backend in not supported"
×
480
        raise NotImplementedError(msg)
×
481

482
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
483
        """
484
        Return the gradient of an N-dimensional array.
485

486
        Parameters
487
        ----------
488
        f : AnyArray
489
            Input array.
490

491
        Returns
492
        -------
493
        AnyArray
494
            Gradient of the input array.
495

496
        Raises
497
        ------
498
        NotImplementedError
499
            If the array backend is not supported.
500

501
        Notes
502
        -----
503
        See https://github.com/glass-dev/glass/issues/648
504
        """
505
        if self.backend in {"numpy", "jax.numpy"}:
8✔
506
            return self.xp.gradient(f)
8✔
507

508
        if self.backend == "array_api_strict":
8✔
509
            np = import_numpy(self.backend, "gradient")
8✔
510

511
            # Using design principle of scipy (i.e. copy, use np, copy back)
512
            f_np = np.asarray(f, copy=True)
8✔
513
            result_np = np.gradient(f_np)
8✔
514
            return self.xp.asarray(result_np, copy=True)
8✔
515

516
        msg = "the array backend in not supported"
×
517
        raise NotImplementedError(msg)
×
518

519
    def linalg_lstsq(
8✔
520
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
521
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
522
        """
523
        Solve a linear least squares problem.
524

525
        Parameters
526
        ----------
527
        a : AnyArray
528
            Coefficient matrix.
529
        b : AnyArray
530
            Ordinate or "dependent variable" values.
531
        rcond : float | None, optional
532
            Cut-off ratio for small singular values.
533

534
        Returns
535
        -------
536
        x : {(N,), (N, K)} AnyArray
537
            Least-squares solution. If b is two-dimensional, the solutions are in the K
538
            columns of x.
539

540
        residuals : {(1,), (K,), (0,)} AnyArray
541
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
542
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
543
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
544

545
        rank : int
546
            Rank of matrix a.
547

548
        s : (min(M, N),) AnyArray
549
            Singular values of a.
550

551
        Raises
552
        ------
553
        NotImplementedError
554
            If the array backend is not supported.
555

556
        Notes
557
        -----
558
        See https://github.com/glass-dev/glass/issues/649
559
        """
560
        if self.backend in {"numpy", "jax.numpy"}:
8✔
561
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
562

563
        if self.backend == "array_api_strict":
8✔
564
            np = import_numpy(self.backend, "linalg.lstsq")
8✔
565

566
            # Using design principle of scipy (i.e. copy, use np, copy back)
567
            a_np = np.asarray(a, copy=True)
8✔
568
            b_np = np.asarray(b, copy=True)
8✔
569
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
570
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
571

572
        msg = "the array backend in not supported"
×
573
        raise NotImplementedError(msg)
×
574

575
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
576
        """
577
        Evaluate the Einstein summation convention on the operands.
578

579
        Parameters
580
        ----------
581
        subscripts : str
582
            Specifies the subscripts for summation.
583
        *operands : AnyArray
584
            Arrays to be summed.
585

586
        Returns
587
        -------
588
        AnyArray
589
            Result of the Einstein summation.
590

591
        Raises
592
        ------
593
        NotImplementedError
594
            If the array backend is not supported.
595

596
        Notes
597
        -----
598
        See https://github.com/glass-dev/glass/issues/657
599
        """
600
        if self.backend in {"numpy", "jax.numpy"}:
8✔
601
            return self.xp.einsum(subscripts, *operands)
8✔
602

603
        if self.backend == "array_api_strict":
8✔
604
            np = import_numpy(self.backend, "einsum")
8✔
605

606
            # Using design principle of scipy (i.e. copy, use np, copy back)
607
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
608
            result_np = np.einsum(subscripts, *operands_np)
8✔
609
            return self.xp.asarray(result_np, copy=True)
8✔
610

611
        msg = "the array backend in not supported"
×
612
        raise NotImplementedError(msg)
×
613

614
    def apply_along_axis(
8✔
615
        self,
616
        func1d: Callable[..., Any],
617
        axis: int,
618
        arr: AnyArray,
619
        *args: object,
620
        **kwargs: object,
621
    ) -> AnyArray:
622
        """
623
        Apply a function to 1-D slices along the given axis.
624

625
        Parameters
626
        ----------
627
        func1d : Callable[..., Any]
628
            Function to apply to 1-D slices.
629
        axis : int
630
            Axis along which to apply the function.
631
        arr : AnyArray
632
            Input array.
633
        *args : object
634
            Additional positional arguments to pass to func1d.
635
        **kwargs : object
636
            Additional keyword arguments to pass to func1d.
637

638
        Returns
639
        -------
640
        AnyArray
641
            Result of applying the function along the axis.
642

643
        Raises
644
        ------
645
        NotImplementedError
646
            If the array backend is not supported.
647

648
        Notes
649
        -----
650
        See https://github.com/glass-dev/glass/issues/651
651

652
        """
653
        if self.backend in {"numpy", "jax.numpy"}:
8✔
654
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
655

656
        if self.backend == "array_api_strict":
8✔
657
            # Import here to prevent users relying on numpy unless in this instance
658
            np = import_numpy(self.backend, "apply_along_axis")
8✔
659

660
            return self.xp.asarray(
8✔
661
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
662
            )
663

664
        msg = "the array backend in not supported"
×
665
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc