• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 17394129984

02 Sep 2025 05:25AM UTC coverage: 90.419% (+9.4%) from 81.047%
17394129984

push

github

avivajpeyi
run precommits

169 of 180 branches covered (93.89%)

Branch coverage included in aggregate %.

145 of 166 new or added lines in 15 files covered. (87.35%)

62 existing lines in 11 files now uncovered.

1492 of 1657 relevant lines covered (90.04%)

1.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.46
/src/log_psplines/plotting/utils.py
1
import dataclasses
2✔
2

3
import arviz as az
2✔
4
import jax.numpy as jnp
2✔
5
import numpy as np
2✔
6

7
from ..datatypes import Periodogram
2✔
8
from ..psplines import LogPSplines
2✔
9

10
__all__ = ["unpack_data"]
2✔
11

12

13
@dataclasses.dataclass
2✔
14
class PlottingData:
2✔
15
    freqs: np.ndarray = None
2✔
16
    pdgrm: np.ndarray = None
2✔
17
    model: np.ndarray = None
2✔
18
    ci: np.ndarray = None
2✔
19

20
    @property
2✔
21
    def n(self):
2✔
22
        if self.freqs is not None:
×
23
            return len(self.freqs)
×
24
        elif self.pdgrm is not None:
×
25
            return len(self.pdgrm)
×
26
        elif self.model is not None:
×
27
            return len(self.model)
×
28
        else:
29
            raise ValueError("No data to get length from.")
×
30

31

32
def unpack_data(
2✔
33
    pdgrm: Periodogram = None,
34
    spline_model: LogPSplines = None,
35
    weights=None,
36
    yscalar=1.0,
37
    use_uniform_ci=True,
38
    use_parametric_model=True,
39
    freqs=None,
40
):
41
    plt_dat = PlottingData()
2✔
42
    if pdgrm is not None:
2✔
43
        plt_dat.pdgrm = np.array(pdgrm.power, dtype=np.float64) * yscalar
2✔
44
        plt_dat.freqs = np.array(pdgrm.freqs)
2✔
45

46
    if spline_model is not None:
2✔
47

48
        if weights is None:
2✔
49
            # just use the initial weights/0 weights
50
            ln_spline = spline_model(use_parametric_model=use_parametric_model)
2✔
51

52
        elif weights.ndim == 1:
2✔
53
            # only one set of weights -- no CI possible
54
            ln_spline = spline_model(weights, use_parametric_model)
×
55

56
        else:  # weights.ndim == 2
57
            # multiple sets of weights -- CI possible
58
            ln_splines = jnp.array(
2✔
59
                [spline_model(w, use_parametric_model) for w in weights]
60
            )
61

62
            if use_uniform_ci:
2✔
63
                ln_ci = _get_uni_ci(ln_splines)
2✔
64
            else:  # percentile
UNCOV
65
                ln_ci = jnp.percentile(
×
66
                    ln_splines, q=jnp.array([16, 50, 84]), axis=0
67
                )
68
            ln_ci = jnp.array(ln_ci)
2✔
69
            plt_dat.ci = np.exp(ln_ci, dtype=np.float64) * yscalar
2✔
70
            ln_spline = ln_ci[1]
2✔
71
        plt_dat.model = np.exp(ln_spline, dtype=np.float64) * yscalar
2✔
72

73
    if plt_dat.freqs is None and freqs is None:
2✔
UNCOV
74
        plt_dat.freqs = np.linspace(0, 1, plt_dat.n)
×
75
    elif freqs is not None:
2✔
76
        plt_dat.freqs = freqs
×
77

78
    return plt_dat
2✔
79

80

81
def _get_uni_ci(samples, alpha=0.1):
2✔
82
    """
83
    Compute a uniform (simultaneous) confidence band for a set of function samples.
84

85
    Args:
86
        samples (jnp.ndarray): Shape (num_samples, num_points) array of function samples.
87
        alpha (float): Significance level (default 0.1 for 90% CI).
88

89
    Returns:
90
        tuple: (lower_bound, median, upper_bound) arrays.
91
    """
92
    num_samples, num_points = samples.shape
2✔
93

94
    # Compute pointwise median and standard deviation
95
    median = jnp.median(samples, axis=0)
2✔
96
    std = jnp.std(samples, axis=0)
2✔
97

98
    # Compute the max deviation over all samples
99
    deviations = (samples - median[None, :]) / std[
2✔
100
        None, :
101
    ]  # Normalize deviations
102
    max_deviation = jnp.max(
2✔
103
        jnp.abs(deviations), axis=1
104
    )  # Max deviation per sample
105

106
    # Compute the scaling factor using the distribution of max deviations
107
    k_alpha = jnp.percentile(
2✔
108
        max_deviation, 100 * (1 - alpha)
109
    )  # Critical value
110

111
    # Compute uniform confidence bands
112
    lower_bound = median - k_alpha * std
2✔
113
    upper_bound = median + k_alpha * std
2✔
114

115
    return lower_bound, median, upper_bound
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc