• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

glass-dev / glass / 15018870526

14 May 2025 10:54AM UTC coverage: 95.081% (-0.02%) from 95.103%
15018870526

push

github

web-flow
gh-402: add NumPy-like interface for JAX RNGs (#610)

Co-authored-by: Nicolas Tessore <n.tessore@ucl.ac.uk>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick J. Roddy <patrickjamesroddy@gmail.com>

173 of 175 branches covered (98.86%)

Branch coverage included in aggregate %.

66 of 70 new or added lines in 2 files covered. (94.29%)

1238 of 1309 relevant lines covered (94.58%)

7.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.91
/glass/_array_api_utils.py
1
from __future__ import annotations
8✔
2

3
from typing import TYPE_CHECKING, Any, TypeAlias
8✔
4

5
import numpy as np
8✔
6

7
from glass.jax import Generator
8✔
8

9
if TYPE_CHECKING:
10
    from types import ModuleType
11

12
    from jaxtyping import Array
13
    from numpy.typing import NDArray
14

15

16
def get_namespace(*arrays: NDArray[Any] | Array) -> ModuleType:
8✔
17
    """
18
    Return the array library (array namespace) of input arrays
19
    if they belong to the same library or raise a :class:`ValueError`
20
    if they do not.
21
    """
22
    namespace = arrays[0].__array_namespace__()
8✔
23
    if any(
8✔
24
        array.__array_namespace__() != namespace
25
        for array in arrays
26
        if array is not None
27
    ):
28
        msg = "input arrays should belong to the same array library"
8✔
29
        raise ValueError(msg)
8✔
30

31
    return namespace
8✔
32

33

34
UnifiedGenerator: TypeAlias = np.random.Generator | Generator
8✔
35

36

37
def rng_dispatcher(array: NDArray[Any] | Array) -> UnifiedGenerator:
8✔
38
    """Dispatch RNG on the basis of the provided array."""
39
    backend = array.__array_namespace__().__name__
8✔
40
    if backend == "jax.numpy":
8✔
41
        return Generator(seed=42)
8✔
42
    if backend in {"numpy", "array_api_strict"}:
8✔
43
        return np.random.default_rng()
8✔
NEW
44
    msg = "the array backend in not supported"
×
NEW
45
    raise NotImplementedError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc