• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 21367203855

26 Jan 2026 05:24PM UTC coverage: 94.174% (+0.02%) from 94.156%
21367203855

push

github

web-flow
gh-902: port `cls2cov`, `spectra_indices`, `_glass_to_healpix_alm` in `fields.py` (#963)

Co-authored-by: Connor Aird <c.aird@ucl.ac.uk>

212 of 214 branches covered (99.07%)

Branch coverage included in aggregate %.

7 of 7 new or added lines in 1 file covered. (100.0%)

18 existing lines in 1 file now uncovered.

1372 of 1468 relevant lines covered (93.46%)

5.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.95
/glass/_array_api_utils.py
1
"""
2
Array API Utilities for glass.
3
==============================
4

5
This module provides utility functions and classes for working with multiple array
6
backends in the glass project, including NumPy, JAX, and array-api-strict. It includes
7
functions for importing backends, determining array namespaces, dispatching random
8
number generators, and providing missing functionality for array-api-strict through the
9
XPAdditions class.
10

11
Classes and functions in this module help ensure consistent behavior and compatibility
12
across different array libraries, and provide wrappers for common operations such as
13
integration, interpolation, and linear algebra.
14

15
"""
16

17
from __future__ import annotations
18

19
import functools
20
from typing import TYPE_CHECKING, Any
21

22
if TYPE_CHECKING:
23
    from collections.abc import Callable
24
    from types import ModuleType
25

26
    import numpy as np
27

28
    from glass._types import AnyArray
29

30

31
class CompatibleBackendNotFoundError(Exception):
6✔
32
    """
33
    Exception raised when an array library backend that
34
    implements a requested function, is not found.
35
    """
36

37
    def __init__(self, missing_backend: str, users_backend: str | None) -> None:
6✔
UNCOV
38
        self.message = (
×
39
            f"{missing_backend} is required here as "
40
            "no alternative has been provided by the user."
41
            if users_backend is None
42
            else f"GLASS depends on functions not supported by {users_backend}"
43
        )
UNCOV
44
        super().__init__(self.message)
×
45

46

47
def import_numpy(backend: str | None = None) -> ModuleType:
6✔
48
    """
49
    Import the NumPy module, raising a helpful error if NumPy is not installed.
50

51
    Parameters
52
    ----------
53
    backend
54
        The name of the backend requested by the user.
55

56
    Returns
57
    -------
58
        The NumPy module.
59

60
    Raises
61
    ------
62
    ModuleNotFoundError
63
        If NumPy is not found in the user's environment.
64

65
    Notes
66
    -----
67
    This is useful for explaining to the user why NumPy is required when their chosen
68
    backend does not implement a needed function.
69
    """
70
    try:
5✔
71
        import numpy  # noqa: ICN001, PLC0415
72

73
    except ModuleNotFoundError as err:
×
UNCOV
74
        raise CompatibleBackendNotFoundError("numpy", backend) from err
×
75
    else:
76
        return numpy
5✔
77

78

79
def default_xp() -> ModuleType:
6✔
80
    """Returns the library backend we default to if none is specified by the user."""
81
    return import_numpy()
5✔
82

83

84
class XPAdditions:
6✔
85
    """
86
    Additional functions missing from both array-api-strict and array-api-extra.
87

88
    This class provides wrappers for common array operations such as integration,
89
    interpolation, and linear algebra, ensuring compatibility across NumPy, JAX,
90
    and array-api-strict backends.
91

92
    This is intended as a temporary solution. See https://github.com/glass-dev/glass/issues/645
93
    for details.
94
    """
95

96
    def __init__(self, xp: ModuleType) -> None:
6✔
97
        """
98
        Initialize XPAdditions with the given array namespace.
99

100
        Parameters
101
        ----------
102
        xp
103
            The array library backend to use for array operations.
104
        """
105
        self.xp = xp
6✔
106

107
    def trapezoid(
6✔
108
        self,
109
        y: AnyArray,
110
        x: AnyArray = None,
111
        dx: float = 1.0,
112
        axis: int = -1,
113
    ) -> AnyArray:
114
        """
115
        Integrate along the given axis using the composite trapezoidal rule.
116

117
        Parameters
118
        ----------
119
        y
120
            Input array to integrate.
121
        x
122
            Sample points corresponding to y.
123
        dx
124
            Spacing between sample points.
125
        axis
126
            Axis along which to integrate.
127

128
        Returns
129
        -------
130
            Integrated result.
131

132
        Raises
133
        ------
134
        NotImplementedError
135
            If the array backend is not supported.
136

137
        Notes
138
        -----
139
        See https://github.com/glass-dev/glass/issues/646
140
        """
141
        if self.xp.__name__ == "jax.numpy":
6✔
142
            import glass.jax  # noqa: PLC0415
143

144
            return glass.jax.trapezoid(y, x=x, dx=dx, axis=axis)
5✔
145

146
        if self.xp.__name__ == "numpy":
6✔
147
            return self.xp.trapezoid(y, x=x, dx=dx, axis=axis)
6✔
148

149
        if self.xp.__name__ == "array_api_strict":
5✔
150
            np = import_numpy(self.xp.__name__)
5✔
151

152
            # Using design principle of scipy (i.e. copy, use np, copy back)
153
            y_np = np.asarray(y, copy=True)
5✔
154
            x_np = np.asarray(x, copy=True)
5✔
155
            result_np = np.trapezoid(y_np, x_np, dx=dx, axis=axis)
5✔
156
            return self.xp.asarray(result_np, copy=True)
5✔
157

158
        msg = "the array backend in not supported"
×
UNCOV
159
        raise NotImplementedError(msg)
×
160

161
    def interp(  # noqa: PLR0913
6✔
162
        self,
163
        x: AnyArray,
164
        x_points: AnyArray,
165
        y_points: AnyArray,
166
        left: float | None = None,
167
        right: float | None = None,
168
        period: float | None = None,
169
    ) -> AnyArray:
170
        """
171
        One-dimensional linear interpolation for monotonically increasing sample points.
172

173
        Parameters
174
        ----------
175
        x
176
            The x-coordinates at which to evaluate the interpolated values.
177
        x_points
178
            The x-coordinates of the data points.
179
        y_points
180
            The y-coordinates of the data points.
181
        left
182
            Value to return for x < x_points[0].
183
        right
184
            Value to return for x > x_points[-1].
185
        period
186
            Period for periodic interpolation.
187

188
        Returns
189
        -------
190
            Interpolated values.
191

192
        Raises
193
        ------
194
        NotImplementedError
195
            If the array backend is not supported.
196

197
        Notes
198
        -----
199
        See https://github.com/glass-dev/glass/issues/650
200
        """
201
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
6✔
202
            return self.xp.interp(
6✔
203
                x,
204
                x_points,
205
                y_points,
206
                left=left,
207
                right=right,
208
                period=period,
209
            )
210

211
        if self.xp.__name__ == "array_api_strict":
5✔
212
            np = import_numpy(self.xp.__name__)
5✔
213

214
            # Using design principle of scipy (i.e. copy, use np, copy back)
215
            x_np = np.asarray(x, copy=True)
5✔
216
            x_points_np = np.asarray(x_points, copy=True)
5✔
217
            y_points_np = np.asarray(y_points, copy=True)
5✔
218
            result_np = np.interp(
5✔
219
                x_np,
220
                x_points_np,
221
                y_points_np,
222
                left=left,
223
                right=right,
224
                period=period,
225
            )
226
            return self.xp.asarray(result_np, copy=True)
5✔
227

228
        msg = "the array backend in not supported"
×
UNCOV
229
        raise NotImplementedError(msg)
×
230

231
    def gradient(self, f: AnyArray) -> AnyArray:
6✔
232
        """
233
        Return the gradient of an N-dimensional array.
234

235
        Parameters
236
        ----------
237
        f
238
            Input array.
239

240
        Returns
241
        -------
242
            Gradient of the input array.
243

244
        Raises
245
        ------
246
        NotImplementedError
247
            If the array backend is not supported.
248

249
        Notes
250
        -----
251
        See https://github.com/glass-dev/glass/issues/648
252
        """
253
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
254
            return self.xp.gradient(f)
5✔
255

256
        if self.xp.__name__ == "array_api_strict":
5✔
257
            np = import_numpy(self.xp.__name__)
5✔
258

259
            # Using design principle of scipy (i.e. copy, use np, copy back)
260
            f_np = np.asarray(f, copy=True)
5✔
261
            result_np = np.gradient(f_np)
5✔
262
            return self.xp.asarray(result_np, copy=True)
5✔
263

264
        msg = "the array backend in not supported"
×
UNCOV
265
        raise NotImplementedError(msg)
×
266

267
    def linalg_lstsq(
6✔
268
        self,
269
        a: AnyArray,
270
        b: AnyArray,
271
        rcond: float | None = None,
272
    ) -> tuple[AnyArray, AnyArray, int, AnyArray]:
273
        """
274
        Solve a linear least squares problem.
275

276
        Parameters
277
        ----------
278
        a
279
            Coefficient matrix.
280
        b
281
            Ordinate or "dependent variable" values.
282
        rcond
283
            Cut-off ratio for small singular values.
284

285
        Returns
286
        -------
287
        x
288
            Least-squares solution. If b is two-dimensional, the solutions are in the K
289
            columns of x.
290

291
        residuals
292
            Sums of squared residuals: Squared Euclidean 2-norm for each column in b - a
293
            @ x. If the rank of a is < N or M <= N, this is an empty array. If b is
294
            1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).
295

296
        rank
297
            Rank of matrix a.
298

299
        s
300
            Singular values of a.
301

302
        Raises
303
        ------
304
        NotImplementedError
305
            If the array backend is not supported.
306

307
        Notes
308
        -----
309
        See https://github.com/glass-dev/glass/issues/649
310
        """
311
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
312
            return self.xp.linalg.lstsq(a, b, rcond=rcond)  # type: ignore[no-any-return]
5✔
313

314
        if self.xp.__name__ == "array_api_strict":
5✔
315
            np = import_numpy(self.xp.__name__)
5✔
316

317
            # Using design principle of scipy (i.e. copy, use np, copy back)
318
            a_np = np.asarray(a, copy=True)
5✔
319
            b_np = np.asarray(b, copy=True)
5✔
320
            result_np = np.linalg.lstsq(a_np, b_np, rcond=rcond)
5✔
321
            return tuple(self.xp.asarray(res, copy=True) for res in result_np)
5✔
322

323
        msg = "the array backend in not supported"
×
UNCOV
324
        raise NotImplementedError(msg)
×
325

326
    def einsum(self, subscripts: str, *operands: AnyArray) -> AnyArray:
6✔
327
        """
328
        Evaluate the Einstein summation convention on the operands.
329

330
        Parameters
331
        ----------
332
        subscripts
333
            Specifies the subscripts for summation.
334
        *operands
335
            Arrays to be summed.
336

337
        Returns
338
        -------
339
            Result of the Einstein summation.
340

341
        Raises
342
        ------
343
        NotImplementedError
344
            If the array backend is not supported.
345

346
        Notes
347
        -----
348
        See https://github.com/glass-dev/glass/issues/657
349
        """
350
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
351
            return self.xp.einsum(subscripts, *operands)
5✔
352

353
        if self.xp.__name__ == "array_api_strict":
5✔
354
            np = import_numpy(self.xp.__name__)
5✔
355

356
            # Using design principle of scipy (i.e. copy, use np, copy back)
357
            operands_np = (np.asarray(op, copy=True) for op in operands)
5✔
358
            result_np = np.einsum(subscripts, *operands_np)
5✔
359
            return self.xp.asarray(result_np, copy=True)
5✔
360

361
        msg = "the array backend in not supported"
×
UNCOV
362
        raise NotImplementedError(msg)
×
363

364
    def apply_along_axis(
6✔
365
        self,
366
        func: Callable[..., Any],
367
        func_inputs: tuple[Any, ...],
368
        axis: int,
369
        arr: AnyArray,
370
        *args: object,
371
        **kwargs: object,
372
    ) -> AnyArray:
373
        """
374
        Apply a function to 1-D slices along the given axis.
375

376
        Rather than accepting a partial function as usual, the function and
377
        its inputs are passed in separately for better compatibility.
378

379
        Parameters
380
        ----------
381
        func
382
            Function to apply to 1-D slices.
383
        func_inputs
384
            All inputs to the func besides arr.
385
        axis
386
            Axis along which to apply the function.
387
        arr
388
            Input array.
389
        *args
390
            Additional positional arguments to pass to func1d.
391
        **kwargs
392
            Additional keyword arguments to pass to func1d.
393

394
        Returns
395
        -------
396
            Result of applying the function along the axis.
397

398
        Raises
399
        ------
400
        NotImplementedError
401
            If the array backend is not supported.
402

403
        Notes
404
        -----
405
        See https://github.com/glass-dev/glass/issues/651
406

407
        """
408
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
409
            func1d = functools.partial(func, *func_inputs)
5✔
410
            return self.xp.apply_along_axis(func1d, axis, arr, *args, **kwargs)
5✔
411

412
        if self.xp.__name__ == "array_api_strict":
5✔
413
            # Import here to prevent users relying on numpy unless in this instance
414
            np = import_numpy(self.xp.__name__)
5✔
415

416
            # Everything must be NumPy to avoid mismatches between array types
417
            inputs_np = (np.asarray(inp) for inp in func_inputs)
5✔
418
            func1d = functools.partial(func, *inputs_np)
5✔
419

420
            return self.xp.asarray(
5✔
421
                np.apply_along_axis(func1d, axis, arr, *args, **kwargs),
422
                copy=True,
423
            )
424

UNCOV
425
        msg = "the array backend in not supported"
×
UNCOV
426
        raise NotImplementedError(msg)
×
427

428
    def vectorize(
6✔
429
        self,
430
        pyfunc: Callable[..., Any],
431
        otypes: tuple[type[float]],
432
    ) -> Callable[..., Any]:
433
        """
434
        Returns an object that acts like pyfunc, but takes arrays as input.
435

436
        Parameters
437
        ----------
438
        pyfunc
439
            Python function to vectorize.
440
        otypes
441
            Output types.
442

443
        Returns
444
        -------
445
            Vectorized function.
446

447
        Raises
448
        ------
449
        NotImplementedError
450
            If the array backend is not supported.
451

452
        Notes
453
        -----
454
        See https://github.com/glass-dev/glass/issues/671
455
        """
456
        if self.xp.__name__ == "numpy":
5✔
457
            return self.xp.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
458

459
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
460
            # Import here to prevent users relying on numpy unless in this instance
461
            np = import_numpy(self.xp.__name__)
5✔
462

463
            return np.vectorize(pyfunc, otypes=otypes)  # type: ignore[no-any-return]
5✔
464

UNCOV
465
        msg = "the array backend in not supported"
×
UNCOV
466
        raise NotImplementedError(msg)
×
467

468
    def radians(self, deg_arr: AnyArray) -> AnyArray:
6✔
469
        """
470
        Convert angles from degrees to radians.
471

472
        Parameters
473
        ----------
474
        deg_arr
475
            Array of angles in degrees.
476

477
        Returns
478
        -------
479
            Array of angles in radians.
480

481
        Raises
482
        ------
483
        NotImplementedError
484
            If the array backend is not supported.
485
        """
486
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
5✔
487
            return self.xp.radians(deg_arr)
5✔
488

489
        if self.xp.__name__ == "array_api_strict":
5✔
490
            np = import_numpy(self.xp.__name__)
5✔
491

492
            return self.xp.asarray(np.radians(deg_arr))
5✔
493

UNCOV
494
        msg = "the array backend in not supported"
×
UNCOV
495
        raise NotImplementedError(msg)
×
496

497
    def degrees(self, deg_arr: AnyArray) -> AnyArray:
6✔
498
        """
499
        Convert angles from radians to degrees.
500

501
        Parameters
502
        ----------
503
        deg_arr
504
            Array of angles in radians.
505

506
        Returns
507
        -------
508
            Array of angles in degrees.
509

510
        Raises
511
        ------
512
        NotImplementedError
513
            If the array backend is not supported.
514
        """
515
        if self.xp.__name__ in {"numpy", "jax.numpy"}:
6✔
516
            return self.xp.degrees(deg_arr)
6✔
517

518
        if self.xp.__name__ == "array_api_strict":
5✔
519
            np = import_numpy(self.xp.__name__)
5✔
520

521
            return self.xp.asarray(np.degrees(deg_arr))
5✔
522

UNCOV
523
        msg = "the array backend in not supported"
×
UNCOV
524
        raise NotImplementedError(msg)
×
525

526
    def ndindex(self, shape: tuple[int, ...]) -> np.ndindex:
6✔
527
        """
528
        Wrapper for numpy.ndindex.
529

530
        See relevant docs for details:
531
        - NumPy, https://numpy.org/doc/2.2/reference/generated/numpy.ndindex.html
532

533
        Raises
534
        ------
535
        NotImplementedError
536
            If the array backend is not supported.
537

538
        """
539
        if self.xp.__name__ == "numpy":
6✔
540
            return self.xp.ndindex(shape)  # type: ignore[no-any-return]
6✔
541

542
        if self.xp.__name__ in {"array_api_strict", "jax.numpy"}:
5✔
543
            np = import_numpy(self.xp.__name__)
5✔
544

545
            return np.ndindex(shape)  # type: ignore[no-any-return]
5✔
546

UNCOV
547
        msg = "the array backend in not supported"
×
UNCOV
548
        raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc