• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 19344074984

13 Nov 2025 07:55PM UTC coverage: 93.367% (+0.03%) from 93.341%
19344074984

Pull #771

github

web-flow
Merge cf6786f59 into ab28ea36f
Pull Request #771: gh-770: Default to numpy for redshift_grid and fixed_zbins if no xp is provided

220 of 222 branches covered (99.1%)

Branch coverage included in aggregate %.

5 of 6 new or added lines in 3 files covered. (83.33%)

27 existing lines in 1 file now uncovered.

1469 of 1587 relevant lines covered (92.56%)

4.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.95
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
import inspect
8✔
20
import warnings
8✔
21
from typing import TYPE_CHECKING, Any
8✔
22

23
if TYPE_CHECKING:
24
    from collections.abc import Callable
25
    from types import ModuleType
26

27
    import numpy as np
28
    from numpy.typing import DTypeLike
29

30
    from array_api_strict._array_object import Array as AArray
31

32
    from glass._types import AnyArray, FloatArray, Size, UnifiedGenerator
33

34

35
class CompatibleBackendNotFoundError(Exception):
8✔
36
    """
37
    Exception raised when an array library backend that
38
    implements a requested function, is not found.
39
    """
40

41
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
8✔
NEW
42
        self.message = (
×
43
            f"{missing_backend} is required here as "
44
            "no alternative has been provided by the user."
45
            if users_backend is None
46
            else f"GLASS depends on functions not supported by {users_backend}"
47
        )
UNCOV
48
        super().__init__(self.message)
×
49

50

51
def import_numpy(backend: str | None = None) -> ModuleType:
8✔
52
    """
53
    Import the NumPy module, raising a helpful error if NumPy is not installed.
54

55
    Parameters
56
    ----------
57
    backend
58
        The name of the backend requested by the user.
59

60
    Returns
61
    -------
62
        The NumPy module.
63

64
    Raises
65
    ------
66
    ModuleNotFoundError
67
        If NumPy is not found in the user's environment.
68

69
    Notes
70
    -----
71
    This is useful for explaining to the user why NumPy is required when their chosen
72
    backend does not implement a needed function.
73
    """
74
    try:
4✔
75
        import numpy  # noqa: ICN001, PLC0415
4✔
76

UNCOV
77
    except ModuleNotFoundError as err:
×
UNCOV
78
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
79
    else:
80
        return numpy
4✔
81

82

83
def default_xp() -> ModuleType:
8✔
84
    """Returns the library backend we default to if none is specified by the user."""
85
    warnings.warn(
4✔
86
        f"No array library has been provided for call to {inspect.stack()[1].function}",
87
        stacklevel=2,
88
    )
89
    return import_numpy()
4✔
90

91

92
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
93
    """
94
    Dispatch a random number generator based on the provided array's backend.
95

96
    Parameters
97
    ----------
98
    xp
99
        The array library backend to use for array operations.
100

101
    Returns
102
    -------
103
        The appropriate random number generator for the array's backend.
104

105
    Raises
106
    ------
107
    NotImplementedError
108
        If the array backend is not supported.
109
    """
110
    if xp.__name__ == "jax.numpy":
4✔
111
        import glass.jax  # noqa: PLC0415
4✔
112

113
        return glass.jax.Generator(seed=42)
4✔
114

115
    if xp.__name__ == "numpy":
4✔
116
        return xp.random.default_rng()  # type: ignore[no-any-return]
4✔
117

118
    if xp.__name__ == "array_api_strict":
4✔
119
        return Generator(seed=42)
4✔
120

UNCOV
121
    msg = "the array backend in not supported"
×
UNCOV
122
    raise NotImplementedError(msg)
×
123

124

125
class Generator:
8✔
126
    """
127
    NumPy random number generator returning array_api_strict Array.
128

129
    This class wraps NumPy's random number generator and returns arrays compatible
130
    with array_api_strict.
131
    """
132

133
    __slots__ = ("axp", "nxp", "rng")
8✔
134

135
    def __init__(
8✔
136
        self,
137
        seed: int | bool | AArray | None = None,  # noqa: FBT001
138
    ) -> None:
139
        """
140
        Initialize the Generator.
141

142
        Parameters
143
        ----------
144
        seed
145
            Seed for the random number generator.
146
        """
147
        import numpy  # noqa: ICN001, PLC0415
8✔
148

149
        import array_api_strict  # noqa: PLC0415
8✔
150

151
        self.axp = array_api_strict
8✔
152
        self.nxp = numpy
8✔
153
        self.rng = self.nxp.random.default_rng(seed=seed)
8✔
154

155
    def random(
8✔
156
        self,
157
        size: Size = None,
158
        dtype: DTypeLike | None = None,
159
        out: AArray | None = None,
160
    ) -> AArray:
161
        """
162
        Return random floats in the half-open interval [0.0, 1.0).
163

164
        Parameters
165
        ----------
166
        size
167
            Output shape.
168
        dtype
169
            Desired data type.
170
        out
171
            Optional output array.
172

173
        Returns
174
        -------
175
            Array of random floats.
176
        """
177
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
178
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
179

180
    def normal(
8✔
181
        self,
182
        loc: float | FloatArray = 0.0,
183
        scale: float | FloatArray = 1.0,
184
        size: Size = None,
185
    ) -> AArray:
186
        """
187
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
188

189
        Parameters
190
        ----------
191
        loc
192
            Mean of the distribution.
193
        scale
194
            Standard deviation of the distribution.
195
        size
196
            Output shape.
197

198
        Returns
199
        -------
200
            Array of samples from the normal distribution.
201
        """
202
        return self.axp.asarray(self.rng.normal(loc, scale, size))
4✔
203

204
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
205
        """
206
        Draw samples from a Poisson distribution.
207

208
        Parameters
209
        ----------
210
        lam
211
            Expected number of events.
212
        size
213
            Output shape.
214

215
        Returns
216
        -------
217
            Array of samples from the Poisson distribution.
218
        """
219
        return self.axp.asarray(self.rng.poisson(lam, size))
4✔
220

221
    def standard_normal(
8✔
222
        self,
223
        size: Size = None,
224
        dtype: DTypeLike | None = None,
225
        out: AArray | None = None,
226
    ) -> AArray:
227
        """
228
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
229

230
        Parameters
231
        ----------
232
        size
233
            Output shape.
234
        dtype
235
            Desired data type.
236
        out
237
            Optional output array.
238

239
        Returns
240
        -------
241
            Array of samples from the standard normal distribution.
242
        """
243
        dtype = dtype if dtype is not None else self.nxp.float64
4✔
244
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
4✔
245

246
    def uniform(
8✔
247
        self,
248
        low: float | AArray = 0.0,
249
        high: float | AArray = 1.0,
250
        size: Size = None,
251
    ) -> AArray:
252
        """
253
        Draw samples from a Uniform distribution.
254

255
        Parameters
256
        ----------
257
        low
258
            Lower bound of the distribution.
259
        high
260
            Upper bound of the distribution.
261
        size : Size, optional
262
            Output shape.
263

264
        Returns
265
        -------
266
            Array of samples from the uniform distribution.
267
        """
268
        return self.axp.asarray(self.rng.uniform(low, high, size))
4✔
269

270

271
class XPAdditions:
8✔
272
    """
273
    Additional functions missing from both array-api-strict and array-api-extra.
274

275
    This class provides wrappers for common array operations such as integration,
276
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
277
    and array-api-strict backends.
278

279
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
280
    for details.
281
    """
282

283
    xp: ModuleType
8✔
284
    backend: str
8✔
285

286
    def __init__(self, xp: ModuleType) -> None:
8✔
287
        """
288
        Initialize XPAdditions with the given array namespace.
289

290
        Parameters
291
        ----------
292
        xp
293
            The array library backend to use for array operations.
294
        """
295
        self.xp = xp
4✔
296

297
    def trapezoid(
8✔
298
        self,
299
        y: AnyArray,
300
        x: AnyArray = None,
301
        dx: float = 1.0,
302
        axis: int = -1,
303
    ) -> AnyArray:
304
        """
305
        Integrate along the given axis using the composite trapezoidal rule.
306

307
        Parameters
308
        ----------
309
        y
310
            Input array to integrate.
311
        x
312
            Sample points corresponding to y.
313
        dx
314
            Spacing between sample points.
315
        axis
316
            Axis along which to integrate.
317

318
        Returns
319
        -------
320
            Integrated result.
321

322
        Raises
323
        ------
324
        NotImplementedError
325
            If the array backend is not supported.
326

327
        Notes
328
        -----
329
        See https://github.com/glass-dev/glass/issues/646
330
        """
331
        if self.xp.__name__ == "jax.numpy":
4✔
332
            import glass.jax  # noqa: PLC0415
4✔
333

334
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
4✔
335

336
        if self.xp.__name__ == "numpy":
4✔
337
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
4✔
338

339
        if self.xp.__name__ == "array_api_strict":
4✔
340
            np = import_numpy(self.xp.__name__)
4✔
341

342
            # Using design principle of scipy (i.e. copy, use np, copy back)
343
            y_np = np.asarray(y, copy=True)
4✔
344
            x_np = np.asarray(x, copy=True)
4✔
345
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
4✔
346
            return self.xp.asarray(result_np, copy=True)
4✔
347

UNCOV
348
        msg = "the array backend in not supported"
×
UNCOV
349
        raise NotImplementedError(msg)
×
350

351
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
352
        """
353
        Compute the set union of two 1D arrays.
354

355
        Parameters
356
        ----------
357
        ar1
358
            First input array.
359
        ar2
360
            Second input array.
361

362
        Returns
363
        -------
364
            The union of the two arrays.
365

366
        Raises
367
        ------
368
        NotImplementedError
369
            If the array backend is not supported.
370

371
        Notes
372
        -----
373
        See https://github.com/glass-dev/glass/issues/647
374
        """
375
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
376
            return self.xp.union1d(ar1, ar2)
4✔
377

378
        if self.xp.__name__ == "array_api_strict":
4✔
379
            np = import_numpy(self.xp.__name__)
4✔
380

381
            # Using design principle of scipy (i.e. copy, use np, copy back)
382
            ar1_np = np.asarray(ar1, copy=True)
4✔
383
            ar2_np = np.asarray(ar2, copy=True)
4✔
384
            result_np = np.union1d(ar1_np, ar2_np)
4✔
385
            return self.xp.asarray(result_np, copy=True)
4✔
386

UNCOV
387
        msg = "the array backend in not supported"
×
UNCOV
388
        raise NotImplementedError(msg)
×
389

390
    def interp(  # noqa: PLR0913
8✔
391
        self,
392
        x: AnyArray,
393
        x_points: AnyArray,
394
        y_points: AnyArray,
395
        left: float | None = None,
396
        right: float | None = None,
397
        period: float | None = None,
398
    ) -> AnyArray:
399
        """
400
        One-dimensional linear interpolation for monotonically increasing sample points.
401

402
        Parameters
403
        ----------
404
        x
405
            The x-coordinates at which to evaluate the interpolated values.
406
        x_points
407
            The x-coordinates of the data points.
408
        y_points
409
            The y-coordinates of the data points.
410
        left
411
            Value to return for x < x_points[0].
412
        right
413
            Value to return for x > x_points[-1].
414
        period
415
            Period for periodic interpolation.
416

417
        Returns
418
        -------
419
            Interpolated values.
420

421
        Raises
422
        ------
423
        NotImplementedError
424
            If the array backend is not supported.
425

426
        Notes
427
        -----
428
        See https://github.com/glass-dev/glass/issues/650
429
        """
430
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
431
            return self.xp.interp(
4✔
432
                x,
433
                x_points,
434
                y_points,
435
                left=left,
436
                right=right,
437
                period=period,
438
            )
439

440
        if self.xp.__name__ == "array_api_strict":
4✔
441
            np = import_numpy(self.xp.__name__)
4✔
442

443
            # Using design principle of scipy (i.e. copy, use np, copy back)
444
            x_np = np.asarray(x, copy=True)
4✔
445
            x_points_np = np.asarray(x_points, copy=True)
4✔
446
            y_points_np = np.asarray(y_points, copy=True)
4✔
447
            result_np = np.interp(
4✔
448
                x_np,
449
                x_points_np,
450
                y_points_np,
451
                left=left,
452
                right=right,
453
                period=period,
454
            )
455
            return self.xp.asarray(result_np, copy=True)
4✔
456

UNCOV
457
        msg = "the array backend in not supported"
×
UNCOV
458
        raise NotImplementedError(msg)
×
459

460
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
461
        """
462
        Return the gradient of an N-dimensional array.
463

464
        Parameters
465
        ----------
466
        f
467
            Input array.
468

469
        Returns
470
        -------
471
            Gradient of the input array.
472

473
        Raises
474
        ------
475
        NotImplementedError
476
            If the array backend is not supported.
477

478
        Notes
479
        -----
480
        See https://github.com/glass-dev/glass/issues/648
481
        """
482
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
483
            return self.xp.gradient(f)
4✔
484

485
        if self.xp.__name__ == "array_api_strict":
4✔
486
            np = import_numpy(self.xp.__name__)
4✔
487

488
            # Using design principle of scipy (i.e. copy, use np, copy back)
489
            f_np = np.asarray(f, copy=True)
4✔
490
            result_np = np.gradient(f_np)
4✔
491
            return self.xp.asarray(result_np, copy=True)
4✔
492

UNCOV
493
        msg = "the array backend in not supported"
×
UNCOV
494
        raise NotImplementedError(msg)
×
495

496
    def linalg_lstsq(
8✔
497
        self,
498
        a: AnyArray,
499
        b: AnyArray,
500
        rcond: float | None = None,
501
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
502
        """
503
        Solve a linear least squares problem.
504

505
        Parameters
506
        ----------
507
        a
508
            Coefficient matrix.
509
        b
510
            Ordinate or "dependent variable" values.
511
        rcond
512
            Cut-off ratio for small singular values.
513

514
        Returns
515
        -------
516
        x
517
            Least-squares solution. If b is two-dimensional, the solutions are in the K
518
            columns of x.
519

520
        residuals
521
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
522
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
523
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
524

525
        rank
526
            Rank of matrix a.
527

528
        s
529
            Singular values of a.
530

531
        Raises
532
        ------
533
        NotImplementedError
534
            If the array backend is not supported.
535

536
        Notes
537
        -----
538
        See https://github.com/glass-dev/glass/issues/649
539
        """
540
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
541
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
4✔
542

543
        if self.xp.__name__ == "array_api_strict":
4✔
544
            np = import_numpy(self.xp.__name__)
4✔
545

546
            # Using design principle of scipy (i.e. copy, use np, copy back)
547
            a_np = np.asarray(a, copy=True)
4✔
548
            b_np = np.asarray(b, copy=True)
4✔
549
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
4✔
550
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
4✔
551

UNCOV
552
        msg = "the array backend in not supported"
×
UNCOV
553
        raise NotImplementedError(msg)
×
554

555
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
556
        """
557
        Evaluate the Einstein summation convention on the operands.
558

559
        Parameters
560
        ----------
561
        subscripts
562
            Specifies the subscripts for summation.
563
        *operands
564
            Arrays to be summed.
565

566
        Returns
567
        -------
568
            Result of the Einstein summation.
569

570
        Raises
571
        ------
572
        NotImplementedError
573
            If the array backend is not supported.
574

575
        Notes
576
        -----
577
        See https://github.com/glass-dev/glass/issues/657
578
        """
579
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
580
            return self.xp.einsum(subscripts, *operands)
4✔
581

582
        if self.xp.__name__ == "array_api_strict":
4✔
583
            np = import_numpy(self.xp.__name__)
4✔
584

585
            # Using design principle of scipy (i.e. copy, use np, copy back)
586
            operands_np = (np.asarray(op, copy=True) for op in operands)
4✔
587
            result_np = np.einsum(subscripts, *operands_np)
4✔
588
            return self.xp.asarray(result_np, copy=True)
4✔
589

UNCOV
590
        msg = "the array backend in not supported"
×
UNCOV
591
        raise NotImplementedError(msg)
×
592

593
    def apply_along_axis(
8✔
594
        self,
595
        func1d: Callable[..., Any],
596
        axis: int,
597
        arr: AnyArray,
598
        *args: object,
599
        **kwargs: object,
600
    ) -> AnyArray:
601
        """
602
        Apply a function to 1-D slices along the given axis.
603

604
        Parameters
605
        ----------
606
        func1d
607
            Function to apply to 1-D slices.
608
        axis
609
            Axis along which to apply the function.
610
        arr
611
            Input array.
612
        *args
613
            Additional positional arguments to pass to func1d.
614
        **kwargs
615
            Additional keyword arguments to pass to func1d.
616

617
        Returns
618
        -------
619
            Result of applying the function along the axis.
620

621
        Raises
622
        ------
623
        NotImplementedError
624
            If the array backend is not supported.
625

626
        Notes
627
        -----
628
        See https://github.com/glass-dev/glass/issues/651
629

630
        """
631
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
632
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
4✔
633

634
        if self.xp.__name__ == "array_api_strict":
4✔
635
            # Import here to prevent users relying on numpy unless in this instance
636
            np = import_numpy(self.xp.__name__)
4✔
637

638
            return self.xp.asarray(
4✔
639
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
640
                copy=True,
641
            )
642

UNCOV
643
        msg = "the array backend in not supported"
×
UNCOV
644
        raise NotImplementedError(msg)
×
645

646
    def vectorize(
8✔
647
        self,
648
        pyfunc: Callable[..., Any],
649
        otypes: tuple[type[float]],
650
    ) -> Callable[..., Any]:
651
        """
652
        Returns an object that acts like pyfunc, but takes arrays as input.
653

654
        Parameters
655
        ----------
656
        pyfunc
657
            Python function to vectorize.
658
        otypes
659
            Output types.
660

661
        Returns
662
        -------
663
            Vectorized function.
664

665
        Raises
666
        ------
667
        NotImplementedError
668
            If the array backend is not supported.
669

670
        Notes
671
        -----
672
        See https://github.com/glass-dev/glass/issues/671
673
        """
674
        if self.xp.__name__ == "numpy":
4✔
675
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
4✔
676

677
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
4✔
678
            # Import here to prevent users relying on numpy unless in this instance
679
            np = import_numpy(self.xp.__name__)
4✔
680

681
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
4✔
682

UNCOV
683
        msg = "the array backend in not supported"
×
UNCOV
684
        raise NotImplementedError(msg)
×
685

686
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
687
        """
688
        Convert angles from degrees to radians.
689

690
        Parameters
691
        ----------
692
        deg_arr
693
            Array of angles in degrees.
694

695
        Returns
696
        -------
697
            Array of angles in radians.
698

699
        Raises
700
        ------
701
        NotImplementedError
702
            If the array backend is not supported.
703
        """
704
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
705
            return self.xp.radians(deg_arr)
4✔
706

707
        if self.xp.__name__ == "array_api_strict":
4✔
708
            np = import_numpy(self.xp.__name__)
4✔
709

710
            return self.xp.asarray(np.radians(deg_arr))
4✔
711

UNCOV
712
        msg = "the array backend in not supported"
×
UNCOV
713
        raise NotImplementedError(msg)
×
714

715
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
716
        """
717
        Convert angles from radians to degrees.
718

719
        Parameters
720
        ----------
721
        deg_arr
722
            Array of angles in radians.
723

724
        Returns
725
        -------
726
            Array of angles in degrees.
727

728
        Raises
729
        ------
730
        NotImplementedError
731
            If the array backend is not supported.
732
        """
733
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
4✔
734
            return self.xp.degrees(deg_arr)
4✔
735

736
        if self.xp.__name__ == "array_api_strict":
4✔
737
            np = import_numpy(self.xp.__name__)
4✔
738

739
            return self.xp.asarray(np.degrees(deg_arr))
4✔
740

UNCOV
741
        msg = "the array backend in not supported"
×
UNCOV
742
        raise NotImplementedError(msg)
×
743

744
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
745
        """
746
        Wrapper for numpy.ndindex.
747

748
        See relevant docs for details:
749
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
750

751
        Raises
752
        ------
753
        NotImplementedError
754
            If the array backend is not supported.
755

756
        """
757
        if self.xp.__name__ == "numpy":
4✔
758
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
4✔
759

760
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
4✔
761
            np = import_numpy(self.xp.__name__)
4✔
762

763
            return np.ndindex(shape)  # type: ignore[no-any-return]
4✔
764

UNCOV
765
        msg = "the array backend in not supported"
×
UNCOV
766
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc