• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 20752181965

06 Jan 2026 02:58PM UTC coverage: 93.966% (+0.7%) from 93.306%
20752181965

push

github

web-flow
gh-945: Add `uv.lock` pre-commit hook (#946)

211 of 213 branches covered (99.06%)

Branch coverage included in aggregate %.

1315 of 1411 relevant lines covered (93.2%)

5.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.95
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
==============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
18

19
from typing import TYPE_CHECKING, Any
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from numpy.typing import DTypeLike
27

28
    from array_api_strict._array_object import Array as AArray
29

30
    from glass._types import AnyArray, FloatArray, UnifiedGenerator
31

32

33
class CompatibleBackendNotFoundError(Exception):
6✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
6✔
40
        self.message = (
×
41
            f"{missing_backend} is required here as "
42
            "no alternative has been provided by the user."
43
            if users_backend is None
44
            else f"GLASS depends on functions not supported by {users_backend}"
45
        )
46
        super().__init__(self.message)
×
47

48

49
def import_numpy(backend: str | None = None) -> ModuleType:
6✔
50
    """
51
    Import the NumPy module, raising a helpful error if NumPy is not installed.
52

53
    Parameters
54
    ----------
55
    backend
56
        The name of the backend requested by the user.
57

58
    Returns
59
    -------
60
        The NumPy module.
61

62
    Raises
63
    ------
64
    ModuleNotFoundError
65
        If NumPy is not found in the user's environment.
66

67
    Notes
68
    -----
69
    This is useful for explaining to the user why NumPy is required when their chosen
70
    backend does not implement a needed function.
71
    """
72
    try:
5✔
73
        import numpy  # noqa: ICN001, PLC0415
74

75
    except ModuleNotFoundError as err:
×
76
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
77
    else:
78
        return numpy
5✔
79

80

81
def default_xp() -> ModuleType:
6✔
82
    """Returns the library backend we default to if none is specified by the user."""
83
    return import_numpy()
5✔
84

85

86
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
6✔
87
    """
88
    Dispatch a random number generator based on the provided array's backend.
89

90
    Parameters
91
    ----------
92
    xp
93
        The array library backend to use for array operations.
94

95
    Returns
96
    -------
97
        The appropriate random number generator for the array's backend.
98

99
    Raises
100
    ------
101
    NotImplementedError
102
        If the array backend is not supported.
103
    """
104
    seed = 42
6✔
105

106
    if xp.__name__ == "jax.numpy":
6✔
107
        import glass.jax  # noqa: PLC0415
108

109
        return glass.jax.Generator(seed=seed)
5✔
110

111
    if xp.__name__ == "numpy":
6✔
112
        return xp.random.default_rng(seed=seed)  # type: ignore[no-any-return]
6✔
113

114
    if xp.__name__ == "array_api_strict":
5✔
115
        return Generator(seed=seed)
5✔
116

117
    msg = "the array backend in not supported"
×
118
    raise NotImplementedError(msg)
×
119

120

121
class Generator:
6✔
122
    """
123
    NumPy random number generator returning array_api_strict Array.
124

125
    This class wraps NumPy's random number generator and returns arrays compatible
126
    with array_api_strict.
127
    """
128

129
    __slots__ = ("axp", "nxp", "rng")
6✔
130

131
    def __init__(
6✔
132
        self,
133
        seed: int | bool | AArray | None = None,  # noqa: FBT001
134
    ) -> None:
135
        """
136
        Initialize the Generator.
137

138
        Parameters
139
        ----------
140
        seed
141
            Seed for the random number generator.
142
        """
143
        import numpy  # noqa: ICN001, PLC0415
144

145
        import array_api_strict  # noqa: PLC0415
146

147
        self.axp = array_api_strict
5✔
148
        self.nxp = numpy
5✔
149
        self.rng = self.nxp.random.default_rng(seed=seed)
5✔
150

151
    def random(
6✔
152
        self,
153
        size: int | tuple[int, ...] | None = None,
154
        dtype: DTypeLike | None = None,
155
        out: AArray | None = None,
156
    ) -> AArray:
157
        """
158
        Return random floats in the half-open interval [0.0, 1.0).
159

160
        Parameters
161
        ----------
162
        size
163
            Output shape.
164
        dtype
165
            Desired data type.
166
        out
167
            Optional output array.
168

169
        Returns
170
        -------
171
            Array of random floats.
172
        """
173
        dtype = dtype if dtype is not None else self.nxp.float64
5✔
174
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
5✔
175

176
    def normal(
6✔
177
        self,
178
        loc: float | FloatArray = 0.0,
179
        scale: float | FloatArray = 1.0,
180
        size: int | tuple[int, ...] | None = None,
181
    ) -> AArray:
182
        """
183
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
184

185
        Parameters
186
        ----------
187
        loc
188
            Mean of the distribution.
189
        scale
190
            Standard deviation of the distribution.
191
        size
192
            Output shape.
193

194
        Returns
195
        -------
196
            Array of samples from the normal distribution.
197
        """
198
        return self.axp.asarray(self.rng.normal(loc, scale, size))
5✔
199

200
    def poisson(
6✔
201
        self,
202
        lam: float | AArray,
203
        size: int | tuple[int, ...] | None = None,
204
    ) -> AArray:
205
        """
206
        Draw samples from a Poisson distribution.
207

208
        Parameters
209
        ----------
210
        lam
211
            Expected number of events.
212
        size
213
            Output shape.
214

215
        Returns
216
        -------
217
            Array of samples from the Poisson distribution.
218
        """
219
        return self.axp.asarray(self.rng.poisson(lam, size))
5✔
220

221
    def standard_normal(
6✔
222
        self,
223
        size: int | tuple[int, ...] | None = None,
224
        dtype: DTypeLike | None = None,
225
        out: AArray | None = None,
226
    ) -> AArray:
227
        """
228
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
229

230
        Parameters
231
        ----------
232
        size
233
            Output shape.
234
        dtype
235
            Desired data type.
236
        out
237
            Optional output array.
238

239
        Returns
240
        -------
241
            Array of samples from the standard normal distribution.
242
        """
243
        dtype = dtype if dtype is not None else self.nxp.float64
5✔
244
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
5✔
245

246
    def uniform(
6✔
247
        self,
248
        low: float | AArray = 0.0,
249
        high: float | AArray = 1.0,
250
        size: int | tuple[int, ...] | None = None,
251
    ) -> AArray:
252
        """
253
        Draw samples from a Uniform distribution.
254

255
        Parameters
256
        ----------
257
        low
258
            Lower bound of the distribution.
259
        high
260
            Upper bound of the distribution.
261
        size
262
            Output shape.
263

264
        Returns
265
        -------
266
            Array of samples from the uniform distribution.
267
        """
268
        return self.axp.asarray(self.rng.uniform(low, high, size))
5✔
269

270

271
class XPAdditions:
6✔
272
    """
273
    Additional functions missing from both array-api-strict and array-api-extra.
274

275
    This class provides wrappers for common array operations such as integration,
276
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
277
    and array-api-strict backends.
278

279
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
280
    for details.
281
    """
282

283
    xp: ModuleType
6✔
284
    backend: str
6✔
285

286
    def __init__(self, xp: ModuleType) -> None:
6✔
287
        """
288
        Initialize XPAdditions with the given array namespace.
289

290
        Parameters
291
        ----------
292
        xp
293
            The array library backend to use for array operations.
294
        """
295
        self.xp = xp
6✔
296

297
    def trapezoid(
6✔
298
        self,
299
        y: AnyArray,
300
        x: AnyArray = None,
301
        dx: float = 1.0,
302
        axis: int = -1,
303
    ) -> AnyArray:
304
        """
305
        Integrate along the given axis using the composite trapezoidal rule.
306

307
        Parameters
308
        ----------
309
        y
310
            Input array to integrate.
311
        x
312
            Sample points corresponding to y.
313
        dx
314
            Spacing between sample points.
315
        axis
316
            Axis along which to integrate.
317

318
        Returns
319
        -------
320
            Integrated result.
321

322
        Raises
323
        ------
324
        NotImplementedError
325
            If the array backend is not supported.
326

327
        Notes
328
        -----
329
        See https://github.com/glass-dev/glass/issues/646
330
        """
331
        if self.xp.__name__ == "jax.numpy":
6✔
332
            import glass.jax  # noqa: PLC0415
333

334
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
5✔
335

336
        if self.xp.__name__ == "numpy":
6✔
337
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
6✔
338

339
        if self.xp.__name__ == "array_api_strict":
5✔
340
            np = import_numpy(self.xp.__name__)
5✔
341

342
            # Using design principle of scipy (i.e. copy, use np, copy back)
343
            y_np = np.asarray(y, copy=True)
5✔
344
            x_np = np.asarray(x, copy=True)
5✔
345
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
5✔
346
            return self.xp.asarray(result_np, copy=True)
5✔
347

348
        msg = "the array backend in not supported"
×
349
        raise NotImplementedError(msg)
×
350

351
    def interp(  # noqa: PLR0913
6✔
352
        self,
353
        x: AnyArray,
354
        x_points: AnyArray,
355
        y_points: AnyArray,
356
        left: float | None = None,
357
        right: float | None = None,
358
        period: float | None = None,
359
    ) -> AnyArray:
360
        """
361
        One-dimensional linear interpolation for monotonically increasing sample points.
362

363
        Parameters
364
        ----------
365
        x
366
            The x-coordinates at which to evaluate the interpolated values.
367
        x_points
368
            The x-coordinates of the data points.
369
        y_points
370
            The y-coordinates of the data points.
371
        left
372
            Value to return for x < x_points[0].
373
        right
374
            Value to return for x > x_points[-1].
375
        period
376
            Period for periodic interpolation.
377

378
        Returns
379
        -------
380
            Interpolated values.
381

382
        Raises
383
        ------
384
        NotImplementedError
385
            If the array backend is not supported.
386

387
        Notes
388
        -----
389
        See https://github.com/glass-dev/glass/issues/650
390
        """
391
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
6✔
392
            return self.xp.interp(
6✔
393
                x,
394
                x_points,
395
                y_points,
396
                left=left,
397
                right=right,
398
                period=period,
399
            )
400

401
        if self.xp.__name__ == "array_api_strict":
5✔
402
            np = import_numpy(self.xp.__name__)
5✔
403

404
            # Using design principle of scipy (i.e. copy, use np, copy back)
405
            x_np = np.asarray(x, copy=True)
5✔
406
            x_points_np = np.asarray(x_points, copy=True)
5✔
407
            y_points_np = np.asarray(y_points, copy=True)
5✔
408
            result_np = np.interp(
5✔
409
                x_np,
410
                x_points_np,
411
                y_points_np,
412
                left=left,
413
                right=right,
414
                period=period,
415
            )
416
            return self.xp.asarray(result_np, copy=True)
5✔
417

418
        msg = "the array backend in not supported"
×
419
        raise NotImplementedError(msg)
×
420

421
    def gradient(self, f: AnyArray) -> AnyArray:
6✔
422
        """
423
        Return the gradient of an N-dimensional array.
424

425
        Parameters
426
        ----------
427
        f
428
            Input array.
429

430
        Returns
431
        -------
432
            Gradient of the input array.
433

434
        Raises
435
        ------
436
        NotImplementedError
437
            If the array backend is not supported.
438

439
        Notes
440
        -----
441
        See https://github.com/glass-dev/glass/issues/648
442
        """
443
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
444
            return self.xp.gradient(f)
5✔
445

446
        if self.xp.__name__ == "array_api_strict":
5✔
447
            np = import_numpy(self.xp.__name__)
5✔
448

449
            # Using design principle of scipy (i.e. copy, use np, copy back)
450
            f_np = np.asarray(f, copy=True)
5✔
451
            result_np = np.gradient(f_np)
5✔
452
            return self.xp.asarray(result_np, copy=True)
5✔
453

454
        msg = "the array backend in not supported"
×
455
        raise NotImplementedError(msg)
×
456

457
    def linalg_lstsq(
6✔
458
        self,
459
        a: AnyArray,
460
        b: AnyArray,
461
        rcond: float | None = None,
462
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
463
        """
464
        Solve a linear least squares problem.
465

466
        Parameters
467
        ----------
468
        a
469
            Coefficient matrix.
470
        b
471
            Ordinate or "dependent variable" values.
472
        rcond
473
            Cut-off ratio for small singular values.
474

475
        Returns
476
        -------
477
        x
478
            Least-squares solution. If b is two-dimensional, the solutions are in the K
479
            columns of x.
480

481
        residuals
482
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
483
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
484
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
485

486
        rank
487
            Rank of matrix a.
488

489
        s
490
            Singular values of a.
491

492
        Raises
493
        ------
494
        NotImplementedError
495
            If the array backend is not supported.
496

497
        Notes
498
        -----
499
        See https://github.com/glass-dev/glass/issues/649
500
        """
501
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
502
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
5✔
503

504
        if self.xp.__name__ == "array_api_strict":
5✔
505
            np = import_numpy(self.xp.__name__)
5✔
506

507
            # Using design principle of scipy (i.e. copy, use np, copy back)
508
            a_np = np.asarray(a, copy=True)
5✔
509
            b_np = np.asarray(b, copy=True)
5✔
510
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
5✔
511
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
5✔
512

513
        msg = "the array backend in not supported"
×
514
        raise NotImplementedError(msg)
×
515

516
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
6✔
517
        """
518
        Evaluate the Einstein summation convention on the operands.
519

520
        Parameters
521
        ----------
522
        subscripts
523
            Specifies the subscripts for summation.
524
        *operands
525
            Arrays to be summed.
526

527
        Returns
528
        -------
529
            Result of the Einstein summation.
530

531
        Raises
532
        ------
533
        NotImplementedError
534
            If the array backend is not supported.
535

536
        Notes
537
        -----
538
        See https://github.com/glass-dev/glass/issues/657
539
        """
540
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
541
            return self.xp.einsum(subscripts, *operands)
5✔
542

543
        if self.xp.__name__ == "array_api_strict":
5✔
544
            np = import_numpy(self.xp.__name__)
5✔
545

546
            # Using design principle of scipy (i.e. copy, use np, copy back)
547
            operands_np = (np.asarray(op, copy=True) for op in operands)
5✔
548
            result_np = np.einsum(subscripts, *operands_np)
5✔
549
            return self.xp.asarray(result_np, copy=True)
5✔
550

551
        msg = "the array backend in not supported"
×
552
        raise NotImplementedError(msg)
×
553

554
    def apply_along_axis(
6✔
555
        self,
556
        func1d: Callable[..., Any],
557
        axis: int,
558
        arr: AnyArray,
559
        *args: object,
560
        **kwargs: object,
561
    ) -> AnyArray:
562
        """
563
        Apply a function to 1-D slices along the given axis.
564

565
        Parameters
566
        ----------
567
        func1d
568
            Function to apply to 1-D slices.
569
        axis
570
            Axis along which to apply the function.
571
        arr
572
            Input array.
573
        *args
574
            Additional positional arguments to pass to func1d.
575
        **kwargs
576
            Additional keyword arguments to pass to func1d.
577

578
        Returns
579
        -------
580
            Result of applying the function along the axis.
581

582
        Raises
583
        ------
584
        NotImplementedError
585
            If the array backend is not supported.
586

587
        Notes
588
        -----
589
        See https://github.com/glass-dev/glass/issues/651
590

591
        """
592
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
593
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
5✔
594

595
        if self.xp.__name__ == "array_api_strict":
5✔
596
            # Import here to prevent users relying on numpy unless in this instance
597
            np = import_numpy(self.xp.__name__)
5✔
598

599
            return self.xp.asarray(
5✔
600
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
601
                copy=True,
602
            )
603

604
        msg = "the array backend in not supported"
×
605
        raise NotImplementedError(msg)
×
606

607
    def vectorize(
6✔
608
        self,
609
        pyfunc: Callable[..., Any],
610
        otypes: tuple[type[float]],
611
    ) -> Callable[..., Any]:
612
        """
613
        Returns an object that acts like pyfunc, but takes arrays as input.
614

615
        Parameters
616
        ----------
617
        pyfunc
618
            Python function to vectorize.
619
        otypes
620
            Output types.
621

622
        Returns
623
        -------
624
            Vectorized function.
625

626
        Raises
627
        ------
628
        NotImplementedError
629
            If the array backend is not supported.
630

631
        Notes
632
        -----
633
        See https://github.com/glass-dev/glass/issues/671
634
        """
635
        if self.xp.__name__ == "numpy":
5✔
636
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
637

638
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
639
            # Import here to prevent users relying on numpy unless in this instance
640
            np = import_numpy(self.xp.__name__)
5✔
641

642
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
643

644
        msg = "the array backend in not supported"
×
645
        raise NotImplementedError(msg)
×
646

647
    def radians(self, deg_arr: AnyArray) -> AnyArray:
6✔
648
        """
649
        Convert angles from degrees to radians.
650

651
        Parameters
652
        ----------
653
        deg_arr
654
            Array of angles in degrees.
655

656
        Returns
657
        -------
658
            Array of angles in radians.
659

660
        Raises
661
        ------
662
        NotImplementedError
663
            If the array backend is not supported.
664
        """
665
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
666
            return self.xp.radians(deg_arr)
5✔
667

668
        if self.xp.__name__ == "array_api_strict":
5✔
669
            np = import_numpy(self.xp.__name__)
5✔
670

671
            return self.xp.asarray(np.radians(deg_arr))
5✔
672

673
        msg = "the array backend in not supported"
×
674
        raise NotImplementedError(msg)
×
675

676
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
6✔
677
        """
678
        Convert angles from radians to degrees.
679

680
        Parameters
681
        ----------
682
        deg_arr
683
            Array of angles in radians.
684

685
        Returns
686
        -------
687
            Array of angles in degrees.
688

689
        Raises
690
        ------
691
        NotImplementedError
692
            If the array backend is not supported.
693
        """
694
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
6✔
695
            return self.xp.degrees(deg_arr)
6✔
696

697
        if self.xp.__name__ == "array_api_strict":
5✔
698
            np = import_numpy(self.xp.__name__)
5✔
699

700
            return self.xp.asarray(np.degrees(deg_arr))
5✔
701

702
        msg = "the array backend in not supported"
×
703
        raise NotImplementedError(msg)
×
704

705
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
6✔
706
        """
707
        Wrapper for numpy.ndindex.
708

709
        See relevant docs for details:
710
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
711

712
        Raises
713
        ------
714
        NotImplementedError
715
            If the array backend is not supported.
716

717
        """
718
        if self.xp.__name__ == "numpy":
6✔
719
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
6✔
720

721
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
722
            np = import_numpy(self.xp.__name__)
5✔
723

724
            return np.ndindex(shape)  # type: ignore[no-any-return]
5✔
725

726
        msg = "the array backend in not supported"
×
727
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc