• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 20735310563

06 Jan 2026 01:56AM UTC coverage: 69.693% (-6.1%) from 75.796%
20735310563

push

github

avivajpeyi
run slow on CI

928 of 1242 branches covered (74.72%)

Branch coverage included in aggregate %.

5352 of 7769 relevant lines covered (68.89%)

1.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.26
/src/log_psplines/plotting/diagnostics.py
1
import os
2✔
2
from dataclasses import dataclass
2✔
3
from typing import Optional
2✔
4

5
import arviz as az
2✔
6
import matplotlib.pyplot as plt
2✔
7
import numpy as np
2✔
8

9
from ..diagnostics import run_all_diagnostics
2✔
10
from ..logger import logger
2✔
11
from .base import PlotConfig, safe_plot, setup_plot_style
2✔
12

13
# Setup consistent styling for diagnostics plots
14
setup_plot_style()
2✔
15

16

17
@dataclass
2✔
18
class DiagnosticsConfig:
2✔
19
    """Configuration for diagnostics plotting parameters."""
20

21
    figsize: tuple = (12, 8)
2✔
22
    dpi: int = 150
2✔
23
    ess_threshold: int = 400
2✔
24
    rhat_threshold: float = 1.01
2✔
25
    fontsize: int = 11
2✔
26
    labelsize: int = 12
2✔
27
    titlesize: int = 12
2✔
28

29

30
def plot_trace(idata: az.InferenceData, compact=True) -> plt.Figure:
2✔
31
    groups = {
2✔
32
        "delta": [
33
            v for v in idata.posterior.data_vars if v.startswith("delta")
34
        ],
35
        "phi": [v for v in idata.posterior.data_vars if v.startswith("phi")],
36
        "weights": [
37
            v for v in idata.posterior.data_vars if v.startswith("weights")
38
        ],
39
    }
40

41
    if compact:
2✔
42
        nrows = 3
2✔
43
    else:
44
        nrows = len(groups)
×
45
    fig, axes = plt.subplots(nrows, 2, figsize=(7, 3 * nrows))
2✔
46

47
    for row, (group_name, vars) in enumerate(groups.items()):
2✔
48

49
        # if vars are more than 1, and compact, then we need to repeat the axes
50
        if compact:
2✔
51
            group_axes = axes[row, :].reshape(1, 2)
2✔
52
            group_axes = np.repeat(group_axes, len(vars), axis=0)
2✔
53
        else:
54
            group_axes = axes[row, :]
×
55

56
        group_axes[0, 0].set_title(
2✔
57
            f"{group_name.capitalize()} Parameters", fontsize=14
58
        )
59

60
        for i, var in enumerate(vars):
2✔
61
            data = idata.posterior[
2✔
62
                var
63
            ].values  # shape is (nchain, nsamples, ndim) if ndim>1 else (nchain, nsamples)
64
            if data.ndim == 3:
2✔
65
                data = data[0].T  # shape is now (ndim, nsamples)
2✔
66

67
            ax_trace = group_axes[i, 0] if compact else group_axes[0]
2✔
68
            ax_hist = group_axes[i, 1] if compact else group_axes[1]
2✔
69
            ax_trace.set_ylabel(group_name, fontsize=8)
2✔
70
            ax_trace.set_xlabel("MCMC Step", fontsize=8)
2✔
71
            ax_hist.set_xlabel(group_name, fontsize=8)
2✔
72
            # place ylabel on right side of hist
73
            ax_hist.yaxis.set_label_position("right")
2✔
74
            ax_hist.set_ylabel("Density", fontsize=8, rotation=270, labelpad=0)
2✔
75

76
            # remove axes yspine for hist
77
            ax_hist.spines["left"].set_visible(False)
2✔
78
            ax_hist.spines["right"].set_visible(False)
2✔
79
            ax_hist.spines["top"].set_visible(False)
2✔
80
            ax_hist.set_yticks([])  # remove y ticks
2✔
81
            ax_hist.yaxis.set_ticks_position("none")
2✔
82

83
            ax_trace.spines["right"].set_visible(False)
2✔
84
            ax_trace.spines["top"].set_visible(False)
2✔
85

86
            color = f"C{i}"
2✔
87
            label = f"{var}"
2✔
88
            if group_name in ["phi", "delta"]:
2✔
89
                ax_trace.set_yscale("log")
2✔
90
                ax_hist.set_xscale("log")
2✔
91

92
            for p in data:
2✔
93
                ax_trace.plot(p, color=color, alpha=0.7, label=label)
2✔
94

95
                # if phi or delta, use log scale for hist-x, log for trace y
96
                if group_name in ["phi", "delta"]:
2✔
97
                    bins = np.logspace(
2✔
98
                        np.log10(np.min(p)), np.log10(np.max(p)), 30
99
                    )
100
                    logp = np.log(p)
2✔
101
                    log_grid, log_pdf = az.kde(logp)
2✔
102
                    grid = np.exp(log_grid)
2✔
103
                    pdf = log_pdf / grid  # change of variables
2✔
104
                else:
105
                    bins = 30
2✔
106
                    grid, pdf = az.kde(p)
2✔
107
                ax_hist.plot(grid, pdf, color=color, label=label)
2✔
108
                ax_hist.hist(
2✔
109
                    p, bins=bins, density=True, color=color, alpha=0.3
110
                )
111

112
                # KDE plot instead of histogram
113

114
    plt.suptitle("Parameter Traces", fontsize=16)
2✔
115
    plt.tight_layout()
2✔
116
    return fig
2✔
117

118

119
def plot_diagnostics(
2✔
120
    idata: az.InferenceData,
121
    outdir: str,
122
    n_channels: Optional[int] = None,
123
    n_freq: Optional[int] = None,
124
    runtime: Optional[float] = None,
125
    config: Optional[DiagnosticsConfig] = None,
126
) -> None:
127
    """
128
    Create essential MCMC diagnostics in organized subdirectories.
129
    """
130
    if outdir is None:
2✔
131
        return
×
132

133
    if config is None:
2✔
134
        config = DiagnosticsConfig()
2✔
135

136
    # Create diagnostics subdirectory
137
    diag_dir = os.path.join(outdir, "diagnostics")
2✔
138
    os.makedirs(diag_dir, exist_ok=True)
2✔
139

140
    logger.info("Generating MCMC diagnostics...")
2✔
141

142
    # Generate summary report
143
    generate_diagnostics_summary(idata, diag_dir)
2✔
144
    _create_diagnostic_plots(
2✔
145
        idata, diag_dir, config, n_channels, n_freq, runtime
146
    )
147

148

149
def _create_diagnostic_plots(
2✔
150
    idata, diag_dir, config, n_channels, n_freq, runtime
151
):
152
    """Create only the essential diagnostic plots."""
153
    logger.debug("Generating diagnostic plots...")
2✔
154

155
    # 1. ArviZ trace plots
156
    @safe_plot(f"{diag_dir}/trace_plots.png", config.dpi)
2✔
157
    def create_trace_plots():
2✔
158
        return plot_trace(idata)
2✔
159

160
    create_trace_plots()
2✔
161

162
    # 2. Summary dashboard with key convergence metrics
163
    @safe_plot(f"{diag_dir}/summary_dashboard.png", config.dpi)
2✔
164
    def plot_summary():
2✔
165
        _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime)
2✔
166

167
    plot_summary()
2✔
168

169
    # 3. Log posterior diagnostics
170
    @safe_plot(f"{diag_dir}/log_posterior.png", config.dpi)
2✔
171
    def plot_lp():
2✔
172
        _plot_log_posterior(idata, config)
2✔
173

174
    plot_lp()
2✔
175

176
    # 4. Acceptance rate diagnostics
177
    @safe_plot(f"{diag_dir}/acceptance_diagnostics.png", config.dpi)
2✔
178
    def plot_acceptance():
2✔
179
        _plot_acceptance_diagnostics_blockaware(idata, config)
2✔
180

181
    plot_acceptance()
2✔
182

183
    # 5. Sampler-specific diagnostics
184
    _create_sampler_diagnostics(idata, diag_dir, config)
2✔
185

186
    # 6. Divergences diagnostics (for NUTS only)
187
    _create_divergences_diagnostics(idata, diag_dir, config)
2✔
188

189

190
def _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime):
2✔
191

192
    # Create 2x2 layout
193
    fig, axes = plt.subplots(
2✔
194
        2, 2, figsize=(config.figsize[0] * 0.8, config.figsize[1])
195
    )
196
    ess_ax = axes[0, 0]
2✔
197
    meta_ax = axes[0, 1]
2✔
198
    param_ax = axes[1, 0]
2✔
199
    status_ax = axes[1, 1]
2✔
200

201
    # Get ESS values once
202
    ess_values = None
2✔
203
    try:
2✔
204
        ess = idata.attrs.get("ess")
2✔
205
        ess_values = ess[~np.isnan(ess)]
2✔
206
    except Exception:
×
207
        pass
×
208

209
    # 1. ESS Distribution
210
    _plot_ess_histogram(ess_ax, ess_values, config)
2✔
211

212
    # 2. Analysis Metadata
213
    _plot_metadata(meta_ax, idata, n_channels, n_freq, runtime)
2✔
214

215
    # 3. Parameter Summary
216
    _plot_parameter_summary(param_ax, idata)
2✔
217

218
    # 4. Convergence Status
219
    _plot_convergence_status(status_ax, ess_values, config)
2✔
220

221
    plt.tight_layout()
2✔
222

223

224
def _plot_ess_histogram(ax, ess_values, config):
2✔
225
    """Plot ESS distribution with quality thresholds."""
226
    if ess_values is None or len(ess_values) == 0:
2✔
227
        ax.text(0.5, 0.5, "ESS unavailable", ha="center", va="center")
2✔
228
        ax.set_title("ESS Distribution")
2✔
229
        return
2✔
230

231
    # Histogram
232
    ax.hist(ess_values, bins=30, alpha=0.7, edgecolor="black")
2✔
233

234
    # Reference lines
235
    thresholds = [
2✔
236
        (400, "red", "--", "Minimum reliable"),
237
        (1000, "orange", "--", "Good"),
238
        (np.max(ess_values), "green", ":", f"Max = {np.max(ess_values):.0f}"),
239
    ]
240

241
    for threshold, color, style, label in thresholds:
2✔
242
        ax.axvline(
2✔
243
            x=threshold,
244
            color=color,
245
            linestyle=style,
246
            linewidth=2 if threshold < np.max(ess_values) else 1,
247
            alpha=0.8,
248
            label=label,
249
        )
250

251
    ax.set_xlabel("ESS")
2✔
252
    ax.set_ylabel("Count")
2✔
253
    ax.set_title("ESS Distribution")
2✔
254
    ax.legend(loc="upper right", fontsize="x-small")
2✔
255
    ax.grid(True, alpha=0.3)
2✔
256

257
    # Summary stats
258
    pct_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
259
    stats_text = f"Min: {ess_values.min():.0f}\nMean: {ess_values.mean():.0f}\n≥{config.ess_threshold}: {pct_good:.1f}%"
2✔
260
    ax.text(
2✔
261
        0.02,
262
        0.98,
263
        stats_text,
264
        transform=ax.transAxes,
265
        fontsize=10,
266
        verticalalignment="top",
267
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.7),
268
    )
269

270

271
def _plot_metadata(ax, idata, n_channels, n_freq, runtime):
2✔
272
    """Display analysis metadata."""
273
    try:
2✔
274
        n_samples = idata.posterior.sizes.get("draw", 0)
2✔
275
        n_chains = idata.posterior.sizes.get("chain", 1)
2✔
276
        n_params = len(list(idata.posterior.data_vars))
2✔
277
        sampler_type = idata.attrs["sampler_type"]
2✔
278

279
        metadata_lines = [
2✔
280
            f"Sampler: {sampler_type}",
281
            f"Samples: {n_samples} × {n_chains} chains",
282
            f"Parameters: {n_params}",
283
        ]
284
        if n_channels is not None:
2✔
285
            metadata_lines.append(f"Channels: {n_channels}")
×
286
        if n_freq is not None:
2✔
287
            metadata_lines.append(f"Frequencies: {n_freq}")
×
288
        if runtime is not None:
2✔
289
            metadata_lines.append(f"Runtime: {runtime:.2f}s")
×
290

291
        ax.text(
2✔
292
            0.05,
293
            0.95,
294
            "\n".join(metadata_lines),
295
            transform=ax.transAxes,
296
            fontsize=12,
297
            verticalalignment="top",
298
            fontfamily="monospace",
299
        )
300
    except Exception:
×
301
        ax.text(0.5, 0.5, "Metadata unavailable", ha="center", va="center")
×
302

303
    ax.set_title("Analysis Summary")
2✔
304
    ax.axis("off")
2✔
305

306

307
def _plot_parameter_summary(ax, idata):
2✔
308
    """Display parameter count summary."""
309
    try:
2✔
310
        param_groups = _group_parameters_simple(idata)
2✔
311
        if param_groups:
2✔
312
            summary_text = "Parameter Summary:\n"
2✔
313
            for group_name, params in param_groups.items():
2✔
314
                if params:
2✔
315
                    summary_text += f"{group_name}: {len(params)}\n"
2✔
316
            ax.text(
2✔
317
                0.05,
318
                0.95,
319
                summary_text.strip(),
320
                transform=ax.transAxes,
321
                fontsize=11,
322
                verticalalignment="top",
323
                fontfamily="monospace",
324
            )
325
    except Exception:
×
326
        ax.text(
×
327
            0.5,
328
            0.5,
329
            "Parameter summary\nunavailable",
330
            ha="center",
331
            va="center",
332
        )
333

334
    ax.set_title("Parameter Summary")
2✔
335
    ax.axis("off")
2✔
336

337

338
def _plot_convergence_status(ax, ess_values, config):
2✔
339
    """Display convergence status based on ESS only."""
340
    try:
2✔
341
        status_lines = ["Convergence Status:"]
2✔
342

343
        if ess_values is not None and len(ess_values) > 0:
2✔
344
            ess_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
345
            status_lines.append(
2✔
346
                f"ESS ≥ {config.ess_threshold}: {ess_good:.0f}%"
347
            )
348
            status_lines.append("")
2✔
349
            status_lines.append("Overall Status:")
2✔
350

351
            if ess_good >= 90:
2✔
352
                status_lines.append("✓ EXCELLENT")
×
353
                color = "green"
×
354
            elif ess_good >= 75:
2✔
355
                status_lines.append("✓ ADEQUATE")
×
356
                color = "orange"
×
357
            else:
358
                status_lines.append("⚠ NEEDS ATTENTION")
2✔
359
                color = "red"
2✔
360
        else:
361
            status_lines.append("? UNABLE TO ASSESS")
2✔
362
            color = "gray"
2✔
363

364
        ax.text(
2✔
365
            0.05,
366
            0.95,
367
            "\n".join(status_lines),
368
            transform=ax.transAxes,
369
            fontsize=11,
370
            verticalalignment="top",
371
            fontfamily="monospace",
372
            color=color,
373
        )
374
    except Exception:
×
375
        ax.text(0.5, 0.5, "Status unavailable", ha="center", va="center")
×
376

377
    ax.set_title("Convergence Status")
2✔
378
    ax.axis("off")
2✔
379

380

381
def _plot_log_posterior(idata, config):
2✔
382
    """Log posterior diagnostics."""
383
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
384

385
    # Check for lp first, then log_likelihood
386
    if "lp" in idata.sample_stats:
2✔
387
        lp_values = idata.sample_stats["lp"].values.flatten()
2✔
388
        var_name = "lp"
2✔
389
        title_prefix = "Log Posterior"
2✔
390
    elif "log_likelihood" in idata.sample_stats:
2✔
391
        lp_values = idata.sample_stats["log_likelihood"].values.flatten()
2✔
392
        var_name = "log_likelihood"
2✔
393
        title_prefix = "Log Likelihood"
2✔
394
    else:
395
        # Create a fallback layout when no posterior data available
396
        fig, axes = plt.subplots(1, 1, figsize=config.figsize)
×
397
        axes.text(
×
398
            0.5,
399
            0.5,
400
            "No log posterior\nor log likelihood\navailable",
401
            ha="center",
402
            va="center",
403
            fontsize=14,
404
        )
405
        axes.set_title("Log Posterior Diagnostics")
×
406
        axes.axis("off")
×
407
        plt.tight_layout()
×
408
        return
×
409

410
    # Trace plot with running mean overlaid
411
    axes[0, 0].plot(
2✔
412
        lp_values, alpha=0.7, linewidth=1, color="blue", label="Trace"
413
    )
414

415
    # Add running mean on the same plot
416
    window_size = max(10, len(lp_values) // 100)
2✔
417
    if len(lp_values) > window_size:
2✔
418
        running_mean = np.convolve(
×
419
            lp_values, np.ones(window_size) / window_size, mode="valid"
420
        )
421
        axes[0, 0].plot(
×
422
            range(window_size // 2, window_size // 2 + len(running_mean)),
423
            running_mean,
424
            alpha=0.9,
425
            linewidth=3,
426
            color="red",
427
            label=f"Running mean (w={window_size})",
428
        )
429

430
    axes[0, 0].set_xlabel("Iteration")
2✔
431
    axes[0, 0].set_ylabel(title_prefix)
2✔
432
    axes[0, 0].set_title(f"{title_prefix} Trace with Running Mean")
2✔
433
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
434
    axes[0, 0].grid(True, alpha=0.3)
2✔
435

436
    # Distribution
437
    axes[0, 1].hist(
2✔
438
        lp_values, bins=50, alpha=0.7, density=True, edgecolor="black"
439
    )
440
    axes[0, 1].axvline(
2✔
441
        np.mean(lp_values),
442
        color="red",
443
        linestyle="--",
444
        linewidth=2,
445
        label=f"Mean: {np.mean(lp_values):.1f}",
446
    )
447
    axes[0, 1].set_xlabel(title_prefix)
2✔
448
    axes[0, 1].set_ylabel("Density")
2✔
449
    axes[0, 1].set_title(f"{title_prefix} Distribution")
2✔
450
    axes[0, 1].legend(loc="best", fontsize="small")
2✔
451
    axes[0, 1].grid(True, alpha=0.3)
2✔
452

453
    # Step-to-step changes
454
    lp_diff = np.diff(lp_values)
2✔
455
    axes[1, 0].plot(lp_diff, alpha=0.5, linewidth=1)
2✔
456
    axes[1, 0].axhline(0, color="red", linestyle="--", alpha=0.7)
2✔
457
    axes[1, 0].axhline(
2✔
458
        np.mean(lp_diff),
459
        color="blue",
460
        linestyle="--",
461
        alpha=0.7,
462
        label=f"Mean change: {np.mean(lp_diff):.1f}",
463
    )
464
    axes[1, 0].set_xlabel("Iteration")
2✔
465
    axes[1, 0].set_ylabel(f"{title_prefix} Difference")
2✔
466
    axes[1, 0].set_title("Step-to-Step Changes")
2✔
467
    axes[1, 0].legend(loc="best", fontsize="small")
2✔
468
    axes[1, 0].grid(True, alpha=0.3)
2✔
469

470
    # Summary statistics
471
    stats_lines = [
2✔
472
        f"Mean: {np.mean(lp_values):.2f}",
473
        f"Std: {np.std(lp_values):.2f}",
474
        f"Min: {np.min(lp_values):.2f}",
475
        f"Max: {np.max(lp_values):.2f}",
476
        f"Range: {np.max(lp_values) - np.min(lp_values):.2f}",
477
        "",
478
        "Stability:",
479
        f"Final variation: {np.std(lp_values[-len(lp_values)//4:]):.2f}",
480
    ]
481

482
    axes[1, 1].text(
2✔
483
        0.05,
484
        0.95,
485
        "\n".join(stats_lines),
486
        transform=axes[1, 1].transAxes,
487
        fontsize=10,
488
        verticalalignment="top",
489
        fontfamily="monospace",
490
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
491
    )
492
    axes[1, 1].set_title("Posterior Statistics")
2✔
493
    axes[1, 1].axis("off")
2✔
494

495
    plt.tight_layout()
2✔
496

497

498
def _plot_acceptance_diagnostics(idata, config):
2✔
499
    """Acceptance rate diagnostics."""
500
    accept_key = None
×
501
    if "accept_prob" in idata.sample_stats:
×
502
        accept_key = "accept_prob"
×
503
    elif "acceptance_rate" in idata.sample_stats:
×
504
        accept_key = "acceptance_rate"
×
505

506
    if accept_key is None:
×
507
        fig, ax = plt.subplots(figsize=config.figsize)
×
508
        ax.text(
×
509
            0.5,
510
            0.5,
511
            "Acceptance rate data unavailable",
512
            ha="center",
513
            va="center",
514
        )
515
        ax.set_title("Acceptance Rate Diagnostics")
×
516
        return
×
517

518
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
×
519

520
    accept_rates = idata.sample_stats[accept_key].values.flatten()
×
521
    target_rate = getattr(idata.attrs, "target_accept_rate", 0.44)
×
522
    sampler_type = (
×
523
        idata.attrs["sampler_type"].lower()
524
        if "sampler_type" in idata.attrs
525
        else "unknown"
526
    )
527
    sampler_type = "NUTS" if "nuts" in sampler_type else "MH"
×
528

529
    # Define good ranges based on sampler
530
    if target_rate > 0.5:  # NUTS
×
531
        good_range = (0.7, 0.9)
×
532
        low_range = (0.0, 0.6)
×
533
        high_range = (0.9, 1.0)
×
534
        concerning_range = (0.6, 0.7)
×
535
    else:  # MH
536
        good_range = (0.2, 0.5)
×
537
        low_range = (0.0, 0.2)
×
538
        high_range = (0.5, 1.0)
×
539
        concerning_range = (0.1, 0.2)  # MH can be lower than NUTS
×
540

541
    # Trace plot with color zones
542
    # Add background zones
543
    axes[0, 0].axhspan(
×
544
        good_range[0],
545
        good_range[1],
546
        alpha=0.1,
547
        color="green",
548
        label=f"Good ({good_range[0]:.1f}-{good_range[1]:.1f})",
549
    )
550
    axes[0, 0].axhspan(
×
551
        low_range[0], low_range[1], alpha=0.1, color="red", label="Too low"
552
    )
553
    axes[0, 0].axhspan(
×
554
        high_range[0],
555
        high_range[1],
556
        alpha=0.1,
557
        color="orange",
558
        label="Too high",
559
    )
560
    if concerning_range[1] > concerning_range[0]:
×
561
        axes[0, 0].axhspan(
×
562
            concerning_range[0],
563
            concerning_range[1],
564
            alpha=0.1,
565
            color="yellow",
566
            label="Concerning",
567
        )
568

569
    # Main trace plot
570
    axes[0, 0].plot(
×
571
        accept_rates, alpha=0.8, linewidth=1, color="blue", label="Trace"
572
    )
573
    axes[0, 0].axhline(
×
574
        target_rate,
575
        color="red",
576
        linestyle="--",
577
        linewidth=2,
578
        label=f"Target ({target_rate})",
579
    )
580

581
    # Add running average on the same plot
582
    window_size = max(10, len(accept_rates) // 50)
×
583
    if len(accept_rates) > window_size:
×
584
        running_mean = np.convolve(
×
585
            accept_rates, np.ones(window_size) / window_size, mode="valid"
586
        )
587
        axes[0, 0].plot(
×
588
            range(window_size // 2, window_size // 2 + len(running_mean)),
589
            running_mean,
590
            alpha=0.9,
591
            linewidth=3,
592
            color="purple",
593
            label=f"Running mean (w={window_size})",
594
        )
595

596
    axes[0, 0].set_xlabel("Iteration")
×
597
    axes[0, 0].set_ylabel("Acceptance Rate")
×
598
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
×
599
    axes[0, 0].legend(loc="best", fontsize="small")
×
600
    axes[0, 0].grid(True, alpha=0.3)
×
601

602
    # Add interpretation text
603
    interpretation = f"{sampler_type} aims for {target_rate:.2f}."
×
604
    if target_rate > 0.5:
×
605
        interpretation += " Green: efficient sampling."
×
606
    else:
607
        interpretation += " MH adapts to find optimal rate."
×
608
    axes[0, 0].text(
×
609
        0.02,
610
        0.02,
611
        interpretation,
612
        transform=axes[0, 0].transAxes,
613
        fontsize=9,
614
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
615
    )
616

617
    # Distribution
618
    axes[0, 1].hist(
×
619
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
620
    )
621
    axes[0, 1].axvline(
×
622
        target_rate,
623
        color="red",
624
        linestyle="--",
625
        linewidth=2,
626
        label=f"Target ({target_rate})",
627
    )
628
    axes[0, 1].set_xlabel("Acceptance Rate")
×
629
    axes[0, 1].set_ylabel("Density")
×
630
    axes[0, 1].set_title("Acceptance Rate Distribution")
×
631
    axes[0, 1].legend()
×
632
    axes[0, 1].grid(True, alpha=0.3)
×
633

634
    # Since running means are already overlaid on the main plot, use the bottom row for additional info
635

636
    # Additional acceptance analysis - evolution over time
637
    if len(accept_rates) > 10:
×
638
        # Show moving standard deviation or coefficient of variation
639
        window_std = np.array(
×
640
            [
641
                np.std(accept_rates[max(0, i - 20) : i + 1])
642
                for i in range(len(accept_rates))
643
            ]
644
        )
645
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
×
646
        axes[1, 0].set_xlabel("Iteration")
×
647
        axes[1, 0].set_ylabel("Rolling Std")
×
648
        axes[1, 0].set_title("Rolling Standard Deviation")
×
649
        axes[1, 0].grid(True, alpha=0.3)
×
650
    else:
651
        axes[1, 0].text(
×
652
            0.5,
653
            0.5,
654
            "Acceptance variability\nanalysis unavailable",
655
            ha="center",
656
            va="center",
657
        )
658
        axes[1, 0].set_title("Acceptance Stability")
×
659

660
    # Summary statistics (expanded)
661
    stats_text = [
×
662
        f"Sampler: {sampler_type}",
663
        f"Target: {target_rate:.3f}",
664
        f"Mean: {np.mean(accept_rates):.3f}",
665
        f"Std: {np.std(accept_rates):.3f}",
666
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
667
        f"Min: {np.min(accept_rates):.3f}",
668
        f"Max: {np.max(accept_rates):.3f}",
669
        "",
670
        "Stability:",
671
        f"Final std: {np.std(accept_rates[-len(accept_rates)//4:]):.3f}",
672
    ]
673

674
    axes[1, 1].text(
×
675
        0.05,
676
        0.95,
677
        "\n".join(stats_text),
678
        transform=axes[1, 1].transAxes,
679
        fontsize=9,
680
        verticalalignment="top",
681
        fontfamily="monospace",
682
    )
683
    axes[1, 1].set_title("Acceptance Analysis")
×
684
    axes[1, 1].axis("off")
×
685

686
    plt.tight_layout()
×
687

688

689
def _plot_acceptance_diagnostics_blockaware(idata, config):
2✔
690
    """Acceptance diagnostics that also handle per‑channel series from blocked NUTS.
691

692
    If keys like ``accept_prob_channel_0`` are found in ``idata.sample_stats``,
693
    they are overlaid on the overall trace and included in the summary.
694
    """
695
    # Detect overall series
696
    accept_key = None
2✔
697
    if "accept_prob" in idata.sample_stats:
2✔
698
        accept_key = "accept_prob"
2✔
699
    elif "acceptance_rate" in idata.sample_stats:
2✔
700
        accept_key = "acceptance_rate"
×
701

702
    # Collect per-channel series
703
    channel_series = {}
2✔
704
    for key in idata.sample_stats:
2✔
705
        if isinstance(key, str) and key.startswith("accept_prob_channel_"):
2✔
706
            try:
2✔
707
                ch = int(key.rsplit("_", 1)[-1])
2✔
708
                channel_series[ch] = idata.sample_stats[key].values.flatten()
2✔
709
            except Exception:
×
710
                pass
×
711

712
    if accept_key is None and not channel_series:
2✔
713
        fig, ax = plt.subplots(figsize=config.figsize)
×
714
        ax.text(
×
715
            0.5,
716
            0.5,
717
            "Acceptance rate data unavailable",
718
            ha="center",
719
            va="center",
720
        )
721
        ax.set_title("Acceptance Rate Diagnostics")
×
722
        return
×
723

724
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
725

726
    # Overall or concatenated series
727
    if accept_key is not None:
2✔
728
        accept_rates = idata.sample_stats[accept_key].values.flatten()
2✔
729
    else:
730
        accept_rates = np.concatenate(list(channel_series.values()))
2✔
731

732
    sampler_type_attr = idata.attrs.get("sampler_type", "").lower()
2✔
733
    is_nuts = "nuts" in sampler_type_attr
2✔
734
    target_rate = idata.attrs.get(
2✔
735
        "target_accept_rate",
736
        idata.attrs.get("target_accept_prob", 0.8 if is_nuts else 0.44),
737
    )
738
    sampler_type = "NUTS" if is_nuts else "MH"
2✔
739

740
    # Ranges
741
    if is_nuts:
2✔
742
        good_range = (0.7, 0.9)
2✔
743
        low_range = (0.0, 0.6)
2✔
744
        high_range = (0.9, 1.0)
2✔
745
        concerning_range = (0.6, 0.7)
2✔
746
    else:
747
        good_range = (0.2, 0.5)
2✔
748
        low_range = (0.0, 0.2)
2✔
749
        high_range = (0.5, 1.0)
2✔
750
        concerning_range = (0.1, 0.2)
2✔
751

752
    # Background zones
753
    axes[0, 0].axhspan(good_range[0], good_range[1], alpha=0.1, color="green")
2✔
754
    axes[0, 0].axhspan(low_range[0], low_range[1], alpha=0.1, color="red")
2✔
755
    axes[0, 0].axhspan(high_range[0], high_range[1], alpha=0.1, color="orange")
2✔
756
    if concerning_range[1] > concerning_range[0]:
2✔
757
        axes[0, 0].axhspan(
2✔
758
            concerning_range[0], concerning_range[1], alpha=0.1, color="yellow"
759
        )
760

761
    # Plot traces: overall + per-channel overlays
762
    if accept_key is not None:
2✔
763
        axes[0, 0].plot(
2✔
764
            accept_rates, alpha=0.8, linewidth=1, color="blue", label="overall"
765
        )
766
    for ch in sorted(channel_series):
2✔
767
        axes[0, 0].plot(
2✔
768
            channel_series[ch], alpha=0.6, linewidth=1, label=f"ch {ch}"
769
        )
770
    axes[0, 0].axhline(
2✔
771
        target_rate,
772
        color="red",
773
        linestyle="--",
774
        linewidth=2,
775
        label=f"Target ({target_rate})",
776
    )
777

778
    # Running mean for overall
779
    window_size = max(10, len(accept_rates) // 50)
2✔
780
    if len(accept_rates) > window_size:
2✔
781
        running_mean = np.convolve(
×
782
            accept_rates, np.ones(window_size) / window_size, mode="valid"
783
        )
784
        axes[0, 0].plot(
×
785
            range(window_size // 2, window_size // 2 + len(running_mean)),
786
            running_mean,
787
            alpha=0.9,
788
            linewidth=3,
789
            color="purple",
790
            label=f"Running mean (w={window_size})",
791
        )
792

793
    axes[0, 0].set_xlabel("Iteration")
2✔
794
    axes[0, 0].set_ylabel("Acceptance Rate")
2✔
795
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
2✔
796
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
797
    axes[0, 0].grid(True, alpha=0.3)
2✔
798

799
    # Histogram and summary
800
    axes[0, 1].hist(
2✔
801
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
802
    )
803
    axes[0, 1].axvline(
2✔
804
        target_rate,
805
        color="red",
806
        linestyle="--",
807
        linewidth=2,
808
        label=f"Target ({target_rate})",
809
    )
810
    axes[0, 1].set_xlabel("Acceptance Rate")
2✔
811
    axes[0, 1].set_ylabel("Density")
2✔
812
    axes[0, 1].set_title("Acceptance Rate Distribution")
2✔
813
    axes[0, 1].legend()
2✔
814
    axes[0, 1].grid(True, alpha=0.3)
2✔
815

816
    if len(accept_rates) > 10:
2✔
817
        window_std = np.array(
×
818
            [
819
                np.std(accept_rates[max(0, i - 20) : i + 1])
820
                for i in range(len(accept_rates))
821
            ]
822
        )
823
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
×
824
        axes[1, 0].set_xlabel("Iteration")
×
825
        axes[1, 0].set_ylabel("Rolling Std")
×
826
        axes[1, 0].set_title("Rolling Standard Deviation")
×
827
        axes[1, 0].grid(True, alpha=0.3)
×
828
    else:
829
        axes[1, 0].text(
2✔
830
            0.5,
831
            0.5,
832
            "Acceptance variability\nanalysis unavailable",
833
            ha="center",
834
            va="center",
835
        )
836
        axes[1, 0].set_title("Acceptance Stability")
2✔
837

838
    # Summary text
839
    stats_text = [
2✔
840
        f"Sampler: {sampler_type}",
841
        f"Target: {target_rate:.3f}",
842
        f"Mean: {np.mean(accept_rates):.3f}",
843
        f"Std: {np.std(accept_rates):.3f}",
844
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
845
        f"Min: {np.min(accept_rates):.3f}",
846
        f"Max: {np.max(accept_rates):.3f}",
847
    ]
848
    if channel_series:
2✔
849
        stats_text.append("")
2✔
850
        stats_text.append("Per-channel means:")
2✔
851
        for ch in sorted(channel_series):
2✔
852
            stats_text.append(f"  ch {ch}: {np.mean(channel_series[ch]):.3f}")
2✔
853

854
    axes[1, 1].text(
2✔
855
        0.05,
856
        0.95,
857
        "\n".join(stats_text),
858
        transform=axes[1, 1].transAxes,
859
        fontsize=9,
860
        va="top",
861
        family="monospace",
862
    )
863
    axes[1, 1].set_title("Acceptance Analysis")
2✔
864
    axes[1, 1].axis("off")
2✔
865

866
    plt.tight_layout()
2✔
867

868

869
def _get_channel_indices(sample_stats, base_key: str) -> set:
2✔
870
    """Return set of channel indices for the given ``base_key`` prefix."""
871

872
    prefix = f"{base_key}_channel_"
2✔
873
    indices = set()
2✔
874
    for key in sample_stats:
2✔
875
        if isinstance(key, str) and key.startswith(prefix):
2✔
876
            try:
2✔
877
                indices.add(int(key.replace(prefix, "")))
2✔
878
            except Exception:
×
879
                continue
×
880
    return indices
2✔
881

882

883
def _plot_nuts_diagnostics_blockaware(idata, config):
2✔
884
    """NUTS diagnostics supporting per‑channel (blocked) diagnostics fields.
885

886
    Overlays per‑channel series when keys like ``energy_channel_{j}`` or
887
    ``num_steps_channel_{j}`` are present.
888
    """
889
    # Presence of overall arrays
890
    has_energy = "energy" in idata.sample_stats
2✔
891
    has_potential = "potential_energy" in idata.sample_stats
2✔
892
    has_steps = "num_steps" in idata.sample_stats
2✔
893
    has_accept = "accept_prob" in idata.sample_stats
2✔
894

895
    # Collect per-channel data
896
    def _collect(base):
2✔
897
        out = {}
2✔
898
        prefix = f"{base}_channel_"
2✔
899
        for key in idata.sample_stats:
2✔
900
            if isinstance(key, str) and key.startswith(prefix):
2✔
901
                try:
2✔
902
                    ch = int(key.replace(prefix, ""))
2✔
903
                    out[ch] = idata.sample_stats[key].values.flatten()
2✔
904
                except Exception:
×
905
                    pass
×
906
        return out
2✔
907

908
    energy_ch = _collect("energy")
2✔
909
    potential_ch = _collect("potential_energy")
2✔
910
    steps_ch = _collect("num_steps")
2✔
911
    accept_ch = _collect("accept_prob")
2✔
912

913
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
914

915
    # Energy / potential
916
    ax = axes[0, 0]
2✔
917
    plotted = False
2✔
918
    if has_energy:
2✔
919
        ax.plot(
2✔
920
            idata.sample_stats.energy.values.flatten(),
921
            alpha=0.7,
922
            lw=1,
923
            label="H",
924
        )
925
        plotted = True
2✔
926
    if has_potential:
2✔
927
        ax.plot(
2✔
928
            idata.sample_stats.potential_energy.values.flatten(),
929
            alpha=0.7,
930
            lw=1,
931
            label="P",
932
        )
933
        plotted = True
2✔
934
    for ch in sorted(energy_ch):
2✔
935
        ax.plot(energy_ch[ch], alpha=0.5, lw=1, label=f"H ch {ch}")
2✔
936
        plotted = True
2✔
937
    for ch in sorted(potential_ch):
2✔
938
        ax.plot(potential_ch[ch], alpha=0.5, lw=1, label=f"P ch {ch}")
2✔
939
        plotted = True
2✔
940
    if not plotted:
2✔
941
        ax.text(0.5, 0.5, "Energy data\nunavailable", ha="center", va="center")
×
942
    ax.set_title("Energy Diagnostics")
2✔
943
    ax.set_xlabel("Iteration")
2✔
944
    ax.set_ylabel("Energy")
2✔
945
    ax.grid(True, alpha=0.3)
2✔
946
    if plotted:
2✔
947
        ax.legend(loc="best", fontsize="small")
2✔
948

949
    # Steps histogram
950
    ax = axes[0, 1]
2✔
951
    if has_steps:
2✔
952
        vals = idata.sample_stats.num_steps.values.flatten()
2✔
953
    else:
954
        vals = (
2✔
955
            np.concatenate(list(steps_ch.values()))
956
            if steps_ch
957
            else np.array([])
958
        )
959
    if vals.size:
2✔
960
        ax.hist(vals, bins=20, alpha=0.7, edgecolor="black")
2✔
961
        ax.set_title("Leapfrog Steps Distribution")
2✔
962
        ax.set_xlabel("Steps")
2✔
963
        ax.set_ylabel("Trajectories")
2✔
964
        ax.grid(True, alpha=0.3)
2✔
965
    else:
966
        ax.text(0.5, 0.5, "Steps data\nunavailable", ha="center", va="center")
×
967

968
    # Acceptance (overlay per-channel)
969
    ax = axes[1, 0]
2✔
970
    plotted = False
2✔
971
    if has_accept:
2✔
972
        ax.plot(
2✔
973
            idata.sample_stats.accept_prob.values.flatten(),
974
            alpha=0.8,
975
            lw=1,
976
            label="overall",
977
        )
978
        plotted = True
2✔
979
    for ch in sorted(accept_ch):
2✔
980
        ax.plot(accept_ch[ch], alpha=0.6, lw=1, label=f"ch {ch}")
2✔
981
        plotted = True
2✔
982
    if not plotted:
2✔
983
        ax.text(
×
984
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
985
        )
986
    ax.set_title("Acceptance Trace")
2✔
987
    ax.set_xlabel("Iteration")
2✔
988
    ax.set_ylabel("accept_prob")
2✔
989
    ax.grid(True, alpha=0.3)
2✔
990
    if plotted:
2✔
991
        ax.legend(loc="best", fontsize="small")
2✔
992

993
    # Summary text
994
    ax = axes[1, 1]
2✔
995
    lines = []
2✔
996
    if has_steps or steps_ch:
2✔
997
        lines.append("Steps summary:")
2✔
998
        if has_steps:
2✔
999
            s = idata.sample_stats.num_steps.values.flatten()
2✔
1000
            lines.append(f"  overall μ={np.mean(s):.1f}, max={np.max(s):.0f}")
2✔
1001
        for ch in sorted(steps_ch):
2✔
1002
            s = steps_ch[ch]
2✔
1003
            lines.append(f"  ch {ch} μ={np.mean(s):.1f}, max={np.max(s):.0f}")
2✔
1004
        lines.append("")
2✔
1005
    if has_accept or accept_ch:
2✔
1006
        lines.append("Acceptance summary:")
2✔
1007
        if has_accept:
2✔
1008
            a = idata.sample_stats.accept_prob.values.flatten()
2✔
1009
            lines.append(f"  overall μ={np.mean(a):.3f}")
2✔
1010
        for ch in sorted(accept_ch):
2✔
1011
            a = accept_ch[ch]
2✔
1012
            lines.append(f"  ch {ch} μ={np.mean(a):.3f}")
2✔
1013
    if lines:
2✔
1014
        ax.text(
2✔
1015
            0.05,
1016
            0.95,
1017
            "\n".join(lines),
1018
            transform=ax.transAxes,
1019
            va="top",
1020
            family="monospace",
1021
        )
1022
    ax.set_title("NUTS Diagnostics Summary")
2✔
1023
    ax.axis("off")
2✔
1024

1025
    plt.tight_layout()
2✔
1026

1027

1028
def _plot_single_nuts_block(idata, config, channel_idx: int):
2✔
1029
    """NUTS diagnostics for a single blocked channel."""
1030

1031
    def _get(key):
2✔
1032
        full_key = f"{key}_channel_{channel_idx}"
2✔
1033
        return (
2✔
1034
            idata.sample_stats[full_key].values.flatten()
1035
            if full_key in idata.sample_stats
1036
            else None
1037
        )
1038

1039
    energy = _get("energy")
2✔
1040
    potential = _get("potential_energy")
2✔
1041
    num_steps = _get("num_steps")
2✔
1042
    accept_prob = _get("accept_prob")
2✔
1043

1044
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1045

1046
    # Energy traces
1047
    ax = axes[0, 0]
2✔
1048
    plotted = False
2✔
1049
    if energy is not None:
2✔
1050
        ax.plot(energy, alpha=0.7, lw=1, label="H")
2✔
1051
        plotted = True
2✔
1052
    if potential is not None:
2✔
1053
        ax.plot(potential, alpha=0.7, lw=1, label="P")
2✔
1054
        plotted = True
2✔
1055
    if not plotted:
2✔
1056
        ax.text(0.5, 0.5, "Energy data\nunavailable", ha="center", va="center")
×
1057
    ax.set_title(f"Channel {channel_idx} Energy")
2✔
1058
    ax.set_xlabel("Iteration")
2✔
1059
    ax.set_ylabel("Energy")
2✔
1060
    ax.grid(True, alpha=0.3)
2✔
1061
    if plotted:
2✔
1062
        ax.legend(loc="best", fontsize="small")
2✔
1063

1064
    # Acceptance trace
1065
    ax = axes[0, 1]
2✔
1066
    if accept_prob is not None:
2✔
1067
        ax.axhspan(0.7, 0.9, alpha=0.1, color="green")
2✔
1068
        ax.axhspan(0.0, 0.6, alpha=0.1, color="red")
2✔
1069
        ax.axhspan(0.9, 1.0, alpha=0.1, color="orange")
2✔
1070
        ax.plot(accept_prob, alpha=0.8, lw=1, color="purple")
2✔
1071
        ax.axhline(0.8, color="red", linestyle="--", lw=1.5, label="target")
2✔
1072
        ax.set_ylim(0, 1)
2✔
1073
        ax.legend(loc="best", fontsize="small")
2✔
1074
        ax.grid(True, alpha=0.3)
2✔
1075
    else:
1076
        ax.text(
×
1077
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1078
        )
1079
    ax.set_title(f"Channel {channel_idx} Acceptance")
2✔
1080
    ax.set_xlabel("Iteration")
2✔
1081
    ax.set_ylabel("accept_prob")
2✔
1082

1083
    # Steps histogram
1084
    ax = axes[1, 0]
2✔
1085
    if num_steps is not None and num_steps.size:
2✔
1086
        ax.hist(num_steps, bins=20, alpha=0.7, edgecolor="black")
2✔
1087
        ax.set_xlabel("Steps")
2✔
1088
        ax.set_ylabel("Trajectories")
2✔
1089
        ax.grid(True, alpha=0.3)
2✔
1090
    else:
1091
        ax.text(0.5, 0.5, "Steps data\nunavailable", ha="center", va="center")
×
1092
    ax.set_title(f"Channel {channel_idx} Leapfrog Steps")
2✔
1093

1094
    # Summary stats
1095
    ax = axes[1, 1]
2✔
1096
    stats_lines = [f"Channel {channel_idx} summary:"]
2✔
1097
    if energy is not None:
2✔
1098
        stats_lines.append(
2✔
1099
            f"  H μ={np.mean(energy):.2f}, σ={np.std(energy):.2f}"
1100
        )
1101
    if potential is not None:
2✔
1102
        stats_lines.append(
2✔
1103
            f"  P μ={np.mean(potential):.2f}, σ={np.std(potential):.2f}"
1104
        )
1105
    if num_steps is not None:
2✔
1106
        stats_lines.append(
2✔
1107
            f"  steps μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1108
        )
1109
    if accept_prob is not None:
2✔
1110
        stats_lines.append(f"  accept μ={np.mean(accept_prob):.3f}")
2✔
1111

1112
    ax.text(
2✔
1113
        0.05,
1114
        0.95,
1115
        "\n".join(stats_lines),
1116
        transform=ax.transAxes,
1117
        va="top",
1118
        family="monospace",
1119
    )
1120
    ax.axis("off")
2✔
1121
    ax.set_title("Summary")
2✔
1122

1123
    plt.tight_layout()
2✔
1124

1125

1126
def _create_sampler_diagnostics(idata, diag_dir, config):
2✔
1127
    """Create sampler-specific diagnostics."""
1128

1129
    # Better sampler detection - check sampler type first
1130
    sampler_type = (
2✔
1131
        idata.attrs["sampler_type"].lower()
1132
        if "sampler_type" in idata.attrs
1133
        else "unknown"
1134
    )
1135

1136
    # Check for NUTS-specific fields that MH definitely doesn't have
1137
    nuts_specific_fields = [
2✔
1138
        "energy",
1139
        "num_steps",
1140
        "tree_depth",
1141
        "diverging",
1142
        "energy_error",
1143
    ]
1144

1145
    has_nuts = (
2✔
1146
        any(field in idata.sample_stats for field in nuts_specific_fields)
1147
        or "nuts" in sampler_type
1148
    )
1149

1150
    # Check for MH-specific fields (exclude anything NUTS might have)
1151
    has_mh = "step_size_mean" in idata.sample_stats and not has_nuts
2✔
1152

1153
    if has_nuts:
2✔
1154

1155
        @safe_plot(f"{diag_dir}/nuts_diagnostics.png", config.dpi)
2✔
1156
        def plot_nuts():
2✔
1157
            _plot_nuts_diagnostics_blockaware(idata, config)
2✔
1158

1159
        plot_nuts()
2✔
1160

1161
        # Per‑channel NUTS diagnostics for blocked samplers
1162
        channel_indices = _get_channel_indices(
2✔
1163
            idata.sample_stats, "accept_prob"
1164
        )
1165
        channel_indices |= _get_channel_indices(idata.sample_stats, "energy")
2✔
1166
        channel_indices |= _get_channel_indices(
2✔
1167
            idata.sample_stats, "potential_energy"
1168
        )
1169
        channel_indices |= _get_channel_indices(
2✔
1170
            idata.sample_stats, "num_steps"
1171
        )
1172

1173
        for channel_idx in sorted(channel_indices):
2✔
1174

1175
            @safe_plot(
2✔
1176
                f"{diag_dir}/nuts_block_{channel_idx}_diagnostics.png",
1177
                config.dpi,
1178
            )
1179
            def plot_nuts_block(channel_idx=channel_idx):
2✔
1180
                _plot_single_nuts_block(idata, config, channel_idx)
2✔
1181

1182
            plot_nuts_block()
2✔
1183
    elif has_mh:
2✔
1184

1185
        @safe_plot(f"{diag_dir}/mh_step_sizes.png", config.dpi)
2✔
1186
        def plot_mh():
2✔
1187
            _plot_mh_step_sizes(idata, config)
2✔
1188

1189
        plot_mh()
2✔
1190

1191

1192
def _plot_nuts_diagnostics(idata, config):
2✔
1193
    """NUTS diagnostics with enhanced information."""
1194
    # Determine available data to decide layout
1195
    has_energy = "energy" in idata.sample_stats
×
1196
    has_potential = "potential_energy" in idata.sample_stats
×
1197
    has_steps = "num_steps" in idata.sample_stats
×
1198
    has_accept = "accept_prob" in idata.sample_stats
×
1199
    has_divergences = "diverging" in idata.sample_stats
×
1200
    has_tree_depth = "tree_depth" in idata.sample_stats
×
1201
    has_energy_error = "energy_error" in idata.sample_stats
×
1202

1203
    # Create a 2x2 layout, potentially combining energy and potential on same plot
1204
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
×
1205

1206
    # Top-left: Energy diagnostics (combine Hamiltonian and Potential if both available)
1207
    energy_ax = axes[0, 0]
×
1208

1209
    if has_energy and has_potential:
×
1210
        # Both available - plot them together on one plot
1211
        energy = idata.sample_stats.energy.values.flatten()
×
1212
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1213

1214
        # Plot both energies on same axis
1215
        energy_ax.plot(
×
1216
            energy, alpha=0.7, linewidth=1, color="blue", label="Hamiltonian"
1217
        )
1218
        energy_ax.plot(
×
1219
            potential,
1220
            alpha=0.7,
1221
            linewidth=1,
1222
            color="orange",
1223
            label="Potential",
1224
        )
1225

1226
        # Add difference (which relates to kinetic energy)
1227
        energy_diff = energy - potential
×
1228
        # Create second y-axis for difference
1229
        ax2 = energy_ax.twinx()
×
1230
        ax2.plot(
×
1231
            energy_diff,
1232
            alpha=0.5,
1233
            linewidth=1,
1234
            color="red",
1235
            label="H - Potential (Kinetic)",
1236
            linestyle="--",
1237
        )
1238
        ax2.set_ylabel("Energy Difference", color="red")
×
1239
        ax2.tick_params(axis="y", labelcolor="red")
×
1240

1241
        energy_ax.set_xlabel("Iteration")
×
1242
        energy_ax.set_ylabel("Energy", color="blue")
×
1243
        energy_ax.tick_params(axis="y", labelcolor="blue")
×
1244
        energy_ax.set_title("Hamiltonian & Potential Energy")
×
1245
        energy_ax.legend(loc="best", fontsize="small")
×
1246
        energy_ax.grid(True, alpha=0.3)
×
1247

1248
        # Add statistics
1249
        energy_ax.text(
×
1250
            0.02,
1251
            0.98,
1252
            f"H: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}\nP: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}",
1253
            transform=energy_ax.transAxes,
1254
            fontsize=8,
1255
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1256
            verticalalignment="top",
1257
        )
1258

1259
    elif has_energy:
×
1260
        # Only Hamiltonian energy
1261
        energy = idata.sample_stats.energy.values.flatten()
×
1262
        energy_ax.plot(energy, alpha=0.7, linewidth=1, color="blue")
×
1263
        energy_ax.set_xlabel("Iteration")
×
1264
        energy_ax.set_ylabel("Hamiltonian Energy")
×
1265
        energy_ax.set_title("Hamiltonian Energy Trace")
×
1266
        energy_ax.grid(True, alpha=0.3)
×
1267

1268
    elif has_potential:
×
1269
        # Only potential energy
1270
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1271
        energy_ax.plot(potential, alpha=0.7, linewidth=1, color="orange")
×
1272
        energy_ax.set_xlabel("Iteration")
×
1273
        energy_ax.set_ylabel("Potential Energy")
×
1274
        energy_ax.set_title("Potential Energy Trace")
×
1275
        energy_ax.grid(True, alpha=0.3)
×
1276

1277
    else:
1278
        energy_ax.text(
×
1279
            0.5,
1280
            0.5,
1281
            "Energy data\nunavailable",
1282
            ha="center",
1283
            va="center",
1284
            transform=energy_ax.transAxes,
1285
        )
1286
        energy_ax.set_title("Energy Diagnostics")
×
1287

1288
    # Top-right: Sampling efficiency diagnostics
1289
    if has_steps:
×
1290
        steps_ax = axes[0, 1]
×
1291
        num_steps = idata.sample_stats.num_steps.values.flatten()
×
1292

1293
        # Show histogram with color zones for step efficiency
1294
        n, bins, edges = steps_ax.hist(
×
1295
            num_steps, bins=20, alpha=0.7, edgecolor="black"
1296
        )
1297

1298
        # Add shaded regions for different efficiency levels
1299
        # Green: efficient (tree depth ≤5, ~32 steps)
1300
        # Yellow: moderate (tree depth 6-8, ~64-256 steps)
1301
        # Red: inefficient (tree depth >8, >256 steps)
1302
        steps_ax.axvspan(
×
1303
            0, 64, alpha=0.1, color="green", label="Efficient (≤64)"
1304
        )
1305
        steps_ax.axvspan(
×
1306
            64, 256, alpha=0.1, color="yellow", label="Moderate (65-256)"
1307
        )
1308
        steps_ax.axvspan(
×
1309
            256,
1310
            np.max(num_steps),
1311
            alpha=0.1,
1312
            color="red",
1313
            label="Inefficient (>256)",
1314
        )
1315

1316
        # Add reference lines for different tree depths
1317
        for depth in [5, 7, 10]:  # Common tree depths
×
1318
            max_steps = 2**depth
×
1319
            steps_ax.axvline(
×
1320
                x=max_steps,
1321
                color="gray",
1322
                linestyle=":",
1323
                alpha=0.7,
1324
                linewidth=1,
1325
                label=f"2^{depth} ({max_steps})",
1326
            )
1327

1328
        steps_ax.set_xlabel("Leapfrog Steps")
×
1329
        steps_ax.set_ylabel("Trajectories")
×
1330
        steps_ax.set_title("Leapfrog Steps Distribution")
×
1331
        steps_ax.legend(loc="best", fontsize="small")
×
1332
        steps_ax.grid(True, alpha=0.3)
×
1333

1334
        # Add efficiency statistics
1335
        pct_inefficient = (num_steps > 256).mean() * 100
×
1336
        pct_moderate = ((num_steps > 64) & (num_steps <= 256)).mean() * 100
×
1337
        pct_efficient = (num_steps <= 64).mean() * 100
×
1338
        steps_ax.text(
×
1339
            0.02,
1340
            0.98,
1341
            f"Efficient: {pct_efficient:.1f}%\nModerate: {pct_moderate:.1f}%\nInefficient: {pct_inefficient:.1f}%\nMean steps: {np.mean(num_steps):.1f}",
1342
            transform=steps_ax.transAxes,
1343
            fontsize=7,
1344
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1345
            verticalalignment="top",
1346
        )
1347

1348
    else:
1349
        axes[0, 1].text(
×
1350
            0.5, 0.5, "Steps data\nunavailable", ha="center", va="center"
1351
        )
1352
        axes[0, 1].set_title("Sampling Steps")
×
1353

1354
    # Bottom-left: Acceptance and NS divergence diagnostics
1355
    accept_ax = axes[1, 0]
×
1356

1357
    if has_accept:
×
1358
        accept_prob = idata.sample_stats.accept_prob.values.flatten()
×
1359

1360
        # Plot acceptance probability with guidance zones
1361
        accept_ax.fill_between(
×
1362
            range(len(accept_prob)),
1363
            0.7,
1364
            0.9,
1365
            alpha=0.1,
1366
            color="green",
1367
            label="Good (0.7-0.9)",
1368
        )
1369
        accept_ax.fill_between(
×
1370
            range(len(accept_prob)),
1371
            0,
1372
            0.6,
1373
            alpha=0.1,
1374
            color="red",
1375
            label="Too low",
1376
        )
1377
        accept_ax.fill_between(
×
1378
            range(len(accept_prob)),
1379
            0.9,
1380
            1.0,
1381
            alpha=0.1,
1382
            color="orange",
1383
            label="Too high",
1384
        )
1385

1386
        accept_ax.plot(
×
1387
            accept_prob,
1388
            alpha=0.8,
1389
            linewidth=1,
1390
            color="blue",
1391
            label="Acceptance prob",
1392
        )
1393
        accept_ax.axhline(
×
1394
            0.8,
1395
            color="red",
1396
            linestyle="--",
1397
            linewidth=2,
1398
            label="NUTS target (0.8)",
1399
        )
1400
        accept_ax.set_xlabel("Iteration")
×
1401
        accept_ax.set_ylabel("Acceptance Probability")
×
1402
        accept_ax.set_title("NUTS Acceptance Diagnostic")
×
1403
        accept_ax.legend(loc="best", fontsize="small")
×
1404
        accept_ax.set_ylim(0, 1)
×
1405
        accept_ax.grid(True, alpha=0.3)
×
1406

1407
    else:
1408
        accept_ax.text(
×
1409
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1410
        )
1411
        accept_ax.set_title("Acceptance Diagnostic")
×
1412

1413
    # Bottom-right: Summary statistics and additional diagnostics
1414
    summary_ax = axes[1, 1]
×
1415

1416
    # Collect available statistics
1417
    stats_lines = []
×
1418

1419
    if has_energy:
×
1420
        energy = idata.sample_stats.energy.values.flatten()
×
1421
        stats_lines.append(
×
1422
            f"Energy: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}"
1423
        )
1424

1425
    if has_potential:
×
1426
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1427
        stats_lines.append(
×
1428
            f"Potential: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}"
1429
        )
1430

1431
    if has_steps:
×
1432
        num_steps = idata.sample_stats.num_steps.values.flatten()
×
1433
        stats_lines.append(
×
1434
            f"Steps: μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1435
        )
1436
        stats_lines.append("")
×
1437

1438
    if has_tree_depth:
×
1439
        tree_depth = idata.sample_stats.tree_depth.values.flatten()
×
1440
        stats_lines.append(f"Tree depth: μ={np.mean(tree_depth):.1f}")
×
1441
        pct_max_depth = (tree_depth >= 10).mean() * 100
×
1442
        stats_lines.append(f"Max depth (≥10): {pct_max_depth:.1f}%")
×
1443

1444
    if has_divergences:
×
1445
        divergences = idata.sample_stats.diverging.values.flatten()
×
1446
        n_divergences = np.sum(divergences)
×
1447
        pct_divergent = n_divergences / len(divergences) * 100
×
1448
        stats_lines.append(
×
1449
            f"Divergent: {n_divergences}/{len(divergences)} ({pct_divergent:.2f}%)"
1450
        )
1451

1452
    if has_energy_error:
×
1453
        energy_error = idata.sample_stats.energy_error.values.flatten()
×
1454
        stats_lines.append(
×
1455
            f"Energy error: |μ|={np.mean(np.abs(energy_error)):.3f}"
1456
        )
1457

1458
    if not stats_lines:
×
1459
        summary_ax.text(
×
1460
            0.5,
1461
            0.5,
1462
            "No diagnostics\ndata available",
1463
            ha="center",
1464
            va="center",
1465
            transform=summary_ax.transAxes,
1466
        )
1467
        summary_ax.set_title("NUTS Statistics")
×
1468
        summary_ax.axis("off")
×
1469
    else:
1470
        summary_text = "\n".join(["NUTS Diagnostics:"] + [""] + stats_lines)
×
1471
        summary_ax.text(
×
1472
            0.05,
1473
            0.95,
1474
            summary_text,
1475
            transform=summary_ax.transAxes,
1476
            fontsize=10,
1477
            verticalalignment="top",
1478
            fontfamily="monospace",
1479
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1480
        )
1481
        summary_ax.set_title("NUTS Summary Statistics")
×
1482
        summary_ax.axis("off")
×
1483

1484
    plt.tight_layout()
×
1485

1486

1487
def _plot_mh_step_sizes(idata, config):
2✔
1488
    """MH step size diagnostics."""
1489
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1490

1491
    step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
1492
    step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
1493

1494
    # Step size evolution
1495
    axes[0, 0].plot(
2✔
1496
        step_means, alpha=0.7, linewidth=1, label="Mean", color="blue"
1497
    )
1498
    axes[0, 0].plot(
2✔
1499
        step_stds, alpha=0.7, linewidth=1, label="Std", color="orange"
1500
    )
1501
    axes[0, 0].set_xlabel("Iteration")
2✔
1502
    axes[0, 0].set_ylabel("Step Size")
2✔
1503
    axes[0, 0].set_title("Step Size Evolution")
2✔
1504
    axes[0, 0].legend()
2✔
1505
    axes[0, 0].grid(True, alpha=0.3)
2✔
1506

1507
    # Step size distributions
1508
    axes[0, 1].hist(step_means, bins=30, alpha=0.5, label="Mean", color="blue")
2✔
1509
    axes[0, 1].hist(step_stds, bins=30, alpha=0.5, label="Std", color="orange")
2✔
1510
    axes[0, 1].set_xlabel("Step Size")
2✔
1511
    axes[0, 1].set_ylabel("Count")
2✔
1512
    axes[0, 1].set_title("Step Size Distributions")
2✔
1513
    axes[0, 1].legend()
2✔
1514
    axes[0, 1].grid(True, alpha=0.3)
2✔
1515

1516
    # Step size adaptation quality
1517
    axes[1, 0].plot(step_means / step_stds, alpha=0.7, linewidth=1)
2✔
1518
    axes[1, 0].set_xlabel("Iteration")
2✔
1519
    axes[1, 0].set_ylabel("Mean / Std")
2✔
1520
    axes[1, 0].set_title("Step Size Consistency")
2✔
1521
    axes[1, 0].grid(True, alpha=0.3)
2✔
1522

1523
    # Summary statistics
1524
    summary_lines = [
2✔
1525
        "Step Size Summary:",
1526
        f"Final mean: {step_means[-1]:.4f}",
1527
        f"Final std: {step_stds[-1]:.4f}",
1528
        f"Mean of means: {np.mean(step_means):.4f}",
1529
        f"Mean of stds: {np.mean(step_stds):.4f}",
1530
        "",
1531
        "Adaptation Quality:",
1532
        f"CV of means: {np.std(step_means)/np.mean(step_means):.3f}",
1533
        f"CV of stds: {np.std(step_stds)/np.mean(step_stds):.3f}",
1534
    ]
1535

1536
    axes[1, 1].text(
2✔
1537
        0.05,
1538
        0.95,
1539
        "\n".join(summary_lines),
1540
        transform=axes[1, 1].transAxes,
1541
        fontsize=10,
1542
        verticalalignment="top",
1543
        fontfamily="monospace",
1544
    )
1545
    axes[1, 1].set_title("Step Size Statistics")
2✔
1546
    axes[1, 1].axis("off")
2✔
1547

1548
    plt.tight_layout()
2✔
1549

1550

1551
def _create_divergences_diagnostics(idata, diag_dir, config):
2✔
1552
    """Create divergences diagnostics for NUTS samplers."""
1553
    # Check if divergences data exists
1554
    has_divergences = "diverging" in idata.sample_stats
2✔
1555
    has_channel_divergences = any(
2✔
1556
        key.startswith("diverging_channel_") for key in idata.sample_stats
1557
    )
1558

1559
    if not has_divergences and not has_channel_divergences:
2✔
1560
        return  # Nothing to plot
2✔
1561

1562
    @safe_plot(f"{diag_dir}/divergences.png", config.dpi)
2✔
1563
    def plot_divergences():
2✔
1564
        _plot_divergences(idata, config)
2✔
1565

1566
    plot_divergences()
2✔
1567

1568

1569
def _plot_divergences(idata, config):
2✔
1570
    """Plot divergences diagnostics."""
1571
    # Collect all divergence data
1572
    divergences_data = {}
2✔
1573

1574
    # Check for main divergences (single chain NUTS)
1575
    if "diverging" in idata.sample_stats:
2✔
1576
        divergences_data["main"] = (
2✔
1577
            idata.sample_stats.diverging.values.flatten()
1578
        )
1579

1580
    # Check for channel-specific divergences (blocked NUTS)
1581
    channel_divergences = {}
2✔
1582
    for key in idata.sample_stats:
2✔
1583
        if key.startswith("diverging_channel_"):
2✔
1584
            channel_idx = key.replace("diverging_channel_", "")
2✔
1585
            channel_divergences[int(channel_idx)] = idata.sample_stats[
2✔
1586
                key
1587
            ].values.flatten()
1588

1589
    if channel_divergences:
2✔
1590
        divergences_data.update(channel_divergences)
2✔
1591

1592
    if not divergences_data:
2✔
1593
        fig, ax = plt.subplots(figsize=config.figsize)
×
1594
        ax.text(
×
1595
            0.5, 0.5, "No divergence data available", ha="center", va="center"
1596
        )
1597
        ax.set_title("Divergences Diagnostics")
×
1598
        return
×
1599

1600
    # Create subplot layout
1601
    n_plots = len(divergences_data)
2✔
1602
    if n_plots == 1:
2✔
1603
        fig, axes = plt.subplots(1, 2, figsize=config.figsize)
2✔
1604
        trace_ax, summary_ax = axes
2✔
1605
    else:
1606
        # Multiple plots - arrange in grid
1607
        cols = 2
2✔
1608
        rows = (n_plots + 1) // cols  # Ceiling division
2✔
1609
        fig, axes = plt.subplots(rows, cols, figsize=config.figsize)
2✔
1610
        if rows == 1:
2✔
1611
            axes = axes.reshape(1, -1)
2✔
1612
        axes = axes.flatten()
2✔
1613

1614
        # Last plot goes in summary_ax if odd number
1615
        if n_plots % 2 == 1:
2✔
1616
            trace_axes = axes[:-1]
×
1617
            summary_ax = axes[-1]
×
1618
        else:
1619
            trace_axes = axes
2✔
1620
            summary_ax = None
2✔
1621

1622
    # Plot divergences traces
1623
    total_divergences = 0
2✔
1624
    total_iterations = 0
2✔
1625

1626
    plot_idx = 0
2✔
1627
    for label, div_values in divergences_data.items():
2✔
1628
        if label == "main":
2✔
1629
            title = "NUTS Divergences"
2✔
1630
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1631
        else:
1632
            title = f"Channel {label} Divergences"
2✔
1633
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1634
            plot_idx += 1
2✔
1635

1636
        # Plot divergence indicators (where divergences occur)
1637
        div_indices = np.where(div_values)[0]
2✔
1638
        ax.scatter(
2✔
1639
            div_indices,
1640
            np.ones_like(div_indices),
1641
            color="red",
1642
            marker="x",
1643
            s=50,
1644
            linewidth=2,
1645
            label="Divergent",
1646
            alpha=0.8,
1647
        )
1648

1649
        # Add background shading for divergent regions
1650
        if len(div_indices) > 0:
2✔
1651
            for idx in div_indices:
2✔
1652
                ax.axvspan(idx - 0.5, idx + 0.5, alpha=0.2, color="red")
2✔
1653

1654
        ax.set_xlabel("Iteration")
2✔
1655
        ax.set_ylabel("Divergence Indicator")
2✔
1656
        ax.set_title(title)
2✔
1657
        ax.set_yticks([0, 1])
2✔
1658
        ax.set_yticklabels(["No", "Yes"])
2✔
1659
        ax.grid(True, alpha=0.3)
2✔
1660

1661
        # Add statistics
1662
        n_divergent = np.sum(div_values)
2✔
1663
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1664
        stats_text = f"{n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
2✔
1665
        ax.text(
2✔
1666
            0.02,
1667
            0.98,
1668
            stats_text,
1669
            transform=ax.transAxes,
1670
            fontsize=10,
1671
            bbox=dict(boxstyle="round", facecolor="lightcoral", alpha=0.8),
1672
            verticalalignment="top",
1673
        )
1674

1675
        total_divergences += n_divergent
2✔
1676
        total_iterations += len(div_values)
2✔
1677

1678
        # Legend only if there are divergences
1679
        if n_divergent > 0:
2✔
1680
            ax.legend(loc="upper right", fontsize="small")
2✔
1681

1682
    # Summary plot
1683
    if summary_ax is not None and n_plots > 1:
2✔
1684
        summary_ax.text(
×
1685
            0.05,
1686
            0.95,
1687
            _get_divergences_summary(divergences_data),
1688
            transform=summary_ax.transAxes,
1689
            fontsize=12,
1690
            verticalalignment="top",
1691
            fontfamily="monospace",
1692
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1693
        )
1694
        summary_ax.set_title("Divergences Summary")
×
1695
        summary_ax.axis("off")
×
1696
    elif n_plots == 1:
2✔
1697
        axes[1].text(
2✔
1698
            0.05,
1699
            0.95,
1700
            _get_divergences_summary(divergences_data),
1701
            transform=axes[1].transAxes,
1702
            fontsize=12,
1703
            verticalalignment="top",
1704
            fontfamily="monospace",
1705
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1706
        )
1707
        axes[1].set_title("Divergences Summary")
2✔
1708
        axes[1].axis("off")
2✔
1709

1710
    # Overall title
1711
    overall_pct = (
2✔
1712
        total_divergences / total_iterations * 100
1713
        if total_iterations > 0
1714
        else 0
1715
    )
1716
    fig.suptitle(f"Overall Divergences: {overall_pct:.2f}%")
2✔
1717

1718
    plt.tight_layout()
2✔
1719

1720

1721
def _get_divergences_summary(divergences_data):
2✔
1722
    """Generate text summary of divergences."""
1723
    lines = ["Divergences Summary:", ""]
2✔
1724

1725
    total_divergences = 0
2✔
1726
    total_iterations = 0
2✔
1727

1728
    for label, div_values in divergences_data.items():
2✔
1729
        n_divergent = np.sum(div_values)
2✔
1730
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1731

1732
        if label == "main":
2✔
1733
            lines.append(
2✔
1734
                f"NUTS: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1735
            )
1736
        else:
1737
            lines.append(
×
1738
                f"Channel {label}: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1739
            )
1740

1741
        total_divergences += n_divergent
2✔
1742
        total_iterations += len(div_values)
2✔
1743

1744
    lines.append("")
2✔
1745
    overall_pct = (
2✔
1746
        total_divergences / total_iterations * 100
1747
        if total_iterations > 0
1748
        else 0
1749
    )
1750
    lines.append(
2✔
1751
        f"Total: {total_divergences}/{total_iterations} ({overall_pct:.2f}%)"
1752
    )
1753

1754
    lines.append("")
2✔
1755
    lines.append("Interpretation:")
2✔
1756
    if overall_pct == 0:
2✔
1757
        lines.append("  ✓ No divergences detected")
×
1758
        lines.append("    Sampling appears well-behaved")
×
1759
    elif overall_pct < 0.1:
2✔
1760
        lines.append("  ~ Few divergences")
×
1761
        lines.append("    Generally good, but monitor")
×
1762
    elif overall_pct < 1.0:
2✔
1763
        lines.append("  ⚠ Some divergences detected")
×
1764
        lines.append("    May indicate sampling issues")
×
1765
    else:
1766
        lines.append("  ✗ Many divergences!")
2✔
1767
        lines.append("    Significant sampling problems")
2✔
1768
        lines.append("    Consider model reparameterization")
2✔
1769

1770
    return "\n".join(lines)
2✔
1771

1772

1773
def _group_parameters_simple(idata):
2✔
1774
    """Simple parameter grouping for counting."""
1775
    param_groups = {"phi": [], "delta": [], "weights": [], "other": []}
2✔
1776

1777
    for param in idata.posterior.data_vars:
2✔
1778
        if param.startswith("phi"):
2✔
1779
            param_groups["phi"].append(param)
2✔
1780
        elif param.startswith("delta"):
2✔
1781
            param_groups["delta"].append(param)
2✔
1782
        elif param.startswith("weights"):
2✔
1783
            param_groups["weights"].append(param)
2✔
1784
        else:
1785
            param_groups["other"].append(param)
×
1786

1787
    return {k: v for k, v in param_groups.items() if v}
2✔
1788

1789

1790
def generate_diagnostics_summary(idata, outdir):
2✔
1791
    """Generate comprehensive text summary using computed diagnostics."""
1792
    summary = []
2✔
1793
    summary.append("=== MCMC Diagnostics Summary ===\n")
2✔
1794

1795
    attrs = getattr(idata, "attrs", {}) or {}
2✔
1796
    if not hasattr(attrs, "get"):
2✔
1797
        attrs = dict(attrs)
×
1798

1799
    n_samples = idata.posterior.sizes.get("draw", 0)
2✔
1800
    n_chains = idata.posterior.sizes.get("chain", 1)
2✔
1801
    n_params = len(list(idata.posterior.data_vars))
2✔
1802
    sampler_type = attrs.get("sampler_type", "Unknown")
2✔
1803

1804
    summary.append(f"Sampler: {sampler_type}")
2✔
1805
    summary.append(
2✔
1806
        f"Samples: {n_samples} per chain × {n_chains} chains = {n_samples * n_chains} total"
1807
    )
1808
    summary.append(f"Parameters: {n_params}")
2✔
1809

1810
    param_groups = _group_parameters_simple(idata)
2✔
1811
    if param_groups:
2✔
1812
        param_summary = ", ".join(
2✔
1813
            [f"{k}: {len(v)}" for k, v in param_groups.items()]
1814
        )
1815
        summary.append(f"Parameter groups: {param_summary}")
2✔
1816

1817
    diag_results = run_all_diagnostics(
2✔
1818
        idata=idata,
1819
        truth=attrs.get("true_psd"),
1820
        psd_ref=attrs.get("true_psd"),
1821
    )
1822

1823
    mcmc_diag = diag_results.get("mcmc", {})
2✔
1824
    if mcmc_diag:
2✔
1825
        ess_min = mcmc_diag.get("ess_bulk_min")
2✔
1826
        ess_med = mcmc_diag.get("ess_bulk_median")
2✔
1827
        if ess_min is not None:
2✔
1828
            summary.append(
2✔
1829
                f"\nESS bulk: min={ess_min:.0f}"
1830
                + (f", median={ess_med:.0f}" if ess_med is not None else "")
1831
            )
1832
        rhat_max = mcmc_diag.get("rhat_max")
2✔
1833
        rhat_mean = mcmc_diag.get("rhat_mean")
2✔
1834
        if rhat_max is not None:
2✔
1835
            summary.append(
×
1836
                f"Rhat: max={rhat_max:.3f}"
1837
                + (f", mean={rhat_mean:.3f}" if rhat_mean is not None else "")
1838
            )
1839
        acc = mcmc_diag.get("acceptance_rate_mean")
2✔
1840
        if acc is not None:
2✔
1841
            summary.append(f"Acceptance rate: {acc:.3f}")
2✔
1842
        div_frac = mcmc_diag.get("divergence_fraction")
2✔
1843
        if div_frac is not None:
2✔
1844
            summary.append(f"Divergence fraction: {div_frac*100:.2f}%")
2✔
1845
        khat = mcmc_diag.get("psis_khat_max")
2✔
1846
        if khat is not None:
2✔
1847
            summary.append(f"PSIS k-hat (max): {khat:.3f}")
×
1848

1849
    psd_diag = diag_results.get("psd_compare", {})
2✔
1850
    if psd_diag:
2✔
1851
        summary.append("\nPSD accuracy diagnostics:")
×
1852
        if "riae" in psd_diag:
×
1853
            summary.append(f"  RIAE: {psd_diag['riae']:.3f}")
×
1854
        if "riae_matrix" in psd_diag:
×
1855
            summary.append(f"  RIAE (matrix): {psd_diag['riae_matrix']:.3f}")
×
1856
        if "coverage" in psd_diag:
×
1857
            summary.append(f"  Coverage: {psd_diag['coverage']*100:.1f}%")
×
1858

1859
    # Overall assessment (best-effort)
1860
    summary.append("\nOverall Convergence Assessment:")
2✔
1861
    if mcmc_diag:
2✔
1862
        ess_ok = mcmc_diag.get("ess_bulk_min", 0) >= 400
2✔
1863
        rhat_ok = mcmc_diag.get("rhat_max", 0) <= 1.01
2✔
1864
        if ess_ok and rhat_ok:
2✔
1865
            summary.append("  Status: EXCELLENT ✓")
×
1866
        elif ess_ok or rhat_ok:
2✔
1867
            summary.append("  Status: GOOD ✓")
2✔
1868
        else:
1869
            summary.append("  Status: NEEDS ATTENTION ⚠")
×
1870
    else:
1871
        summary.append("  Status: UNKNOWN (insufficient diagnostics)")
×
1872

1873
    summary_text = "\n".join(summary)
2✔
1874

1875
    if outdir:
2✔
1876
        with open(f"{outdir}/diagnostics_summary.txt", "w") as f:
2✔
1877
            f.write(summary_text)
2✔
1878

1879
    logger.info(f"\n{summary_text}\n")
2✔
1880
    return summary_text
2✔
1881

1882

1883
def generate_vi_diagnostics_summary(
2✔
1884
    diagnostics: dict, outdir: Optional[str] = None, log: bool = True
1885
) -> str:
1886
    """Log and optionally write a concise VI diagnostics summary."""
1887
    if not diagnostics:
2✔
1888
        return ""
×
1889

1890
    lines = []
2✔
1891
    lines.append("=== VI Diagnostics Summary ===")
2✔
1892
    lines.append("")
2✔
1893

1894
    guide = diagnostics.get("guide", "vi")
2✔
1895
    lines.append(f"Guide: {guide}")
2✔
1896

1897
    khat_max = diagnostics.get("psis_khat_max")
2✔
1898
    if khat_max is not None and np.isfinite(khat_max):
2✔
1899
        status = diagnostics.get("psis_status_message") or diagnostics.get(
2✔
1900
            "psis_khat_status", ""
1901
        )
1902
        status_suffix = f" ({status})" if status else ""
2✔
1903
        lines.append(f"PSIS k-hat (max): {float(khat_max):.3f}{status_suffix}")
2✔
1904
        threshold = diagnostics.get("psis_khat_threshold", 0.7)
2✔
1905
        if khat_max > threshold:
2✔
1906
            lines.append(
2✔
1907
                f"PSIS alert: k-hat exceeds {threshold:.1f} -> posterior may be unreliable"
1908
            )
1909
    moment_summary = diagnostics.get("psis_moment_summary") or {}
2✔
1910
    weight_stats = moment_summary.get("weights")
2✔
1911
    if weight_stats:
2✔
1912
        frac = weight_stats.get("frac_outside")
2✔
1913
        lines.append(
2✔
1914
            "Weight var_ratio "
1915
            + ", ".join(
1916
                [
1917
                    f"min={weight_stats.get('var_ratio_min', np.nan):.2f}",
1918
                    f"median={weight_stats.get('var_ratio_median', np.nan):.2f}",
1919
                    f"max={weight_stats.get('var_ratio_max', np.nan):.2f}",
1920
                    (
1921
                        f"outside[0.7,1.3]={frac*100:.1f}%"
1922
                        if frac is not None
1923
                        else "outside[0.7,1.3]=n/a"
1924
                    ),
1925
                ]
1926
            )
1927
        )
1928
    hyper_params = moment_summary.get("hyperparameters") or []
2✔
1929
    if hyper_params:
2✔
1930
        lines.append("PSIS moments (hyperparameters):")
2✔
1931
        for entry in hyper_params:
2✔
1932
            status = diagnostics.get("psis_status_message") or ""
2✔
1933
            var_ratio = entry["var_ratio"]
2✔
1934
            bias_pct = entry["bias_pct"]
2✔
1935
            thresholds = moment_summary.get("thresholds", {})
2✔
1936
            bias_thr = thresholds.get("bias_threshold", 0.05) * 100.0
2✔
1937
            var_low = thresholds.get("var_low", 0.7)
2✔
1938
            var_high = thresholds.get("var_high", 1.3)
2✔
1939
            status_label = "OK"
2✔
1940
            if abs(bias_pct) > bias_thr:
2✔
1941
                status_label = f"⚠ bias>{bias_thr:.0f}%"
2✔
1942
            if var_ratio < var_low:
2✔
1943
                status_label = "⚠ under-dispersed"
2✔
1944
            elif var_ratio > var_high:
2✔
1945
                status_label = "⚠ over-dispersed"
2✔
1946
            lines.append(
2✔
1947
                f"  {entry['param']}: "
1948
                f"μ_vi={entry['vi_mean']:.3g}, μ_psis={entry['psis_mean']:.3g}, "
1949
                f"bias={entry['bias_pct']:.1f}%, "
1950
                f"σ_vi={entry['vi_std']:.3g}, σ_psis={entry['psis_std']:.3g}, "
1951
                f"var_ratio={entry['var_ratio']:.2f} {status_label}"
1952
            )
1953
    corr_summary = diagnostics.get("psis_correlation_summary") or {}
2✔
1954
    for label, stats in corr_summary.items():
2✔
1955
        if not stats:
2✔
1956
            continue
×
1957
        line = (
2✔
1958
            f"Corr ({label}): max|r|={stats.get('max_abs', np.nan):.3f}, "
1959
            f"median|r|={stats.get('median_abs', np.nan):.3f}"
1960
        )
1961
        if "mean_corr_diff" in stats:
2✔
1962
            line += f", mean|Δ| vs ref={stats['mean_corr_diff']:.3f}"
×
1963
        lines.append(line)
2✔
1964

1965
    # Overall quality indicator
1966
    quality = "OK"
2✔
1967
    if diagnostics.get("psis_flag_critical"):
2✔
1968
        quality = "❌ NOT TRUSTWORTHY"
2✔
1969
    elif diagnostics.get("psis_flag_warn"):
2✔
1970
        quality = "⚠ USE WITH CAUTION"
×
1971
    else:
1972
        # Escalate if hyperparameter moments look off
1973
        for entry in hyper_params:
2✔
1974
            thresholds = moment_summary.get("thresholds", {})
2✔
1975
            bias_thr = thresholds.get("bias_threshold", 0.05) * 100.0
2✔
1976
            var_low = thresholds.get("var_low", 0.7)
2✔
1977
            var_high = thresholds.get("var_high", 1.3)
2✔
1978
            if (
2✔
1979
                abs(entry["bias_pct"]) > bias_thr
1980
                or entry["var_ratio"] < var_low
1981
                or entry["var_ratio"] > var_high
1982
            ):
1983
                quality = "⚠ USE WITH CAUTION"
2✔
1984
                break
2✔
1985
    lines.append(f"Overall VI Quality: {quality}")
2✔
1986

1987
    losses = diagnostics.get("losses")
2✔
1988
    if losses is not None:
2✔
1989
        loss_arr = np.asarray(losses)
2✔
1990
        if loss_arr.size:
2✔
1991
            lines.append(f"Final ELBO: {float(loss_arr.reshape(-1)[-1]):.3f}")
2✔
1992

1993
    vi_samples = diagnostics.get("vi_samples")
2✔
1994
    if vi_samples:
2✔
1995
        first = next(iter(vi_samples.values()))
2✔
1996
        n_draws = np.asarray(first).shape[0]
2✔
1997
        lines.append(f"Posterior draws (VI): {n_draws}")
2✔
1998

1999
    psd_shape = None
2✔
2000
    if "psd_matrix" in diagnostics and diagnostics["psd_matrix"] is not None:
2✔
2001
        psd_shape = np.asarray(diagnostics["psd_matrix"]).shape
2✔
2002
    else:
2003
        real_q = diagnostics.get("psd_quantiles", {}).get("real") or {}
2✔
2004
        q50 = real_q.get("q50")
2✔
2005
        if q50 is not None:
2✔
2006
            psd_shape = np.asarray(q50).shape
×
2007
    if psd_shape is not None and len(psd_shape) >= 3:
2✔
2008
        lines.append(
2✔
2009
            f"PSD shape: {psd_shape[0]} freq × {psd_shape[1]} × {psd_shape[2]}"
2010
        )
2011

2012
    # Accuracy metrics
2013
    riae_matrix = diagnostics.get("riae_matrix")
2✔
2014
    riae_err = diagnostics.get("riae_matrix_errorbars")
2✔
2015
    if riae_matrix is not None:
2✔
2016
        line = f"RIAE (matrix): {float(riae_matrix):.3f}"
×
2017
        if riae_err and len(riae_err) >= 5:
×
2018
            line += f" (5-95% [{riae_err[0]:.3f}, {riae_err[4]:.3f}])"
×
2019
        lines.append(line)
×
2020

2021
    per_ch = diagnostics.get("riae_per_channel")
2✔
2022
    if per_ch:
2✔
2023
        formatted = ", ".join(
×
2024
            f"{idx}:{val:.3f}" for idx, val in enumerate(per_ch)
2025
        )
2026
        lines.append(f"RIAE per channel: {formatted}")
×
2027

2028
    offdiag = diagnostics.get("riae_offdiag")
2✔
2029
    if offdiag is not None:
2✔
2030
        lines.append(f"RIAE off-diagonal: {float(offdiag):.3f}")
×
2031

2032
    coh_riae = diagnostics.get("coherence_riae")
2✔
2033
    if coh_riae is not None:
2✔
2034
        lines.append(f"Coherence RIAE: {float(coh_riae):.3f}")
×
2035

2036
    bands = diagnostics.get("riae_bands")
2✔
2037
    if bands:
2✔
2038
        band_str = "; ".join(
×
2039
            f"[{b['start']:.2e},{b['end']:.2e}]:{b['value']:.3f}"
2040
            for b in bands
2041
        )
2042
        lines.append(f"RIAE by frequency bands: {band_str}")
×
2043

2044
    coverage = diagnostics.get("coverage") or diagnostics.get("ci_coverage")
2✔
2045
    coverage_level = diagnostics.get("coverage_level")
2✔
2046
    if coverage is not None:
2✔
2047
        label = (
×
2048
            f"{int(round(coverage_level * 100))}% interval coverage"
2049
            if coverage_level is not None
2050
            else "Interval coverage"
2051
        )
2052
        lines.append(f"{label}: {float(coverage) * 100:.1f}%")
×
2053

2054
    summary_text = "\n".join(lines)
2✔
2055

2056
    if outdir:
2✔
2057
        try:
2✔
2058
            os.makedirs(outdir, exist_ok=True)
2✔
2059
            with open(
2✔
2060
                os.path.join(outdir, "vi_diagnostics_summary.txt"), "w"
2061
            ) as f:
2062
                f.write(summary_text)
2✔
2063
        except Exception:
×
2064
            logger.debug(
×
2065
                "Could not write VI diagnostics summary to disk.",
2066
                exc_info=True,
2067
            )
2068

2069
    if log:
2✔
2070
        logger.info(f"\n{summary_text}\n")
2✔
2071
    return summary_text
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc