• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 19327482174

13 Nov 2025 09:53AM UTC coverage: 93.355% (+0.01%) from 93.341%
19327482174

Pull #771

github

web-flow
Merge 5aeec7902 into b42d62f8c
Pull Request #771: gh-770: Default to numpy for redshift_grid and fixed_zbins if no xp is provided

220 of 222 branches covered (99.1%)

Branch coverage included in aggregate %.

5 of 6 new or added lines in 3 files covered. (83.33%)

27 existing lines in 1 file now uncovered.

1466 of 1584 relevant lines covered (92.55%)

7.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.7
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
8✔
18

19
from typing import TYPE_CHECKING, Any
8✔
20

21
if TYPE_CHECKING:
22
    from collections.abc import Callable
23
    from types import ModuleType
24

25
    import numpy as np
26
    from numpy.typing import DTypeLike
27

28
    from array_api_strict._array_object import Array as AArray
29

30
    from glass._types import AnyArray, FloatArray, Size, UnifiedGenerator
31

32

33
class CompatibleBackendNotFoundError(Exception):
8✔
34
    """
35
    Exception raised when an array library backend that
36
    implements a requested function, is not found.
37
    """
38

39
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
8✔
UNCOV
40
        self.message = (
×
41
            f"{missing_backend} is required here as "
42
            "no alternative has been provided by the user."
43
            if users_backend is None
44
            else f"GLASS depends on functions not supported by {users_backend}"
45
        )
NEW
46
        super().__init__(self.message)
×
47

48

49
def import_numpy(backend: str | None = None) -> ModuleType:
8✔
50
    """
51
    Import the NumPy module, raising a helpful error if NumPy is not installed.
52

53
    Parameters
54
    ----------
55
    backend
56
        The name of the backend requested by the user.
57

58
    Returns
59
    -------
60
        The NumPy module.
61

62
    Raises
63
    ------
64
    ModuleNotFoundError
65
        If NumPy is not found in the user's environment.
66

67
    Notes
68
    -----
69
    This is useful for explaining to the user why NumPy is required when their chosen
70
    backend does not implement a needed function.
71
    """
72
    try:
8✔
73
        import numpy  # noqa: ICN001, PLC0415
8✔
74

UNCOV
75
    except ModuleNotFoundError as err:
×
UNCOV
76
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
77
    else:
78
        return numpy
8✔
79

80

81
def default_xp() -> ModuleType:
8✔
82
    """Returns the library backend we default to if none is specified by the user."""
83
    return import_numpy()
8✔
84

85

86
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
8✔
87
    """
88
    Dispatch a random number generator based on the provided array's backend.
89

90
    Parameters
91
    ----------
92
    xp
93
        The array library backend to use for array operations.
94

95
    Returns
96
    -------
97
        The appropriate random number generator for the array's backend.
98

99
    Raises
100
    ------
101
    NotImplementedError
102
        If the array backend is not supported.
103
    """
104
    if xp.__name__ == "jax.numpy":
8✔
105
        import glass.jax  # noqa: PLC0415
8✔
106

107
        return glass.jax.Generator(seed=42)
8✔
108

109
    if xp.__name__ == "numpy":
8✔
110
        return xp.random.default_rng()  # type: ignore[no-any-return]
8✔
111

112
    if xp.__name__ == "array_api_strict":
8✔
113
        return Generator(seed=42)
8✔
114

UNCOV
115
    msg = "the array backend in not supported"
×
UNCOV
116
    raise NotImplementedError(msg)
×
117

118

119
class Generator:
8✔
120
    """
121
    NumPy random number generator returning array_api_strict Array.
122

123
    This class wraps NumPy's random number generator and returns arrays compatible
124
    with array_api_strict.
125
    """
126

127
    __slots__ = ("axp", "nxp", "rng")
8✔
128

129
    def __init__(
8✔
130
        self,
131
        seed: int | bool | AArray | None = None,  # noqa: FBT001
132
    ) -> None:
133
        """
134
        Initialize the Generator.
135

136
        Parameters
137
        ----------
138
        seed
139
            Seed for the random number generator.
140
        """
141
        import numpy  # noqa: ICN001, PLC0415
8✔
142

143
        import array_api_strict  # noqa: PLC0415
8✔
144

145
        self.axp = array_api_strict
8✔
146
        self.nxp = numpy
8✔
147
        self.rng = self.nxp.random.default_rng(seed=seed)
8✔
148

149
    def random(
8✔
150
        self,
151
        size: Size = None,
152
        dtype: DTypeLike | None = None,
153
        out: AArray | None = None,
154
    ) -> AArray:
155
        """
156
        Return random floats in the half-open interval [0.0, 1.0).
157

158
        Parameters
159
        ----------
160
        size
161
            Output shape.
162
        dtype
163
            Desired data type.
164
        out
165
            Optional output array.
166

167
        Returns
168
        -------
169
            Array of random floats.
170
        """
171
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
172
        return self.axp.asarray(self.rng.random(size, dtype, out))  # type: ignore[arg-type]
8✔
173

174
    def normal(
8✔
175
        self,
176
        loc: float | FloatArray = 0.0,
177
        scale: float | FloatArray = 1.0,
178
        size: Size = None,
179
    ) -> AArray:
180
        """
181
        Draw samples from a Normal distribution (mean=loc, stdev=scale).
182

183
        Parameters
184
        ----------
185
        loc
186
            Mean of the distribution.
187
        scale
188
            Standard deviation of the distribution.
189
        size
190
            Output shape.
191

192
        Returns
193
        -------
194
            Array of samples from the normal distribution.
195
        """
196
        return self.axp.asarray(self.rng.normal(loc, scale, size))
8✔
197

198
    def poisson(self, lam: float | AArray, size: Size = None) -> AArray:
8✔
199
        """
200
        Draw samples from a Poisson distribution.
201

202
        Parameters
203
        ----------
204
        lam
205
            Expected number of events.
206
        size
207
            Output shape.
208

209
        Returns
210
        -------
211
            Array of samples from the Poisson distribution.
212
        """
213
        return self.axp.asarray(self.rng.poisson(lam, size))
8✔
214

215
    def standard_normal(
8✔
216
        self,
217
        size: Size = None,
218
        dtype: DTypeLike | None = None,
219
        out: AArray | None = None,
220
    ) -> AArray:
221
        """
222
        Draw samples from a standard Normal distribution (mean=0, stdev=1).
223

224
        Parameters
225
        ----------
226
        size
227
            Output shape.
228
        dtype
229
            Desired data type.
230
        out
231
            Optional output array.
232

233
        Returns
234
        -------
235
            Array of samples from the standard normal distribution.
236
        """
237
        dtype = dtype if dtype is not None else self.nxp.float64
8✔
238
        return self.axp.asarray(self.rng.standard_normal(size, dtype, out))  # type: ignore[arg-type]
8✔
239

240
    def uniform(
8✔
241
        self,
242
        low: float | AArray = 0.0,
243
        high: float | AArray = 1.0,
244
        size: Size = None,
245
    ) -> AArray:
246
        """
247
        Draw samples from a Uniform distribution.
248

249
        Parameters
250
        ----------
251
        low
252
            Lower bound of the distribution.
253
        high
254
            Upper bound of the distribution.
255
        size : Size, optional
256
            Output shape.
257

258
        Returns
259
        -------
260
            Array of samples from the uniform distribution.
261
        """
262
        return self.axp.asarray(self.rng.uniform(low, high, size))
8✔
263

264

265
class XPAdditions:
8✔
266
    """
267
    Additional functions missing from both array-api-strict and array-api-extra.
268

269
    This class provides wrappers for common array operations such as integration,
270
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
271
    and array-api-strict backends.
272

273
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
274
    for details.
275
    """
276

277
    xp: ModuleType
8✔
278
    backend: str
8✔
279

280
    def __init__(self, xp: ModuleType) -> None:
8✔
281
        """
282
        Initialize XPAdditions with the given array namespace.
283

284
        Parameters
285
        ----------
286
        xp
287
            The array library backend to use for array operations.
288
        """
289
        self.xp = xp
8✔
290

291
    def trapezoid(
8✔
292
        self,
293
        y: AnyArray,
294
        x: AnyArray = None,
295
        dx: float = 1.0,
296
        axis: int = -1,
297
    ) -> AnyArray:
298
        """
299
        Integrate along the given axis using the composite trapezoidal rule.
300

301
        Parameters
302
        ----------
303
        y
304
            Input array to integrate.
305
        x
306
            Sample points corresponding to y.
307
        dx
308
            Spacing between sample points.
309
        axis
310
            Axis along which to integrate.
311

312
        Returns
313
        -------
314
            Integrated result.
315

316
        Raises
317
        ------
318
        NotImplementedError
319
            If the array backend is not supported.
320

321
        Notes
322
        -----
323
        See https://github.com/glass-dev/glass/issues/646
324
        """
325
        if self.xp.__name__ == "jax.numpy":
8✔
326
            import glass.jax  # noqa: PLC0415
8✔
327

328
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
329

330
        if self.xp.__name__ == "numpy":
8✔
331
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
8✔
332

333
        if self.xp.__name__ == "array_api_strict":
8✔
334
            np = import_numpy(self.xp.__name__)
8✔
335

336
            # Using design principle of scipy (i.e. copy, use np, copy back)
337
            y_np = np.asarray(y, copy=True)
8✔
338
            x_np = np.asarray(x, copy=True)
8✔
339
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
8✔
340
            return self.xp.asarray(result_np, copy=True)
8✔
341

UNCOV
342
        msg = "the array backend in not supported"
×
UNCOV
343
        raise NotImplementedError(msg)
×
344

345
    def union1d(self, ar1: AnyArray, ar2: AnyArray) -> AnyArray:
8✔
346
        """
347
        Compute the set union of two 1D arrays.
348

349
        Parameters
350
        ----------
351
        ar1
352
            First input array.
353
        ar2
354
            Second input array.
355

356
        Returns
357
        -------
358
            The union of the two arrays.
359

360
        Raises
361
        ------
362
        NotImplementedError
363
            If the array backend is not supported.
364

365
        Notes
366
        -----
367
        See https://github.com/glass-dev/glass/issues/647
368
        """
369
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
370
            return self.xp.union1d(ar1, ar2)
8✔
371

372
        if self.xp.__name__ == "array_api_strict":
8✔
373
            np = import_numpy(self.xp.__name__)
8✔
374

375
            # Using design principle of scipy (i.e. copy, use np, copy back)
376
            ar1_np = np.asarray(ar1, copy=True)
8✔
377
            ar2_np = np.asarray(ar2, copy=True)
8✔
378
            result_np = np.union1d(ar1_np, ar2_np)
8✔
379
            return self.xp.asarray(result_np, copy=True)
8✔
380

UNCOV
381
        msg = "the array backend in not supported"
×
UNCOV
382
        raise NotImplementedError(msg)
×
383

384
    def interp(  # noqa: PLR0913
8✔
385
        self,
386
        x: AnyArray,
387
        x_points: AnyArray,
388
        y_points: AnyArray,
389
        left: float | None = None,
390
        right: float | None = None,
391
        period: float | None = None,
392
    ) -> AnyArray:
393
        """
394
        One-dimensional linear interpolation for monotonically increasing sample points.
395

396
        Parameters
397
        ----------
398
        x
399
            The x-coordinates at which to evaluate the interpolated values.
400
        x_points
401
            The x-coordinates of the data points.
402
        y_points
403
            The y-coordinates of the data points.
404
        left
405
            Value to return for x < x_points[0].
406
        right
407
            Value to return for x > x_points[-1].
408
        period
409
            Period for periodic interpolation.
410

411
        Returns
412
        -------
413
            Interpolated values.
414

415
        Raises
416
        ------
417
        NotImplementedError
418
            If the array backend is not supported.
419

420
        Notes
421
        -----
422
        See https://github.com/glass-dev/glass/issues/650
423
        """
424
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
425
            return self.xp.interp(
8✔
426
                x,
427
                x_points,
428
                y_points,
429
                left=left,
430
                right=right,
431
                period=period,
432
            )
433

434
        if self.xp.__name__ == "array_api_strict":
8✔
435
            np = import_numpy(self.xp.__name__)
8✔
436

437
            # Using design principle of scipy (i.e. copy, use np, copy back)
438
            x_np = np.asarray(x, copy=True)
8✔
439
            x_points_np = np.asarray(x_points, copy=True)
8✔
440
            y_points_np = np.asarray(y_points, copy=True)
8✔
441
            result_np = np.interp(
8✔
442
                x_np,
443
                x_points_np,
444
                y_points_np,
445
                left=left,
446
                right=right,
447
                period=period,
448
            )
449
            return self.xp.asarray(result_np, copy=True)
8✔
450

UNCOV
451
        msg = "the array backend in not supported"
×
UNCOV
452
        raise NotImplementedError(msg)
×
453

454
    def gradient(self, f: AnyArray) -> AnyArray:
8✔
455
        """
456
        Return the gradient of an N-dimensional array.
457

458
        Parameters
459
        ----------
460
        f
461
            Input array.
462

463
        Returns
464
        -------
465
            Gradient of the input array.
466

467
        Raises
468
        ------
469
        NotImplementedError
470
            If the array backend is not supported.
471

472
        Notes
473
        -----
474
        See https://github.com/glass-dev/glass/issues/648
475
        """
476
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
477
            return self.xp.gradient(f)
8✔
478

479
        if self.xp.__name__ == "array_api_strict":
8✔
480
            np = import_numpy(self.xp.__name__)
8✔
481

482
            # Using design principle of scipy (i.e. copy, use np, copy back)
483
            f_np = np.asarray(f, copy=True)
8✔
484
            result_np = np.gradient(f_np)
8✔
485
            return self.xp.asarray(result_np, copy=True)
8✔
486

UNCOV
487
        msg = "the array backend in not supported"
×
UNCOV
488
        raise NotImplementedError(msg)
×
489

490
    def linalg_lstsq(
8✔
491
        self,
492
        a: AnyArray,
493
        b: AnyArray,
494
        rcond: float | None = None,
495
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
496
        """
497
        Solve a linear least squares problem.
498

499
        Parameters
500
        ----------
501
        a
502
            Coefficient matrix.
503
        b
504
            Ordinate or "dependent variable" values.
505
        rcond
506
            Cut-off ratio for small singular values.
507

508
        Returns
509
        -------
510
        x
511
            Least-squares solution. If b is two-dimensional, the solutions are in the K
512
            columns of x.
513

514
        residuals
515
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
516
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
517
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
518

519
        rank
520
            Rank of matrix a.
521

522
        s
523
            Singular values of a.
524

525
        Raises
526
        ------
527
        NotImplementedError
528
            If the array backend is not supported.
529

530
        Notes
531
        -----
532
        See https://github.com/glass-dev/glass/issues/649
533
        """
534
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
535
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
8✔
536

537
        if self.xp.__name__ == "array_api_strict":
8✔
538
            np = import_numpy(self.xp.__name__)
8✔
539

540
            # Using design principle of scipy (i.e. copy, use np, copy back)
541
            a_np = np.asarray(a, copy=True)
8✔
542
            b_np = np.asarray(b, copy=True)
8✔
543
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
8✔
544
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
8✔
545

UNCOV
546
        msg = "the array backend in not supported"
×
UNCOV
547
        raise NotImplementedError(msg)
×
548

549
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
8✔
550
        """
551
        Evaluate the Einstein summation convention on the operands.
552

553
        Parameters
554
        ----------
555
        subscripts
556
            Specifies the subscripts for summation.
557
        *operands
558
            Arrays to be summed.
559

560
        Returns
561
        -------
562
            Result of the Einstein summation.
563

564
        Raises
565
        ------
566
        NotImplementedError
567
            If the array backend is not supported.
568

569
        Notes
570
        -----
571
        See https://github.com/glass-dev/glass/issues/657
572
        """
573
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
574
            return self.xp.einsum(subscripts, *operands)
8✔
575

576
        if self.xp.__name__ == "array_api_strict":
8✔
577
            np = import_numpy(self.xp.__name__)
8✔
578

579
            # Using design principle of scipy (i.e. copy, use np, copy back)
580
            operands_np = (np.asarray(op, copy=True) for op in operands)
8✔
581
            result_np = np.einsum(subscripts, *operands_np)
8✔
582
            return self.xp.asarray(result_np, copy=True)
8✔
583

UNCOV
584
        msg = "the array backend in not supported"
×
UNCOV
585
        raise NotImplementedError(msg)
×
586

587
    def apply_along_axis(
8✔
588
        self,
589
        func1d: Callable[..., Any],
590
        axis: int,
591
        arr: AnyArray,
592
        *args: object,
593
        **kwargs: object,
594
    ) -> AnyArray:
595
        """
596
        Apply a function to 1-D slices along the given axis.
597

598
        Parameters
599
        ----------
600
        func1d
601
            Function to apply to 1-D slices.
602
        axis
603
            Axis along which to apply the function.
604
        arr
605
            Input array.
606
        *args
607
            Additional positional arguments to pass to func1d.
608
        **kwargs
609
            Additional keyword arguments to pass to func1d.
610

611
        Returns
612
        -------
613
            Result of applying the function along the axis.
614

615
        Raises
616
        ------
617
        NotImplementedError
618
            If the array backend is not supported.
619

620
        Notes
621
        -----
622
        See https://github.com/glass-dev/glass/issues/651
623

624
        """
625
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
626
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
8✔
627

628
        if self.xp.__name__ == "array_api_strict":
8✔
629
            # Import here to prevent users relying on numpy unless in this instance
630
            np = import_numpy(self.xp.__name__)
8✔
631

632
            return self.xp.asarray(
8✔
633
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
634
                copy=True,
635
            )
636

UNCOV
637
        msg = "the array backend in not supported"
×
UNCOV
638
        raise NotImplementedError(msg)
×
639

640
    def vectorize(
8✔
641
        self,
642
        pyfunc: Callable[..., Any],
643
        otypes: tuple[type[float]],
644
    ) -> Callable[..., Any]:
645
        """
646
        Returns an object that acts like pyfunc, but takes arrays as input.
647

648
        Parameters
649
        ----------
650
        pyfunc
651
            Python function to vectorize.
652
        otypes
653
            Output types.
654

655
        Returns
656
        -------
657
            Vectorized function.
658

659
        Raises
660
        ------
661
        NotImplementedError
662
            If the array backend is not supported.
663

664
        Notes
665
        -----
666
        See https://github.com/glass-dev/glass/issues/671
667
        """
668
        if self.xp.__name__ == "numpy":
8✔
669
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
670

671
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
672
            # Import here to prevent users relying on numpy unless in this instance
673
            np = import_numpy(self.xp.__name__)
8✔
674

675
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
8✔
676

UNCOV
677
        msg = "the array backend in not supported"
×
UNCOV
678
        raise NotImplementedError(msg)
×
679

680
    def radians(self, deg_arr: AnyArray) -> AnyArray:
8✔
681
        """
682
        Convert angles from degrees to radians.
683

684
        Parameters
685
        ----------
686
        deg_arr
687
            Array of angles in degrees.
688

689
        Returns
690
        -------
691
            Array of angles in radians.
692

693
        Raises
694
        ------
695
        NotImplementedError
696
            If the array backend is not supported.
697
        """
698
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
699
            return self.xp.radians(deg_arr)
8✔
700

701
        if self.xp.__name__ == "array_api_strict":
8✔
702
            np = import_numpy(self.xp.__name__)
8✔
703

704
            return self.xp.asarray(np.radians(deg_arr))
8✔
705

UNCOV
706
        msg = "the array backend in not supported"
×
UNCOV
707
        raise NotImplementedError(msg)
×
708

709
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
8✔
710
        """
711
        Convert angles from radians to degrees.
712

713
        Parameters
714
        ----------
715
        deg_arr
716
            Array of angles in radians.
717

718
        Returns
719
        -------
720
            Array of angles in degrees.
721

722
        Raises
723
        ------
724
        NotImplementedError
725
            If the array backend is not supported.
726
        """
727
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
8✔
728
            return self.xp.degrees(deg_arr)
8✔
729

730
        if self.xp.__name__ == "array_api_strict":
8✔
731
            np = import_numpy(self.xp.__name__)
8✔
732

733
            return self.xp.asarray(np.degrees(deg_arr))
8✔
734

UNCOV
735
        msg = "the array backend in not supported"
×
UNCOV
736
        raise NotImplementedError(msg)
×
737

738
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
8✔
739
        """
740
        Wrapper for numpy.ndindex.
741

742
        See relevant docs for details:
743
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
744

745
        Raises
746
        ------
747
        NotImplementedError
748
            If the array backend is not supported.
749

750
        """
751
        if self.xp.__name__ == "numpy":
8✔
752
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
8✔
753

754
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
8✔
755
            np = import_numpy(self.xp.__name__)
8✔
756

757
            return np.ndindex(shape)  # type: ignore[no-any-return]
8✔
758

UNCOV
759
        msg = "the array backend in not supported"
×
UNCOV
760
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc