• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 18653204061

20 Oct 2025 01:15PM UTC coverage: 94.342%. Remained the same
18653204061

Pull #684

github

web-flow
Merge e5ba5567f into 7fd0c3a42
Pull Request #684: gh-682: create an `isort` section for Array API

194 of 196 branches covered (98.98%)

Branch coverage included in aggregate %.

3 of 3 new or added lines in 3 files covered. (100.0%)

6 existing lines in 1 file now uncovered.

1390 of 1483 relevant lines covered (93.73%)

7.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.58
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
20

21
import array_api_strict
8✔
22

23
if TYPE_CHECKING:
24
    from collections.abc import Callable
25
    from types import ModuleType
26

27
    import numpy as np
28
    from jaxtyping import Array as JAXArray
29
    from numpy.typing import DTypeLike, NDArray
30

31
    from array_api_strict._array_object import Array as AArray
32

33
    import glass.jax
34

35
    Size: TypeAlias = int | tuple[int, ...] | None
36

37
    AnyArray: TypeAlias = NDArray[Any] | JAXArray | AArray
38
    ComplexArray: TypeAlias = NDArray[np.complex128] | JAXArray | AArray
39
    DoubleArray: TypeAlias = NDArray[np.double] | JAXArray | AArray
40
    FloatArray: TypeAlias = NDArray[np.float64] | JAXArray | AArray
41

42

43
class CompatibleBackendNotFoundError(Exception):
8✔
44
    """
45
    Exception raised when an array library backend that
46
    implements a requested function, is not found.
47
    """
48

49
    def __init__(self, missing_backend: str, users_backend: str) -> None:
8✔
UNCOV
50
        self.message = (
×
51
            f"{missing_backend} is required here as some functions required by GLASS "
52
            f"are not supported by {users_backend}"
53
        )
UNCOV
54
        super().__init__(self.message)
×
55

56

57
def import_numpy(backend: str) -> ModuleType:
8✔
58
    """
59
    Import the NumPy module, raising a helpful error if NumPy is not installed.
60

61
    Parameters
62
    ----------
63
    backend
64
        The name of the backend requested by the user.
65

66
    Returns
67
    -------
68
        The NumPy module.
69

70
    Raises
71
    ------
72
    ModuleNotFoundError
73
        If NumPy is not found in the user's environment.
74

75
    Notes
76
    -----
77
    This is useful for explaining to the user why NumPy is required when their chosen
78
    backend does not implement a needed function.
79
    """
80
    try:
8✔
81
        import numpy  # noqa: ICN001, PLC0415
8✔
82

UNCOV
83
    except ModuleNotFoundError as err:
×
UNCOV
84
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
85
    else:
86
        return numpy
8✔
87

88

89
def get_namespace(*arrays: AnyArray) -> ModuleType:
8✔
90
    """
91
    Return the array library (namespace) of input arrays if they all belong to the same
92
    library.
93

94
    Parameters
95
    ----------
96
    *arrays
97
        Arrays whose namespace is to be determined.
98

99
    Returns
100
    -------
101
        The array namespace module.
102

103
    Raises
104
    ------
105
    ValueError
106
        If input arrays do not all belong to the same array library.
107
    """
108
    namespace = arrays[0].__array_namespace__()
8✔
109
    if any(
8✔
110
        array.__array_namespace__() != namespace
111
        for array in arrays
112
        if array is not None
113
    ):
114
        msg = "input arrays should belong to the same array library"
8✔
115
        raise ValueError(msg)
8✔
116

117
    return namespace
8✔
118

119

120
def rng_dispatcher(
8✔
121
    array: AnyArray,
122
) -> np.random.Generator | glass.jax.Generator | Generator:
123
    """
124
    Dispatch a random number generator based on the provided array's backend.
125

126
    Parameters
127
    ----------
128
    array
129
        The array whose backend determines the RNG.
130

131
    Returns
132
    -------
133
        The appropriate random number generator for the array's backend.
134

135
    Raises
136
    ------
137
    NotImplementedError
138
        If the array backend is not supported.
139
    """
140
    xp = get_namespace(array)
8✔
141

142
    if xp.__name__ == "jax.numpy":
8✔
143
        import glass.jax  # noqa: PLC0415
8✔
144

145
        return glass.jax.Generator(seed=42)
8✔
146

147
    if xp.__name__ == "numpy":
8✔
148
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
149

150
    if xp.__name__ == "array_api_strict":
8✔
151
        return Generator(seed=42)
8✔
152

UNCOV
153
    msg = "the array backend in not supported"
×
UNCOV
154
    raise NotImplementedError(msg)
×
155

156

157
class Generator:
8✔
158
    """
159
    NumPy random number generator returning array_api_strict Array.
160

161
    This class wraps NumPy's random number generator and returns arrays compatible
162
    with array_api_strict.
163
    """
164

165
    __slots__ = ("axp", "nxp", "rng")
8✔
166

167
    def __init__(
8✔
168
        self,
169
        seed: int | bool | AArray | None = None,  # noqa: FBT001
170
    ) -> None:
171
        """
172
        Initialize the Generator.
173

174
        Parameters
175
        ----------
176
        seed
177
            Seed for the random number generator.
178
        """
179
        import numpy as np  # noqa: PLC0415
8✔
180

181
        self.axp = array_api_strict
8✔
182
        self.nxp = np
8✔
183
        self.rng = np.random.default_rng(seed=seed)
8✔
184

185
    def random(
8✔
186
        self,
187
        size: Size = None,
188
        dtype: DTypeLike | None = None,
189
        out: AArray | None = None,
190
    ) -> AArray:
191
        """
192
        Return random floats in the half-open interval [0.0, 1.0).
193

194
        Parameters
195
        ----------
196
        size
197
            Output shape.
198
        dtype
199
            Desired data type.
200
        out
201
            Optional output array.
202

203
        Returns
204
        -------
205
            Array of random floats.
206
        """
207
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
208
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
209

210
    def normal(
8✔
211
        self,
212
        loc: float | AArray = 0.0,
213
        scale: float | AArray = 1.0,
214
        size: Size = None,
215
    ) -> AArray:
216
        """
217
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
218

219
        Parameters
220
        ----------
221
        loc
222
            Mean of the distribution.
223
        scale
224
            Standard deviation of the distribution.
225
        size
226
            Output shape.
227

228
        Returns
229
        -------
230
            Array of samples from the normal distribution.
231
        """
232
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
233

234
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
235
        """
236
        Draw samples from a Poisson distribution.
237

238
        Parameters
239
        ----------
240
        lam
241
            Expected number of events.
242
        size
243
            Output shape.
244

245
        Returns
246
        -------
247
            Array of samples from the Poisson distribution.
248
        """
249
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
250

251
    def standard_normal(
8✔
252
        self,
253
        size: Size = None,
254
        dtype: DTypeLike | None = None,
255
        out: AArray | None = None,
256
    ) -> AArray:
257
        """
258
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
259

260
        Parameters
261
        ----------
262
        size
263
            Output shape.
264
        dtype
265
            Desired data type.
266
        out
267
            Optional output array.
268

269
        Returns
270
        -------
271
            Array of samples from the standard normal distribution.
272
        """
273
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
274
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
275

276
    def uniform(
8✔
277
        self,
278
        low: float | AArray = 0.0,
279
        high: float | AArray = 1.0,
280
        size: Size = None,
281
    ) -> AArray:
282
        """
283
        Draw samples from a Uniform distribution.
284

285
        Parameters
286
        ----------
287
        low
288
            Lower bound of the distribution.
289
        high
290
            Upper bound of the distribution.
291
        size : Size, optional
292
            Output shape.
293

294
        Returns
295
        -------
296
            Array of samples from the uniform distribution.
297
        """
298
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
299

300

301
class XPAdditions:
8✔
302
    """
303
    Additional functions missing from both array-api-strict and array-api-extra.
304

305
    This class provides wrappers for common array operations such as integration,
306
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
307
    and array-api-strict backends.
308

309
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
310
    for details.
311
    """
312

313
    xp: ModuleType
8✔
314
    backend: str
8✔
315

316
    def __init__(self, xp: ModuleType) -> None:
8✔
317
        """
318
        Initialize XPAdditions with the given array namespace.
319

320
        Parameters
321
        ----------
322
        xp
323
            The array namespace module.
324
        """
325
        self.xp = xp
8✔
326

327
    def trapezoid(
8✔
328
        self, y: AnyArray, x: AnyArray = None, dx: float = 1.0, axis: int = -1
329
    ) -> AnyArray:
330
        """
331
        Integrate along the given axis using the composite trapezoidal rule.
332

333
        Parameters
334
        ----------
335
        y
336
            Input array to integrate.
337
        x
338
            Sample points corresponding to y.
339
        dx
340
            Spacing between sample points.
341
        axis
342
            Axis along which to integrate.
343

344
        Returns
345
        -------
346
            Integrated result.
347

348
        Raises
349
        ------
350
        NotImplementedError
351
            If the array backend is not supported.
352

353
        Notes
354
        -----
355
        See https://github.com/glass-dev/glass/issues/646
356
        """
357
        if self.xp.__name__ == "jax.numpy":
8✔
358
            import glass.jax  # noqa: PLC0415
8✔
359

360
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
361

362
        if self.xp.__name__ == "numpy":
8✔
363
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
364

365
        if self.xp.__name__ == "array_api_strict":
8✔
366
            np = import_numpy(self.xp.__name__)
8✔
367

368
            # Using design principle of scipy (i.e. copy, use np, copy back)
369
            y_np = np.asarray(y, copy=True)
8✔
370
            x_np = np.asarray(x, copy=True)
8✔
371
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
372
            return self.xp.asarray(result_np, copy=True)
8✔
373

374
        msg = "the array backend in not supported"
×
375
        raise NotImplementedError(msg)
×
376

377
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
378
        """
379
        Compute the set union of two 1D arrays.
380

381
        Parameters
382
        ----------
383
        ar1
384
            First input array.
385
        ar2
386
            Second input array.
387

388
        Returns
389
        -------
390
            The union of the two arrays.
391

392
        Raises
393
        ------
394
        NotImplementedError
395
            If the array backend is not supported.
396

397
        Notes
398
        -----
399
        See https://github.com/glass-dev/glass/issues/647
400
        """
401
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
402
            return self.xp.union1d(ar1, ar2)
8✔
403

404
        if self.xp.__name__ == "array_api_strict":
8✔
405
            np = import_numpy(self.xp.__name__)
8✔
406

407
            # Using design principle of scipy (i.e. copy, use np, copy back)
408
            ar1_np = np.asarray(ar1, copy=True)
8✔
409
            ar2_np = np.asarray(ar2, copy=True)
8✔
410
            result_np = np.union1d(ar1_np, ar2_np)
8✔
411
            return self.xp.asarray(result_np, copy=True)
8✔
412

413
        msg = "the array backend in not supported"
×
414
        raise NotImplementedError(msg)
×
415

416
    def interp(  # noqa: PLR0913
8✔
417
        self,
418
        x: AnyArray,
419
        x_points: AnyArray,
420
        y_points: AnyArray,
421
        left: float | None = None,
422
        right: float | None = None,
423
        period: float | None = None,
424
    ) -> AnyArray:
425
        """
426
        One-dimensional linear interpolation for monotonically increasing sample points.
427

428
        Parameters
429
        ----------
430
        x
431
            The x-coordinates at which to evaluate the interpolated values.
432
        x_points
433
            The x-coordinates of the data points.
434
        y_points
435
            The y-coordinates of the data points.
436
        left
437
            Value to return for x < x_points[0].
438
        right
439
            Value to return for x > x_points[-1].
440
        period
441
            Period for periodic interpolation.
442

443
        Returns
444
        -------
445
            Interpolated values.
446

447
        Raises
448
        ------
449
        NotImplementedError
450
            If the array backend is not supported.
451

452
        Notes
453
        -----
454
        See https://github.com/glass-dev/glass/issues/650
455
        """
456
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
457
            return self.xp.interp(
8✔
458
                x, x_points, y_points, left=left, right=right, period=period
459
            )
460

461
        if self.xp.__name__ == "array_api_strict":
8✔
462
            np = import_numpy(self.xp.__name__)
8✔
463

464
            # Using design principle of scipy (i.e. copy, use np, copy back)
465
            x_np = np.asarray(x, copy=True)
8✔
466
            x_points_np = np.asarray(x_points, copy=True)
8✔
467
            y_points_np = np.asarray(y_points, copy=True)
8✔
468
            result_np = np.interp(
8✔
469
                x_np, x_points_np, y_points_np, left=left, right=right, period=period
470
            )
471
            return self.xp.asarray(result_np, copy=True)
8✔
472

473
        msg = "the array backend in not supported"
×
474
        raise NotImplementedError(msg)
×
475

476
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
477
        """
478
        Return the gradient of an N-dimensional array.
479

480
        Parameters
481
        ----------
482
        f
483
            Input array.
484

485
        Returns
486
        -------
487
            Gradient of the input array.
488

489
        Raises
490
        ------
491
        NotImplementedError
492
            If the array backend is not supported.
493

494
        Notes
495
        -----
496
        See https://github.com/glass-dev/glass/issues/648
497
        """
498
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
499
            return self.xp.gradient(f)
8✔
500

501
        if self.xp.__name__ == "array_api_strict":
8✔
502
            np = import_numpy(self.xp.__name__)
8✔
503

504
            # Using design principle of scipy (i.e. copy, use np, copy back)
505
            f_np = np.asarray(f, copy=True)
8✔
506
            result_np = np.gradient(f_np)
8✔
507
            return self.xp.asarray(result_np, copy=True)
8✔
508

509
        msg = "the array backend in not supported"
×
510
        raise NotImplementedError(msg)
×
511

512
    def linalg_lstsq(
8✔
513
        self, a: AnyArray, b: AnyArray, rcond: float | None = None
514
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
515
        """
516
        Solve a linear least squares problem.
517

518
        Parameters
519
        ----------
520
        a
521
            Coefficient matrix.
522
        b
523
            Ordinate or "dependent variable" values.
524
        rcond
525
            Cut-off ratio for small singular values.
526

527
        Returns
528
        -------
529
        x
530
            Least-squares solution. If b is two-dimensional, the solutions are in the K
531
            columns of x.
532

533
        residuals
534
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
535
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
536
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
537

538
        rank
539
            Rank of matrix a.
540

541
        s
542
            Singular values of a.
543

544
        Raises
545
        ------
546
        NotImplementedError
547
            If the array backend is not supported.
548

549
        Notes
550
        -----
551
        See https://github.com/glass-dev/glass/issues/649
552
        """
553
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
554
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
555

556
        if self.xp.__name__ == "array_api_strict":
8✔
557
            np = import_numpy(self.xp.__name__)
8✔
558

559
            # Using design principle of scipy (i.e. copy, use np, copy back)
560
            a_np = np.asarray(a, copy=True)
8✔
561
            b_np = np.asarray(b, copy=True)
8✔
562
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
563
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
564

565
        msg = "the array backend in not supported"
×
566
        raise NotImplementedError(msg)
×
567

568
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
569
        """
570
        Evaluate the Einstein summation convention on the operands.
571

572
        Parameters
573
        ----------
574
        subscripts
575
            Specifies the subscripts for summation.
576
        *operands
577
            Arrays to be summed.
578

579
        Returns
580
        -------
581
            Result of the Einstein summation.
582

583
        Raises
584
        ------
585
        NotImplementedError
586
            If the array backend is not supported.
587

588
        Notes
589
        -----
590
        See https://github.com/glass-dev/glass/issues/657
591
        """
592
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
593
            return self.xp.einsum(subscripts, *operands)
8✔
594

595
        if self.xp.__name__ == "array_api_strict":
8✔
596
            np = import_numpy(self.xp.__name__)
8✔
597

598
            # Using design principle of scipy (i.e. copy, use np, copy back)
599
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
600
            result_np = np.einsum(subscripts, *operands_np)
8✔
601
            return self.xp.asarray(result_np, copy=True)
8✔
602

603
        msg = "the array backend in not supported"
×
604
        raise NotImplementedError(msg)
×
605

606
    def apply_along_axis(
8✔
607
        self,
608
        func1d: Callable[..., Any],
609
        axis: int,
610
        arr: AnyArray,
611
        *args: object,
612
        **kwargs: object,
613
    ) -> AnyArray:
614
        """
615
        Apply a function to 1-D slices along the given axis.
616

617
        Parameters
618
        ----------
619
        func1d
620
            Function to apply to 1-D slices.
621
        axis
622
            Axis along which to apply the function.
623
        arr
624
            Input array.
625
        *args
626
            Additional positional arguments to pass to func1d.
627
        **kwargs
628
            Additional keyword arguments to pass to func1d.
629

630
        Returns
631
        -------
632
            Result of applying the function along the axis.
633

634
        Raises
635
        ------
636
        NotImplementedError
637
            If the array backend is not supported.
638

639
        Notes
640
        -----
641
        See https://github.com/glass-dev/glass/issues/651
642

643
        """
644
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
645
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
646

647
        if self.xp.__name__ == "array_api_strict":
8✔
648
            # Import here to prevent users relying on numpy unless in this instance
649
            np = import_numpy(self.xp.__name__)
8✔
650

651
            return self.xp.asarray(
8✔
652
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs), copy=True
653
            )
654

655
        msg = "the array backend in not supported"
×
656
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc