• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 19686623502

25 Nov 2025 10:58PM UTC coverage: 80.0% (+2.2%) from 77.849%
19686623502

push

github

avivajpeyi
fix: CI fixes

789 of 929 branches covered (84.93%)

Branch coverage included in aggregate %.

23 of 23 new or added lines in 5 files covered. (100.0%)

227 existing lines in 8 files now uncovered.

4907 of 6191 relevant lines covered (79.26%)

1.59 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.67
/src/log_psplines/plotting/base.py
1
"""
2
Base plotting utilities for shared functionality across plotting modules.
3
"""
4

5
import copy
2✔
6
import os
2✔
7
from dataclasses import dataclass
2✔
8
from functools import wraps
2✔
9
from typing import Any, Callable, Dict, Optional, Tuple, Union
2✔
10

11
import jax.numpy as jnp
2✔
12
import matplotlib.pyplot as plt
2✔
13
import numpy as np
2✔
14

15
from ..logger import logger
2✔
16

17
# Color constants used across plotting modules
18
COLORS = {
2✔
19
    "data": "#d3d3d3",  # lightgray
20
    "model": "#ff7f0e",  # tab:orange
21
    "knots": "#d62728",  # tab:red
22
    "true": "#000000",  # black
23
    "empirical": "#404040",  # dark gray
24
    "ci_fill": "#1f77b4",  # tab:blue
25
    "coherence": "#1f77b4",  # tab:blue
26
    "real": "#2ca02c",  # tab:green
27
    "imag": "#ff7f0e",  # tab:orange
28
}
29

30

31
@dataclass
2✔
32
class PlotConfig:
2✔
33
    """Configuration for plotting parameters."""
34

35
    figsize: tuple = (12, 8)
2✔
36
    dpi: int = 150
2✔
37
    fontsize: int = 11
2✔
38
    labelsize: int = 12
2✔
39
    titlesize: int = 12
2✔
40
    linewidth: float = 1.2
2✔
41
    markersize: float = 4.5
2✔
42
    alpha: float = 0.7
2✔
43

44

45
def safe_plot(filename: str, dpi: int = 150):
2✔
46
    """Decorator for safe plotting with error handling."""
47

48
    def decorator(plot_func: Callable):
2✔
49
        @wraps(plot_func)
2✔
50
        def wrapper(*args, **kwargs):
2✔
51
            try:
2✔
52
                logger.debug(f"--- plotting: {os.path.basename(filename)}")
2✔
53
                result = plot_func(*args, **kwargs)
2✔
54
                plt.savefig(filename, dpi=dpi, bbox_inches="tight")
2✔
55
                plt.close()
2✔
56
                return True
2✔
57
            except Exception as e:
×
UNCOV
58
                logger.warning(
×
59
                    f"Failed to create {os.path.basename(filename)}: {e}"
60
                )
61
                plt.close("all")
×
UNCOV
62
                return False
×
63

64
        return wrapper
2✔
65

66
    return decorator
2✔
67

68

69
def extract_plotting_data(idata, weights_key: str = None) -> Dict[str, Any]:
2✔
70
    """
71
    Extract common plotting data from inference data object.
72

73
    Args:
74
        idata: ArviZ InferenceData object
75
        weights_key: Key for weights in posterior (optional)
76

77
    Returns:
78
        Dictionary containing extracted data
79
    """
80
    from ..arviz_utils import (
2✔
81
        get_periodogram,
82
        get_spline_model,
83
        get_weights,
84
    )
85

86
    data = {}
2✔
87

88
    # Extract core data - handle univariate, multivariate, and VI cases
89
    try:
2✔
90
        data["periodogram"] = get_periodogram(idata)
2✔
UNCOV
91
    except KeyError:
×
92
        # For multivariate or VI data, periodogram might not be available
93
        # or might be stored differently
UNCOV
94
        data["periodogram"] = None
×
95

96
    # Extract spline model - handle VI case where 'knots' might not exist
97
    try:
2✔
98
        data["spline_model"] = get_spline_model(idata)
2✔
99
    except KeyError:
2✔
100
        # For VI data, spline model might be stored differently
101
        data["spline_model"] = None
2✔
102

103
    # Extract weights - handle different data structures
104
    try:
2✔
105
        data["weights"] = get_weights(idata, weights_key)
2✔
106
    except (KeyError, AttributeError):
2✔
107
        # For VI data, weights might be stored differently
108
        data["weights"] = None
2✔
109

110
    def _maybe_set_psd_quantiles(dataset, prefix: str) -> bool:
2✔
111
        if dataset is None:
2✔
UNCOV
112
            return False
×
113

114
        added = False
2✔
115
        if "psd" in dataset:
2✔
116
            arr = dataset["psd"]
2✔
117
            if "freq" in arr.coords and "frequencies" not in data:
2✔
118
                data["frequencies"] = np.asarray(arr.coords["freq"].values)
2✔
119
            data[f"{prefix}_psd_quantiles"] = {
2✔
120
                "percentile": np.asarray(arr.coords["percentile"].values),
121
                "values": np.asarray(arr.values),
122
            }
123
            added = True
2✔
124
        if "psd_matrix_real" in dataset:
2✔
125
            freq_coord = dataset["psd_matrix_real"].coords
2✔
126
            if (
2✔
127
                "freq" in freq_coord
128
                and "frequencies" not in data
129
                and freq_coord["freq"] is not None
130
            ):
131
                data["frequencies"] = np.asarray(
2✔
132
                    freq_coord["freq"].values, dtype=float
133
                )
134
            data[f"{prefix}_psd_matrix_quantiles"] = {
2✔
135
                "percentile": np.asarray(
136
                    dataset["psd_matrix_real"].coords["percentile"].values
137
                ),
138
                "real": np.asarray(dataset["psd_matrix_real"].values),
139
                "imag": np.asarray(dataset["psd_matrix_imag"].values),
140
                "coherence": (
141
                    np.asarray(dataset["coherence"].values)
142
                    if "coherence" in dataset
143
                    else None
144
                ),
145
            }
146
            added = True
2✔
147
        return added
2✔
148

149
    if hasattr(idata, "posterior_psd"):
2✔
150
        _maybe_set_psd_quantiles(idata.posterior_psd, "posterior")
2✔
151
    if hasattr(idata, "vi_posterior_psd"):
2✔
152
        _maybe_set_psd_quantiles(idata.vi_posterior_psd, "vi")
2✔
153

154
    idata_attrs = getattr(idata, "attrs", {}) or {}
2✔
155
    only_vi_mode = bool(idata_attrs.get("only_vi"))
2✔
156

157
    # Backwards compatibility: fall back to VI quantiles when posterior absent
158
    if "posterior_psd_quantiles" not in data and "vi_psd_quantiles" in data:
2✔
UNCOV
159
        data["posterior_psd_quantiles"] = copy.deepcopy(
×
160
            data["vi_psd_quantiles"]
161
        )
162
    elif only_vi_mode and "vi_psd_quantiles" in data:
2✔
UNCOV
163
        data["posterior_psd_quantiles"] = copy.deepcopy(
×
164
            data["vi_psd_quantiles"]
165
        )
166
    if (
2✔
167
        "posterior_psd_matrix_quantiles" not in data
168
        and "vi_psd_matrix_quantiles" in data
169
    ):
UNCOV
170
        data["posterior_psd_matrix_quantiles"] = copy.deepcopy(
×
171
            data["vi_psd_matrix_quantiles"]
172
        )
173
    elif only_vi_mode and "vi_psd_matrix_quantiles" in data:
2✔
174
        data["posterior_psd_matrix_quantiles"] = copy.deepcopy(
2✔
175
            data["vi_psd_matrix_quantiles"]
176
        )
177
    elif (
2✔
178
        "posterior_psd_matrix_quantiles" in data
179
        and "vi_psd_matrix_quantiles" in data
180
    ):
181
        posterior_q = data["posterior_psd_matrix_quantiles"]
2✔
182
        vi_q = data["vi_psd_matrix_quantiles"]
2✔
183
        if (
2✔
184
            posterior_q.get("coherence") is None
185
            and vi_q.get("coherence") is not None
186
        ):
UNCOV
187
            posterior_q["coherence"] = np.asarray(vi_q["coherence"])
×
188

189
    # Extract true PSD if available
190
    if "true_psd" in idata_attrs:
2✔
191
        data["true_psd"] = idata_attrs["true_psd"]
2✔
192

193
    # Extract frequencies if available
194
    if "frequencies" in idata_attrs:
2✔
195
        data["frequencies"] = idata_attrs["frequencies"]
2✔
196

197
    return data
2✔
198

199

200
def compute_confidence_intervals(
2✔
201
    samples: np.ndarray,
202
    quantiles: Tuple[float, float, float] = (16, 50, 84),
203
    method: str = "percentile",
204
    alpha: float = 0.1,
205
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
206
    """
207
    Compute confidence intervals from posterior samples.
208

209
    Args:
210
        samples: Array of posterior samples
211
        quantiles: Tuple of quantiles to compute (low, median, high)
212
        method: Method for CI computation ('percentile' or 'uniform')
213
        alpha: Significance level for uniform CI
214

215
    Returns:
216
        Tuple of (lower_bound, median, upper_bound)
217
    """
218
    if method == "percentile":
×
UNCOV
219
        return jnp.percentile(samples, q=jnp.array(quantiles), axis=0)
×
220
    elif method == "uniform":
×
221
        return _compute_uniform_ci(samples, alpha)
×
222
    else:
223
        raise ValueError(f"Unknown CI method: {method}")
×
224

225

226
def _compute_uniform_ci(samples: np.ndarray, alpha: float = 0.1):
2✔
227
    """
228
    Compute uniform (simultaneous) confidence intervals.
229

230
    Args:
231
        samples: Shape (num_samples, num_points) array of function samples
232
        alpha: Significance level
233

234
    Returns:
235
        Tuple of (lower_bound, median, upper_bound)
236
    """
UNCOV
237
    num_samples, num_points = samples.shape
×
238

239
    # Compute pointwise median and standard deviation
UNCOV
240
    median = jnp.median(samples, axis=0)
×
UNCOV
241
    std = jnp.std(samples, axis=0)
×
242

243
    # Compute the max deviation over all samples
UNCOV
244
    deviations = (samples - median[None, :]) / std[None, :]
×
245
    max_deviation = jnp.max(jnp.abs(deviations), axis=1)
×
246

247
    # Compute the scaling factor using the distribution of max deviations
UNCOV
248
    k_alpha = jnp.percentile(max_deviation, 100 * (1 - alpha))
×
249

250
    # Compute uniform confidence bands
251
    lower_bound = median - k_alpha * std
×
UNCOV
252
    upper_bound = median + k_alpha * std
×
253

254
    return lower_bound, median, upper_bound
×
255

256

257
def compute_coherence_ci(
2✔
258
    psd_samples: np.ndarray,
259
) -> Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
260
    """
261
    Compute coherence confidence intervals from multivariate PSD samples.
262

263
    Args:
264
        psd_samples: Shape (n_samples, n_freq, n_channels, n_channels)
265

266
    Returns:
267
        Dictionary mapping (i,j) channel pairs to (q05, q50, q95) tuples
268
    """
UNCOV
269
    ci_dict = {}
×
270
    n_samples, n_freq, n_channels, _ = psd_samples.shape
×
271

272
    for i in range(n_channels):
×
UNCOV
273
        for j in range(n_channels):
×
UNCOV
274
            if i > j:  # Only compute for upper triangle
×
UNCOV
275
                coh = np.abs(psd_samples[:, :, i, j]) ** 2 / (
×
276
                    np.abs(psd_samples[:, :, i, i])
277
                    * np.abs(psd_samples[:, :, j, j])
278
                )
UNCOV
279
                q05 = np.percentile(coh, 5, axis=0)
×
UNCOV
280
                q50 = np.percentile(coh, 50, axis=0)
×
UNCOV
281
                q95 = np.percentile(coh, 95, axis=0)
×
UNCOV
282
                ci_dict[(i, j)] = (q05, q50, q95)
×
283

UNCOV
284
    return ci_dict
×
285

286

287
def compute_cross_spectra_ci(psd_samples: np.ndarray) -> Tuple[Dict, Dict]:
2✔
288
    """
289
    Compute real and imaginary parts of cross-spectra.
290

291
    Args:
292
        psd_samples: Shape (n_samples, n_freq, n_channels, n_channels)
293

294
    Returns:
295
        Tuple of (real_ci_dict, imag_ci_dict)
296
    """
UNCOV
297
    real_dict = {}
×
UNCOV
298
    imag_dict = {}
×
UNCOV
299
    n_samples, n_freq, n_channels, _ = psd_samples.shape
×
300

301
    for i in range(n_channels):
×
UNCOV
302
        for j in range(n_channels):
×
UNCOV
303
            if i != j:
×
304
                # Real part
305
                re_q05 = np.percentile(psd_samples[:, :, i, j].real, 5, axis=0)
×
306
                re_q50 = np.percentile(
×
307
                    psd_samples[:, :, i, j].real, 50, axis=0
308
                )
UNCOV
309
                re_q95 = np.percentile(
×
310
                    psd_samples[:, :, i, j].real, 95, axis=0
311
                )
UNCOV
312
                real_dict[(i, j)] = (re_q05, re_q50, re_q95)
×
313

314
                # Imaginary part
315
                im_q05 = np.percentile(psd_samples[:, :, i, j].imag, 5, axis=0)
×
UNCOV
316
                im_q50 = np.percentile(
×
317
                    psd_samples[:, :, i, j].imag, 50, axis=0
318
                )
319
                im_q95 = np.percentile(
×
320
                    psd_samples[:, :, i, j].imag, 95, axis=0
321
                )
UNCOV
322
                imag_dict[(i, j)] = (im_q05, im_q50, im_q95)
×
323

UNCOV
324
    return real_dict, imag_dict
×
325

326

327
def setup_plot_style(config: Optional[PlotConfig] = None) -> PlotConfig:
2✔
328
    """Setup consistent matplotlib styling for plots."""
329
    if config is None:
2✔
330
        config = PlotConfig()
2✔
331

332
    plt.rcParams.update(
2✔
333
        {
334
            "font.size": config.fontsize,
335
            "axes.labelsize": config.labelsize,
336
            "axes.titlesize": config.titlesize,
337
            "xtick.labelsize": config.fontsize - 1,
338
            "ytick.labelsize": config.fontsize - 1,
339
            "legend.fontsize": config.fontsize - 1,
340
            "axes.linewidth": config.linewidth,
341
            "xtick.major.width": config.linewidth - 0.1,
342
            "ytick.major.width": config.linewidth - 0.1,
343
            "figure.dpi": config.dpi,
344
            "savefig.dpi": config.dpi * 2,
345
        }
346
    )
347

348
    return config
2✔
349

350

351
def validate_plotting_data(data: Dict[str, Any], required_keys: list) -> bool:
2✔
352
    """Validate that required data is available for plotting."""
UNCOV
353
    missing_keys = [
×
354
        key for key in required_keys if key not in data or data[key] is None
355
    ]
UNCOV
356
    if missing_keys:
×
UNCOV
357
        logger.warning(f"Missing required data for plotting: {missing_keys}")
×
UNCOV
358
        return False
×
UNCOV
359
    return True
×
360

361

362
def subsample_weights(
2✔
363
    weights: np.ndarray, max_samples: int = 500
364
) -> np.ndarray:
365
    """Subsample weights array if it's too large for efficient computation."""
UNCOV
366
    if weights.shape[0] > max_samples:
×
UNCOV
367
        idx = np.random.choice(
×
368
            weights.shape[0], size=max_samples, replace=False
369
        )
UNCOV
370
        return weights[idx]
×
UNCOV
371
    return weights
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc