• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18732981084

22 Oct 2025 11:43PM UTC coverage: 78.994% (-0.3%) from 79.302%
18732981084

push

github

web-flow
Merge pull request #14 from nz-gravity/add_multivar_averaged

add averaged data for multivar case

657 of 792 branches covered (82.95%)

Branch coverage included in aggregate %.

460 of 574 new or added lines in 15 files covered. (80.14%)

34 existing lines in 4 files now uncovered.

4273 of 5449 relevant lines covered (78.42%)

1.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.01
/src/log_psplines/plotting/base.py
1
"""
2
Base plotting utilities for shared functionality across plotting modules.
3
"""
4

5
import os
2✔
6
from dataclasses import dataclass
2✔
7
from functools import wraps
2✔
8
from typing import Any, Callable, Dict, Optional, Tuple, Union
2✔
9

10
import jax.numpy as jnp
2✔
11
import matplotlib.pyplot as plt
2✔
12
import numpy as np
2✔
13

14
from ..logger import logger
2✔
15

16
# Color constants used across plotting modules
17
COLORS = {
2✔
18
    "data": "#d3d3d3",  # lightgray
19
    "model": "#ff7f0e",  # tab:orange
20
    "knots": "#d62728",  # tab:red
21
    "true": "#000000",  # black
22
    "empirical": "#404040",  # dark gray
23
    "ci_fill": "#1f77b4",  # tab:blue
24
    "coherence": "#1f77b4",  # tab:blue
25
    "real": "#2ca02c",  # tab:green
26
    "imag": "#ff7f0e",  # tab:orange
27
}
28

29

30
@dataclass
2✔
31
class PlotConfig:
2✔
32
    """Configuration for plotting parameters."""
33

34
    figsize: tuple = (12, 8)
2✔
35
    dpi: int = 150
2✔
36
    fontsize: int = 11
2✔
37
    labelsize: int = 12
2✔
38
    titlesize: int = 12
2✔
39
    linewidth: float = 1.2
2✔
40
    markersize: float = 4.5
2✔
41
    alpha: float = 0.7
2✔
42

43

44
def safe_plot(filename: str, dpi: int = 150):
2✔
45
    """Decorator for safe plotting with error handling."""
46

47
    def decorator(plot_func: Callable):
2✔
48
        @wraps(plot_func)
2✔
49
        def wrapper(*args, **kwargs):
2✔
50
            try:
2✔
51
                logger.debug(f"--- plotting: {os.path.basename(filename)}")
2✔
52
                result = plot_func(*args, **kwargs)
2✔
53
                plt.savefig(filename, dpi=dpi, bbox_inches="tight")
2✔
54
                plt.close()
2✔
55
                return True
2✔
56
            except Exception as e:
×
57
                logger.warning(
×
58
                    f"Failed to create {os.path.basename(filename)}: {e}"
59
                )
60
                plt.close("all")
×
61
                return False
×
62

63
        return wrapper
2✔
64

65
    return decorator
2✔
66

67

68
def extract_plotting_data(idata, weights_key: str = None) -> Dict[str, Any]:
2✔
69
    """
70
    Extract common plotting data from inference data object.
71

72
    Args:
73
        idata: ArviZ InferenceData object
74
        weights_key: Key for weights in posterior (optional)
75

76
    Returns:
77
        Dictionary containing extracted data
78
    """
79
    from ..arviz_utils import (
2✔
80
        get_periodogram,
81
        get_spline_model,
82
        get_weights,
83
    )
84

85
    data = {}
2✔
86

87
    # Extract core data - handle univariate, multivariate, and VI cases
88
    try:
2✔
89
        data["periodogram"] = get_periodogram(idata)
2✔
90
    except KeyError:
2✔
91
        # For multivariate or VI data, periodogram might not be available
92
        # or might be stored differently
93
        data["periodogram"] = None
2✔
94

95
    # Extract spline model - handle VI case where 'knots' might not exist
96
    try:
2✔
97
        data["spline_model"] = get_spline_model(idata)
2✔
98
    except KeyError:
2✔
99
        # For VI data, spline model might be stored differently
100
        data["spline_model"] = None
2✔
101

102
    # Extract weights - handle different data structures
103
    try:
2✔
104
        data["weights"] = get_weights(idata, weights_key)
2✔
105
    except (KeyError, AttributeError):
2✔
106
        # For VI data, weights might be stored differently
107
        data["weights"] = None
2✔
108

109
    # Extract posterior samples if available
110
    if hasattr(idata, "posterior_psd"):
2✔
111
        if "psd" in idata.posterior_psd:
2✔
112
            arr = idata.posterior_psd["psd"]
2✔
113
            data["posterior_psd_quantiles"] = {
2✔
114
                "percentile": np.asarray(arr.coords["percentile"].values),
115
                "values": np.asarray(arr.values),
116
            }
117
        if "psd_matrix_real" in idata.posterior_psd:
2✔
118
            data["posterior_psd_matrix_quantiles"] = {
2✔
119
                "percentile": np.asarray(
120
                    idata.posterior_psd["psd_matrix_real"]
121
                    .coords["percentile"]
122
                    .values
123
                ),
124
                "real": np.asarray(
125
                    idata.posterior_psd["psd_matrix_real"].values
126
                ),
127
                "imag": np.asarray(
128
                    idata.posterior_psd["psd_matrix_imag"].values
129
                ),
130
                "coherence": (
131
                    np.asarray(idata.posterior_psd["coherence"].values)
132
                    if "coherence" in idata.posterior_psd
133
                    else None
134
                ),
135
            }
136

137
    # Extract true PSD if available
138
    if hasattr(idata, "attrs") and "true_psd" in idata.attrs:
2✔
139
        data["true_psd"] = idata.attrs["true_psd"]
2✔
140

141
    # Extract frequencies if available
142
    if hasattr(idata, "attrs") and "frequencies" in idata.attrs:
2✔
143
        data["frequencies"] = idata.attrs["frequencies"]
2✔
144

145
    return data
2✔
146

147

148
def compute_confidence_intervals(
2✔
149
    samples: np.ndarray,
150
    quantiles: Tuple[float, float, float] = (16, 50, 84),
151
    method: str = "percentile",
152
    alpha: float = 0.1,
153
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
154
    """
155
    Compute confidence intervals from posterior samples.
156

157
    Args:
158
        samples: Array of posterior samples
159
        quantiles: Tuple of quantiles to compute (low, median, high)
160
        method: Method for CI computation ('percentile' or 'uniform')
161
        alpha: Significance level for uniform CI
162

163
    Returns:
164
        Tuple of (lower_bound, median, upper_bound)
165
    """
166
    if method == "percentile":
×
167
        return jnp.percentile(samples, q=jnp.array(quantiles), axis=0)
×
168
    elif method == "uniform":
×
169
        return _compute_uniform_ci(samples, alpha)
×
170
    else:
171
        raise ValueError(f"Unknown CI method: {method}")
×
172

173

174
def _compute_uniform_ci(samples: np.ndarray, alpha: float = 0.1):
2✔
175
    """
176
    Compute uniform (simultaneous) confidence intervals.
177

178
    Args:
179
        samples: Shape (num_samples, num_points) array of function samples
180
        alpha: Significance level
181

182
    Returns:
183
        Tuple of (lower_bound, median, upper_bound)
184
    """
185
    num_samples, num_points = samples.shape
×
186

187
    # Compute pointwise median and standard deviation
188
    median = jnp.median(samples, axis=0)
×
189
    std = jnp.std(samples, axis=0)
×
190

191
    # Compute the max deviation over all samples
192
    deviations = (samples - median[None, :]) / std[None, :]
×
193
    max_deviation = jnp.max(jnp.abs(deviations), axis=1)
×
194

195
    # Compute the scaling factor using the distribution of max deviations
196
    k_alpha = jnp.percentile(max_deviation, 100 * (1 - alpha))
×
197

198
    # Compute uniform confidence bands
199
    lower_bound = median - k_alpha * std
×
200
    upper_bound = median + k_alpha * std
×
201

202
    return lower_bound, median, upper_bound
×
203

204

205
def compute_coherence_ci(
2✔
206
    psd_samples: np.ndarray,
207
) -> Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
208
    """
209
    Compute coherence confidence intervals from multivariate PSD samples.
210

211
    Args:
212
        psd_samples: Shape (n_samples, n_freq, n_channels, n_channels)
213

214
    Returns:
215
        Dictionary mapping (i,j) channel pairs to (q05, q50, q95) tuples
216
    """
UNCOV
217
    ci_dict = {}
×
UNCOV
218
    n_samples, n_freq, n_channels, _ = psd_samples.shape
×
219

UNCOV
220
    for i in range(n_channels):
×
UNCOV
221
        for j in range(n_channels):
×
UNCOV
222
            if i > j:  # Only compute for upper triangle
×
UNCOV
223
                coh = np.abs(psd_samples[:, :, i, j]) ** 2 / (
×
224
                    np.abs(psd_samples[:, :, i, i])
225
                    * np.abs(psd_samples[:, :, j, j])
226
                )
UNCOV
227
                q05 = np.percentile(coh, 5, axis=0)
×
UNCOV
228
                q50 = np.percentile(coh, 50, axis=0)
×
UNCOV
229
                q95 = np.percentile(coh, 95, axis=0)
×
UNCOV
230
                ci_dict[(i, j)] = (q05, q50, q95)
×
231

UNCOV
232
    return ci_dict
×
233

234

235
def compute_cross_spectra_ci(psd_samples: np.ndarray) -> Tuple[Dict, Dict]:
2✔
236
    """
237
    Compute real and imaginary parts of cross-spectra.
238

239
    Args:
240
        psd_samples: Shape (n_samples, n_freq, n_channels, n_channels)
241

242
    Returns:
243
        Tuple of (real_ci_dict, imag_ci_dict)
244
    """
245
    real_dict = {}
×
246
    imag_dict = {}
×
247
    n_samples, n_freq, n_channels, _ = psd_samples.shape
×
248

249
    for i in range(n_channels):
×
250
        for j in range(n_channels):
×
251
            if i != j:
×
252
                # Real part
253
                re_q05 = np.percentile(psd_samples[:, :, i, j].real, 5, axis=0)
×
254
                re_q50 = np.percentile(
×
255
                    psd_samples[:, :, i, j].real, 50, axis=0
256
                )
257
                re_q95 = np.percentile(
×
258
                    psd_samples[:, :, i, j].real, 95, axis=0
259
                )
260
                real_dict[(i, j)] = (re_q05, re_q50, re_q95)
×
261

262
                # Imaginary part
263
                im_q05 = np.percentile(psd_samples[:, :, i, j].imag, 5, axis=0)
×
264
                im_q50 = np.percentile(
×
265
                    psd_samples[:, :, i, j].imag, 50, axis=0
266
                )
267
                im_q95 = np.percentile(
×
268
                    psd_samples[:, :, i, j].imag, 95, axis=0
269
                )
270
                imag_dict[(i, j)] = (im_q05, im_q50, im_q95)
×
271

272
    return real_dict, imag_dict
×
273

274

275
def setup_plot_style(config: Optional[PlotConfig] = None) -> PlotConfig:
2✔
276
    """Setup consistent matplotlib styling for plots."""
277
    if config is None:
2✔
278
        config = PlotConfig()
2✔
279

280
    plt.rcParams.update(
2✔
281
        {
282
            "font.size": config.fontsize,
283
            "axes.labelsize": config.labelsize,
284
            "axes.titlesize": config.titlesize,
285
            "xtick.labelsize": config.fontsize - 1,
286
            "ytick.labelsize": config.fontsize - 1,
287
            "legend.fontsize": config.fontsize - 1,
288
            "axes.linewidth": config.linewidth,
289
            "xtick.major.width": config.linewidth - 0.1,
290
            "ytick.major.width": config.linewidth - 0.1,
291
            "figure.dpi": config.dpi,
292
            "savefig.dpi": config.dpi * 2,
293
        }
294
    )
295

296
    return config
2✔
297

298

299
def validate_plotting_data(data: Dict[str, Any], required_keys: list) -> bool:
2✔
300
    """Validate that required data is available for plotting."""
301
    missing_keys = [
×
302
        key for key in required_keys if key not in data or data[key] is None
303
    ]
304
    if missing_keys:
×
305
        logger.warning(f"Missing required data for plotting: {missing_keys}")
×
306
        return False
×
307
    return True
×
308

309

310
def subsample_weights(
2✔
311
    weights: np.ndarray, max_samples: int = 500
312
) -> np.ndarray:
313
    """Subsample weights array if it's too large for efficient computation."""
314
    if weights.shape[0] > max_samples:
×
315
        idx = np.random.choice(
×
316
            weights.shape[0], size=max_samples, replace=False
317
        )
318
        return weights[idx]
×
319
    return weights
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc