• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 16306634153

15 Jul 2025 11:36PM UTC coverage: 68.483%. Remained the same
16306634153

Pull #3

github

web-flow
Merge c41038cdc into 328e854df
Pull Request #3: Adding adaptive MCMC

51 of 78 branches covered (65.38%)

Branch coverage included in aggregate %.

1 of 1 new or added line in 1 file covered. (100.0%)

7 existing lines in 1 file now uncovered.

527 of 766 relevant lines covered (68.8%)

1.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.94
/src/log_psplines/plotting/utils.py
1
import dataclasses
2✔
2

3
import arviz as az
2✔
4
import jax.numpy as jnp
2✔
5
import numpy as np
6

2✔
7
from ..datatypes import Periodogram
2✔
8
from ..psplines import LogPSplines
9

2✔
10
__all__ = ["unpack_data"]
11

12

2✔
13
@dataclasses.dataclass
2✔
14
class PlottingData:
2✔
15
    freqs: jnp.ndarray = None
2✔
16
    pdgrm: jnp.ndarray = None
2✔
17
    model: jnp.ndarray = None
2✔
18
    ci: jnp.ndarray = None
19

2✔
20
    @property
2✔
UNCOV
21
    def n(self):
×
22
        if self.freqs is not None:
×
23
            return len(self.freqs)
×
24
        elif self.pdgrm is not None:
×
25
            return len(self.pdgrm)
×
26
        elif self.model is not None:
×
27
            return len(self.model)
UNCOV
28
        else:
×
29
            raise ValueError("No data to get length from.")
30

31

2✔
32
def unpack_data(
33
    pdgrm: Periodogram = None,
34
    spline_model: LogPSplines = None,
35
    weights=None,
36
    yscalar=1.0,
37
    use_uniform_ci=True,
38
    use_parametric_model=True,
39
    freqs=None,
40
):
2✔
41
    plt_dat = PlottingData()
2✔
42
    if pdgrm is not None:
2✔
43
        plt_dat.pdgrm = np.array(pdgrm.power, dtype=np.float64) * yscalar
2✔
44
        plt_dat.freqs = pdgrm.freqs
45

2✔
46
    if spline_model is not None:
2✔
47
        ln_param = spline_model.log_parametric_model
2✔
UNCOV
48
        if not use_parametric_model:
×
49
            ln_param = jnp.zeros_like(ln_param)
50

2✔
51
        if weights is None:
52
            # just use the initial weights/0 weights
2✔
53
            ln_spline = spline_model()
54

2✔
55
        elif weights.ndim == 1:
UNCOV
56
            # only one set of weights -- no CI possible
×
57
            ln_spline = spline_model(weights)
58

59
        else:  # weights.ndim == 2
60
            # multiple sets of weights -- CI possible
2✔
61
            ln_splines = jnp.array([spline_model(w) for w in weights])
62

2✔
63
            if use_uniform_ci:
2✔
64
                ln_ci = _get_uni_ci(ln_splines)
UNCOV
65
            else:  # percentile
×
66
                ln_ci = jnp.percentile(
67
                    ln_splines, q=jnp.array([16, 50, 84]), axis=0
68
                )
2✔
69
            ln_ci = jnp.array(ln_ci)
2✔
70
            plt_dat.ci = np.exp(ln_ci + ln_param, dtype=np.float64) * yscalar
2✔
71
            ln_spline = ln_ci[1]
2✔
72
        plt_dat.model = (
73
            np.exp(ln_spline + ln_param, dtype=np.float64) * yscalar
74
        )
75

2✔
UNCOV
76
    if plt_dat.freqs is None and freqs is None:
×
77
        plt_dat.freqs = np.linspace(0, 1, plt_dat.n)
2✔
UNCOV
78
    elif freqs is not None:
×
79
        plt_dat.freqs = freqs
80

2✔
81
    return plt_dat
82

83

2✔
84
def _get_uni_ci(samples, alpha=0.1):
85
    """
86
    Compute a uniform (simultaneous) confidence band for a set of function samples.
87

88
    Args:
89
        samples (jnp.ndarray): Shape (num_samples, num_points) array of function samples.
90
        alpha (float): Significance level (default 0.1 for 90% CI).
91

92
    Returns:
93
        tuple: (lower_bound, median, upper_bound) arrays.
94
    """
2✔
95
    num_samples, num_points = samples.shape
96

97
    # Compute pointwise median and standard deviation
2✔
98
    median = jnp.median(samples, axis=0)
2✔
99
    std = jnp.std(samples, axis=0)
100

101
    # Compute the max deviation over all samples
2✔
102
    deviations = (samples - median[None, :]) / std[
103
        None, :
104
    ]  # Normalize deviations
2✔
105
    max_deviation = jnp.max(
106
        jnp.abs(deviations), axis=1
107
    )  # Max deviation per sample
108

109
    # Compute the scaling factor using the distribution of max deviations
2✔
110
    k_alpha = jnp.percentile(
111
        max_deviation, 100 * (1 - alpha)
112
    )  # Critical value
113

114
    # Compute uniform confidence bands
2✔
115
    lower_bound = median - k_alpha * std
2✔
116
    upper_bound = median + k_alpha * std
117

2✔
118
    return lower_bound, median, upper_bound
119

120

121
def plot_diagnostics(
122
    idata: az.InferenceData,
123
    outdir: str,
124
    variables: list = ["phi", "delta"],
125
    figsize: tuple = (12, 8),
126
) -> None:
127
    """
128
    Plot MCMC diagnostics using arviz.
129

130
    Parameters
131
    ----------
132
    idata : az.InferenceData
133
        Inference data from adaptive MCMC
134
    variables : list
135
        Variables to plot
136
    figsize : tuple
137
        Figure size
138
    """
139
    import matplotlib.pyplot as plt
140

141
    # Trace plots
142
    az.plot_trace(idata, var_names=variables, figsize=figsize)
143
    plt.suptitle("Trace plots - Adaptive MCMC")
144
    plt.tight_layout()
145
    plt.savefig(f"{outdir}/trace_plots.png")
146

147
    # Summary statistics
148
    print("Summary Statistics:")
149
    print(az.summary(idata, var_names=variables))
150

151
    # Acceptance rate plot
152
    if "acceptance_rate" in idata.sample_stats:
153
        fig, ax = plt.subplots(figsize=(10, 4))
154
        accept_rates = idata.sample_stats.acceptance_rate.values.flatten()
155
        ax.plot(accept_rates, alpha=0.7)
156
        ax.axhline(
157
            idata.attrs.get("target_accept_rate", 0.44),
158
            color="red",
159
            linestyle="--",
160
            label="Target",
161
        )
162
        ax.set_xlabel("Iteration")
163
        ax.set_ylabel("Acceptance Rate")
164
        ax.set_title("Acceptance Rate Over Time")
165
        ax.legend()
166
        ax.grid(True, alpha=0.3)
167
        plt.tight_layout()
168
        plt.savefig(f"{outdir}/acceptance_rate.png")
169

170
    # Step size evolution
171
    if "step_size_mean" in idata.sample_stats:
172
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
173

174
        step_means = idata.sample_stats.step_size_mean.values.flatten()
175
        step_stds = idata.sample_stats.step_size_std.values.flatten()
176

177
        ax1.plot(step_means, alpha=0.7)
178
        ax1.set_xlabel("Iteration")
179
        ax1.set_ylabel("Mean Step Size")
180
        ax1.set_title("Step Size Evolution")
181
        ax1.grid(True, alpha=0.3)
182

183
        ax2.plot(step_stds, alpha=0.7, color="orange")
184
        ax2.set_xlabel("Iteration")
185
        ax2.set_ylabel("Step Size Std")
186
        ax2.set_title("Step Size Variability")
187
        ax2.grid(True, alpha=0.3)
188

189
        plt.tight_layout()
190
        plt.savefig(f"{outdir}/step_size_evolution.png")
191

192

193
def get_weights(
194
    idata: az.InferenceData,
195
    thin: int = 10,
196
) -> jnp.ndarray:
197
    """
198
    Extract weight samples from arviz InferenceData.
199

200
    Parameters
201
    ----------
202
    idata : az.InferenceData
203
        Inference data containing weight samples
204
    thin : int
205
        Thinning factor
206

207
    Returns
208
    -------
209
    jnp.ndarray
210
        Weight samples, shape (n_samples_thinned, n_weights)
211
    """
212
    # Get weight samples and flatten chains
213
    weight_samples = (
214
        idata.posterior.weights.values
215
    )  # (chains, draws, n_weights)
216
    weight_samples = weight_samples.reshape(
217
        -1, weight_samples.shape[-1]
218
    )  # (chains*draws, n_weights)
219

220
    # Thin samples
221
    return weight_samples[::thin]
222

223

224
def get_psd_samples_arviz(
225
    idata: az.InferenceData, spline_model: LogPSplines, thin: int = 10
226
) -> jnp.ndarray:
227
    """
228
    Extract PSD samples from arviz InferenceData.
229

230
    Parameters
231
    ----------
232
    idata : az.InferenceData
233
        Inference data containing weight samples
234
    spline_model : LogPSplines
235
        Spline model for reconstruction
236
    thin : int
237
        Thinning factor
238

239
    Returns
240
    -------
241
    jnp.ndarray
242
        PSD samples, shape (n_samples_thinned, n_frequencies)
243
    """
244
    # Get weight samples and flatten chains
245
    weight_samples = get_weights(idata, thin=thin)
246

247
    # Compute PSD samples
248
    psd_samples = []
249
    for weights in weight_samples:
250
        ln_spline = spline_model.basis.T @ weights
251
        ln_psd = ln_spline + spline_model.log_parametric_model
252
        psd_samples.append(jnp.exp(ln_psd))
253

254
    return jnp.array(psd_samples)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc