• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18113528024

29 Sep 2025 11:21PM UTC coverage: 81.475% (-2.8%) from 84.297%
18113528024

push

github

avivajpeyi
add GW tests

363 of 440 branches covered (82.5%)

Branch coverage included in aggregate %.

2830 of 3479 relevant lines covered (81.35%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.68
/src/log_psplines/plotting/utils.py
1
import dataclasses
2✔
2

3
import jax.numpy as jnp
2✔
4
import numpy as np
2✔
5

6
from ..datatypes import Periodogram
2✔
7
from ..psplines import LogPSplines
2✔
8

9
__all__ = ["unpack_data"]
2✔
10

11

12
@dataclasses.dataclass
2✔
13
class PlottingData:
2✔
14
    freqs: np.ndarray = None
2✔
15
    pdgrm: np.ndarray = None
2✔
16
    model: np.ndarray = None
2✔
17
    ci: np.ndarray = None
2✔
18

19
    @property
2✔
20
    def n(self):
2✔
21
        if self.freqs is not None:
×
22
            return len(self.freqs)
×
23
        elif self.pdgrm is not None:
×
24
            return len(self.pdgrm)
×
25
        elif self.model is not None:
×
26
            return len(self.model)
×
27
        else:
28
            raise ValueError("No data to get length from.")
×
29

30

31
def unpack_data(
2✔
32
    pdgrm: Periodogram = None,
33
    spline_model: LogPSplines = None,
34
    weights=None,
35
    yscalar=1.0,
36
    use_uniform_ci=True,
37
    use_parametric_model=True,
38
    freqs=None,
39
    posterior_psd=None,
40
):
41
    plt_dat = PlottingData()
2✔
42
    if pdgrm is not None:
2✔
43
        plt_dat.pdgrm = np.array(pdgrm.power, dtype=np.float64) * yscalar
2✔
44
        plt_dat.freqs = np.array(pdgrm.freqs)
2✔
45

46
    if plt_dat.freqs is None and freqs is None:
2✔
47
        plt_dat.freqs = np.linspace(0, 1, plt_dat.n)
×
48
    elif freqs is not None:
2✔
49
        plt_dat.freqs = freqs
×
50

51
    if posterior_psd is not None:
2✔
52
        ci = np.percentile(posterior_psd, q=jnp.array([16, 50, 84]), axis=0)
2✔
53
        plt_dat.ci = ci
2✔
54
        plt_dat.model = ci[1]
2✔
55

56
    if plt_dat.model is None and spline_model is not None:
2✔
57

58
        if weights is None:
2✔
59
            # just use the initial weights/0 weights
60
            ln_spline = spline_model(use_parametric_model=use_parametric_model)
2✔
61

62
        elif weights.ndim == 1:
×
63
            # only one set of weights -- no CI possible
64
            ln_spline = spline_model(weights, use_parametric_model)
×
65

66
        else:  # weights.ndim == 2
67
            # multiple sets of weights -- CI possible
68

69
            if weights.shape[0] > 500:
×
70
                # subsample to speed up
71
                idx = np.random.choice(
×
72
                    weights.shape[0], size=500, replace=False
73
                )
74
                weights = weights[idx]
×
75

76
            ln_splines = jnp.array(
×
77
                [spline_model(w, use_parametric_model) for w in weights]
78
            )
79

80
            if use_uniform_ci:
×
81
                ln_ci = _get_uni_ci(ln_splines)
×
82
            else:  # percentile
83
                ln_ci = jnp.percentile(
×
84
                    ln_splines, q=jnp.array([16, 50, 84]), axis=0
85
                )
86
            ln_ci = jnp.array(ln_ci)
×
87
            plt_dat.ci = np.exp(ln_ci, dtype=np.float64) * yscalar
×
88
            ln_spline = ln_ci[1]
×
89
        plt_dat.model = np.exp(ln_spline, dtype=np.float64) * yscalar
2✔
90

91
    return plt_dat
2✔
92

93

94
def _get_uni_ci(samples, alpha=0.1):
2✔
95
    """
96
    Compute a uniform (simultaneous) confidence band for a set of function samples.
97

98
    Args:
99
        samples (jnp.ndarray): Shape (num_samples, num_points) array of function samples.
100
        alpha (float): Significance level (default 0.1 for 90% CI).
101

102
    Returns:
103
        tuple: (lower_bound, median, upper_bound) arrays.
104
    """
105
    num_samples, num_points = samples.shape
×
106

107
    # Compute pointwise median and standard deviation
108
    median = jnp.median(samples, axis=0)
×
109
    std = jnp.std(samples, axis=0)
×
110

111
    # Compute the max deviation over all samples
112
    deviations = (samples - median[None, :]) / std[
×
113
        None, :
114
    ]  # Normalize deviations
115
    max_deviation = jnp.max(
×
116
        jnp.abs(deviations), axis=1
117
    )  # Max deviation per sample
118

119
    # Compute the scaling factor using the distribution of max deviations
120
    k_alpha = jnp.percentile(
×
121
        max_deviation, 100 * (1 - alpha)
122
    )  # Critical value
123

124
    # Compute uniform confidence bands
125
    lower_bound = median - k_alpha * std
×
126
    upper_bound = median + k_alpha * std
×
127

128
    return lower_bound, median, upper_bound
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc