• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 17657445007

11 Sep 2025 09:12PM UTC coverage: 89.856% (-0.06%) from 89.917%
17657445007

push

github

avivajpeyi
Merge branch 'main' of github.com:nz-gravity/LogPSplinePSD

179 of 190 branches covered (94.21%)

Branch coverage included in aggregate %.

77 of 94 new or added lines in 8 files covered. (81.91%)

9 existing lines in 3 files now uncovered.

1566 of 1752 relevant lines covered (89.38%)

1.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.81
/src/log_psplines/plotting/utils.py
1
import dataclasses
2✔
2

3
import arviz as az
2✔
4
import jax.numpy as jnp
2✔
5
import numpy as np
2✔
6

7
from ..datatypes import Periodogram
2✔
8
from ..psplines import LogPSplines
2✔
9

10
__all__ = ["unpack_data"]
2✔
11

12

13
@dataclasses.dataclass
2✔
14
class PlottingData:
2✔
15
    freqs: np.ndarray = None
2✔
16
    pdgrm: np.ndarray = None
2✔
17
    model: np.ndarray = None
2✔
18
    ci: np.ndarray = None
2✔
19

20
    @property
2✔
21
    def n(self):
2✔
22
        if self.freqs is not None:
×
23
            return len(self.freqs)
×
24
        elif self.pdgrm is not None:
×
25
            return len(self.pdgrm)
×
26
        elif self.model is not None:
×
27
            return len(self.model)
×
28
        else:
29
            raise ValueError("No data to get length from.")
×
30

31

32
def unpack_data(
2✔
33
    pdgrm: Periodogram = None,
34
    spline_model: LogPSplines = None,
35
    weights=None,
36
    yscalar=1.0,
37
    use_uniform_ci=True,
38
    use_parametric_model=True,
39
    freqs=None,
40
):
41
    plt_dat = PlottingData()
2✔
42
    if pdgrm is not None:
2✔
43
        plt_dat.pdgrm = np.array(pdgrm.power, dtype=np.float64) * yscalar
2✔
44
        plt_dat.freqs = np.array(pdgrm.freqs)
2✔
45

46
    if spline_model is not None:
2✔
47

48
        if weights is None:
2✔
49
            # just use the initial weights/0 weights
50
            ln_spline = spline_model(use_parametric_model=use_parametric_model)
2✔
51

52
        elif weights.ndim == 1:
2✔
53
            # only one set of weights -- no CI possible
54
            ln_spline = spline_model(weights, use_parametric_model)
×
55

56
        else:  # weights.ndim == 2
57
            # multiple sets of weights -- CI possible
58

59
            if weights.shape[0] > 500:
2✔
60
                # subsample to speed up
NEW
61
                idx = np.random.choice(
×
62
                    weights.shape[0], size=500, replace=False
63
                )
NEW
64
                weights = weights[idx]
×
65

66
            ln_splines = jnp.array(
2✔
67
                [spline_model(w, use_parametric_model) for w in weights]
68
            )
69

70
            if use_uniform_ci:
2✔
71
                ln_ci = _get_uni_ci(ln_splines)
2✔
72
            else:  # percentile
UNCOV
73
                ln_ci = jnp.percentile(
×
74
                    ln_splines, q=jnp.array([16, 50, 84]), axis=0
75
                )
76
            ln_ci = jnp.array(ln_ci)
2✔
77
            plt_dat.ci = np.exp(ln_ci, dtype=np.float64) * yscalar
2✔
78
            ln_spline = ln_ci[1]
2✔
79
        plt_dat.model = np.exp(ln_spline, dtype=np.float64) * yscalar
2✔
80

81
    if plt_dat.freqs is None and freqs is None:
2✔
UNCOV
82
        plt_dat.freqs = np.linspace(0, 1, plt_dat.n)
×
83
    elif freqs is not None:
2✔
UNCOV
84
        plt_dat.freqs = freqs
×
85

86
    return plt_dat
2✔
87

88

89
def _get_uni_ci(samples, alpha=0.1):
2✔
90
    """
91
    Compute a uniform (simultaneous) confidence band for a set of function samples.
92

93
    Args:
94
        samples (jnp.ndarray): Shape (num_samples, num_points) array of function samples.
95
        alpha (float): Significance level (default 0.1 for 90% CI).
96

97
    Returns:
98
        tuple: (lower_bound, median, upper_bound) arrays.
99
    """
100
    num_samples, num_points = samples.shape
2✔
101

102
    # Compute pointwise median and standard deviation
103
    median = jnp.median(samples, axis=0)
2✔
104
    std = jnp.std(samples, axis=0)
2✔
105

106
    # Compute the max deviation over all samples
107
    deviations = (samples - median[None, :]) / std[
2✔
108
        None, :
109
    ]  # Normalize deviations
110
    max_deviation = jnp.max(
2✔
111
        jnp.abs(deviations), axis=1
112
    )  # Max deviation per sample
113

114
    # Compute the scaling factor using the distribution of max deviations
115
    k_alpha = jnp.percentile(
2✔
116
        max_deviation, 100 * (1 - alpha)
117
    )  # Critical value
118

119
    # Compute uniform confidence bands
120
    lower_bound = median - k_alpha * std
2✔
121
    upper_bound = median + k_alpha * std
2✔
122

123
    return lower_bound, median, upper_bound
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc