• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 19916965190

04 Dec 2025 03:46AM UTC coverage: 79.89% (-0.2%) from 80.121%
19916965190

push

github

avivajpeyi
VI improvements

842 of 998 branches covered (84.37%)

Branch coverage included in aggregate %.

84 of 142 new or added lines in 7 files covered. (59.15%)

3 existing lines in 3 files now uncovered.

5117 of 6461 relevant lines covered (79.2%)

1.58 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.0
/src/log_psplines/plotting/diagnostics.py
1
import os
2✔
2
from dataclasses import dataclass
2✔
3
from typing import Optional
2✔
4

5
import arviz as az
2✔
6
import matplotlib.pyplot as plt
2✔
7
import numpy as np
2✔
8

9
from ..logger import logger
2✔
10
from .base import PlotConfig, safe_plot, setup_plot_style
2✔
11

12
# Setup consistent styling for diagnostics plots
13
setup_plot_style()
2✔
14

15

16
@dataclass
2✔
17
class DiagnosticsConfig:
2✔
18
    """Configuration for diagnostics plotting parameters."""
19

20
    figsize: tuple = (12, 8)
2✔
21
    dpi: int = 150
2✔
22
    ess_threshold: int = 400
2✔
23
    rhat_threshold: float = 1.01
2✔
24
    fontsize: int = 11
2✔
25
    labelsize: int = 12
2✔
26
    titlesize: int = 12
2✔
27

28

29
def plot_trace(idata: az.InferenceData, compact=True) -> plt.Figure:
2✔
30
    groups = {
2✔
31
        "delta": [
32
            v for v in idata.posterior.data_vars if v.startswith("delta")
33
        ],
34
        "phi": [v for v in idata.posterior.data_vars if v.startswith("phi")],
35
        "weights": [
36
            v for v in idata.posterior.data_vars if v.startswith("weights")
37
        ],
38
    }
39

40
    if compact:
2✔
41
        nrows = 3
2✔
42
    else:
43
        nrows = len(groups)
×
44
    fig, axes = plt.subplots(nrows, 2, figsize=(7, 3 * nrows))
2✔
45

46
    for row, (group_name, vars) in enumerate(groups.items()):
2✔
47

48
        # if vars are more than 1, and compact, then we need to repeat the axes
49
        if compact:
2✔
50
            group_axes = axes[row, :].reshape(1, 2)
2✔
51
            group_axes = np.repeat(group_axes, len(vars), axis=0)
2✔
52
        else:
53
            group_axes = axes[row, :]
×
54

55
        group_axes[0, 0].set_title(
2✔
56
            f"{group_name.capitalize()} Parameters", fontsize=14
57
        )
58

59
        for i, var in enumerate(vars):
2✔
60
            data = idata.posterior[
2✔
61
                var
62
            ].values  # shape is (nchain, nsamples, ndim) if ndim>1 else (nchain, nsamples)
63
            if data.ndim == 3:
2✔
64
                data = data[0].T  # shape is now (ndim, nsamples)
2✔
65

66
            ax_trace = group_axes[i, 0] if compact else group_axes[0]
2✔
67
            ax_hist = group_axes[i, 1] if compact else group_axes[1]
2✔
68
            ax_trace.set_ylabel(group_name, fontsize=8)
2✔
69
            ax_trace.set_xlabel("MCMC Step", fontsize=8)
2✔
70
            ax_hist.set_xlabel(group_name, fontsize=8)
2✔
71
            # place ylabel on right side of hist
72
            ax_hist.yaxis.set_label_position("right")
2✔
73
            ax_hist.set_ylabel("Density", fontsize=8, rotation=270, labelpad=0)
2✔
74

75
            # remove axes yspine for hist
76
            ax_hist.spines["left"].set_visible(False)
2✔
77
            ax_hist.spines["right"].set_visible(False)
2✔
78
            ax_hist.spines["top"].set_visible(False)
2✔
79
            ax_hist.set_yticks([])  # remove y ticks
2✔
80
            ax_hist.yaxis.set_ticks_position("none")
2✔
81

82
            ax_trace.spines["right"].set_visible(False)
2✔
83
            ax_trace.spines["top"].set_visible(False)
2✔
84

85
            color = f"C{i}"
2✔
86
            label = f"{var}"
2✔
87
            if group_name in ["phi", "delta"]:
2✔
88
                ax_trace.set_yscale("log")
2✔
89
                ax_hist.set_xscale("log")
2✔
90

91
            for p in data:
2✔
92
                ax_trace.plot(p, color=color, alpha=0.7, label=label)
2✔
93

94
                # if phi or delta, use log scale for hist-x, log for trace y
95
                if group_name in ["phi", "delta"]:
2✔
96
                    bins = np.logspace(
2✔
97
                        np.log10(np.min(p)), np.log10(np.max(p)), 30
98
                    )
99
                    logp = np.log(p)
2✔
100
                    log_grid, log_pdf = az.kde(logp)
2✔
101
                    grid = np.exp(log_grid)
2✔
102
                    pdf = log_pdf / grid  # change of variables
2✔
103
                else:
104
                    bins = 30
2✔
105
                    grid, pdf = az.kde(p)
2✔
106
                ax_hist.plot(grid, pdf, color=color, label=label)
2✔
107
                ax_hist.hist(
2✔
108
                    p, bins=bins, density=True, color=color, alpha=0.3
109
                )
110

111
                # KDE plot instead of histogram
112

113
    plt.suptitle("Parameter Traces", fontsize=16)
2✔
114
    plt.tight_layout()
2✔
115
    return fig
2✔
116

117

118
def plot_diagnostics(
2✔
119
    idata: az.InferenceData,
120
    outdir: str,
121
    n_channels: Optional[int] = None,
122
    n_freq: Optional[int] = None,
123
    runtime: Optional[float] = None,
124
    config: Optional[DiagnosticsConfig] = None,
125
) -> None:
126
    """
127
    Create essential MCMC diagnostics in organized subdirectories.
128
    """
129
    if outdir is None:
2✔
130
        return
×
131

132
    if config is None:
2✔
133
        config = DiagnosticsConfig()
2✔
134

135
    # Create diagnostics subdirectory
136
    diag_dir = os.path.join(outdir, "diagnostics")
2✔
137
    os.makedirs(diag_dir, exist_ok=True)
2✔
138

139
    logger.info("Generating MCMC diagnostics...")
2✔
140

141
    # Generate summary report
142
    generate_diagnostics_summary(idata, diag_dir)
2✔
143
    _create_diagnostic_plots(
2✔
144
        idata, diag_dir, config, n_channels, n_freq, runtime
145
    )
146

147

148
def _create_diagnostic_plots(
2✔
149
    idata, diag_dir, config, n_channels, n_freq, runtime
150
):
151
    """Create only the essential diagnostic plots."""
152
    logger.debug("Generating diagnostic plots...")
2✔
153

154
    # 1. ArviZ trace plots
155
    @safe_plot(f"{diag_dir}/trace_plots.png", config.dpi)
2✔
156
    def create_trace_plots():
2✔
157
        return plot_trace(idata)
2✔
158

159
    create_trace_plots()
2✔
160

161
    # 2. Summary dashboard with key convergence metrics
162
    @safe_plot(f"{diag_dir}/summary_dashboard.png", config.dpi)
2✔
163
    def plot_summary():
2✔
164
        _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime)
2✔
165

166
    plot_summary()
2✔
167

168
    # 3. Log posterior diagnostics
169
    @safe_plot(f"{diag_dir}/log_posterior.png", config.dpi)
2✔
170
    def plot_lp():
2✔
171
        _plot_log_posterior(idata, config)
2✔
172

173
    plot_lp()
2✔
174

175
    # 4. Acceptance rate diagnostics
176
    @safe_plot(f"{diag_dir}/acceptance_diagnostics.png", config.dpi)
2✔
177
    def plot_acceptance():
2✔
178
        _plot_acceptance_diagnostics_blockaware(idata, config)
2✔
179

180
    plot_acceptance()
2✔
181

182
    # 5. Sampler-specific diagnostics
183
    _create_sampler_diagnostics(idata, diag_dir, config)
2✔
184

185
    # 6. Divergences diagnostics (for NUTS only)
186
    _create_divergences_diagnostics(idata, diag_dir, config)
2✔
187

188

189
def _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime):
2✔
190

191
    # Create 2x2 layout
192
    fig, axes = plt.subplots(
2✔
193
        2, 2, figsize=(config.figsize[0] * 0.8, config.figsize[1])
194
    )
195
    ess_ax = axes[0, 0]
2✔
196
    meta_ax = axes[0, 1]
2✔
197
    param_ax = axes[1, 0]
2✔
198
    status_ax = axes[1, 1]
2✔
199

200
    # Get ESS values once
201
    ess_values = None
2✔
202
    try:
2✔
203
        ess = idata.attrs.get("ess")
2✔
204
        ess_values = ess[~np.isnan(ess)]
2✔
205
    except Exception:
×
206
        pass
×
207

208
    # 1. ESS Distribution
209
    _plot_ess_histogram(ess_ax, ess_values, config)
2✔
210

211
    # 2. Analysis Metadata
212
    _plot_metadata(meta_ax, idata, n_channels, n_freq, runtime)
2✔
213

214
    # 3. Parameter Summary
215
    _plot_parameter_summary(param_ax, idata)
2✔
216

217
    # 4. Convergence Status
218
    _plot_convergence_status(status_ax, ess_values, config)
2✔
219

220
    plt.tight_layout()
2✔
221

222

223
def _plot_ess_histogram(ax, ess_values, config):
2✔
224
    """Plot ESS distribution with quality thresholds."""
225
    if ess_values is None or len(ess_values) == 0:
2✔
226
        ax.text(0.5, 0.5, "ESS unavailable", ha="center", va="center")
×
227
        ax.set_title("ESS Distribution")
×
228
        return
×
229

230
    # Histogram
231
    ax.hist(ess_values, bins=30, alpha=0.7, edgecolor="black")
2✔
232

233
    # Reference lines
234
    thresholds = [
2✔
235
        (400, "red", "--", "Minimum reliable"),
236
        (1000, "orange", "--", "Good"),
237
        (np.max(ess_values), "green", ":", f"Max = {np.max(ess_values):.0f}"),
238
    ]
239

240
    for threshold, color, style, label in thresholds:
2✔
241
        ax.axvline(
2✔
242
            x=threshold,
243
            color=color,
244
            linestyle=style,
245
            linewidth=2 if threshold < np.max(ess_values) else 1,
246
            alpha=0.8,
247
            label=label,
248
        )
249

250
    ax.set_xlabel("ESS")
2✔
251
    ax.set_ylabel("Count")
2✔
252
    ax.set_title("ESS Distribution")
2✔
253
    ax.legend(loc="upper right", fontsize="x-small")
2✔
254
    ax.grid(True, alpha=0.3)
2✔
255

256
    # Summary stats
257
    pct_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
258
    stats_text = f"Min: {ess_values.min():.0f}\nMean: {ess_values.mean():.0f}\n≥{config.ess_threshold}: {pct_good:.1f}%"
2✔
259
    ax.text(
2✔
260
        0.02,
261
        0.98,
262
        stats_text,
263
        transform=ax.transAxes,
264
        fontsize=10,
265
        verticalalignment="top",
266
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.7),
267
    )
268

269

270
def _plot_metadata(ax, idata, n_channels, n_freq, runtime):
2✔
271
    """Display analysis metadata."""
272
    try:
2✔
273
        n_samples = idata.posterior.sizes.get("draw", 0)
2✔
274
        n_chains = idata.posterior.sizes.get("chain", 1)
2✔
275
        n_params = len(list(idata.posterior.data_vars))
2✔
276
        sampler_type = idata.attrs["sampler_type"]
2✔
277

278
        metadata_lines = [
2✔
279
            f"Sampler: {sampler_type}",
280
            f"Samples: {n_samples} × {n_chains} chains",
281
            f"Parameters: {n_params}",
282
        ]
283
        if n_channels is not None:
2✔
284
            metadata_lines.append(f"Channels: {n_channels}")
×
285
        if n_freq is not None:
2✔
286
            metadata_lines.append(f"Frequencies: {n_freq}")
×
287
        if runtime is not None:
2✔
288
            metadata_lines.append(f"Runtime: {runtime:.2f}s")
×
289

290
        ax.text(
2✔
291
            0.05,
292
            0.95,
293
            "\n".join(metadata_lines),
294
            transform=ax.transAxes,
295
            fontsize=12,
296
            verticalalignment="top",
297
            fontfamily="monospace",
298
        )
299
    except Exception:
×
300
        ax.text(0.5, 0.5, "Metadata unavailable", ha="center", va="center")
×
301

302
    ax.set_title("Analysis Summary")
2✔
303
    ax.axis("off")
2✔
304

305

306
def _plot_parameter_summary(ax, idata):
2✔
307
    """Display parameter count summary."""
308
    try:
2✔
309
        param_groups = _group_parameters_simple(idata)
2✔
310
        if param_groups:
2✔
311
            summary_text = "Parameter Summary:\n"
2✔
312
            for group_name, params in param_groups.items():
2✔
313
                if params:
2✔
314
                    summary_text += f"{group_name}: {len(params)}\n"
2✔
315
            ax.text(
2✔
316
                0.05,
317
                0.95,
318
                summary_text.strip(),
319
                transform=ax.transAxes,
320
                fontsize=11,
321
                verticalalignment="top",
322
                fontfamily="monospace",
323
            )
324
    except Exception:
×
325
        ax.text(
×
326
            0.5,
327
            0.5,
328
            "Parameter summary\nunavailable",
329
            ha="center",
330
            va="center",
331
        )
332

333
    ax.set_title("Parameter Summary")
2✔
334
    ax.axis("off")
2✔
335

336

337
def _plot_convergence_status(ax, ess_values, config):
2✔
338
    """Display convergence status based on ESS only."""
339
    try:
2✔
340
        status_lines = ["Convergence Status:"]
2✔
341

342
        if ess_values is not None and len(ess_values) > 0:
2✔
343
            ess_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
344
            status_lines.append(
2✔
345
                f"ESS ≥ {config.ess_threshold}: {ess_good:.0f}%"
346
            )
347
            status_lines.append("")
2✔
348
            status_lines.append("Overall Status:")
2✔
349

350
            if ess_good >= 90:
2✔
351
                status_lines.append("✓ EXCELLENT")
×
352
                color = "green"
×
353
            elif ess_good >= 75:
2✔
354
                status_lines.append("✓ ADEQUATE")
×
355
                color = "orange"
×
356
            else:
357
                status_lines.append("⚠ NEEDS ATTENTION")
2✔
358
                color = "red"
2✔
359
        else:
360
            status_lines.append("? UNABLE TO ASSESS")
×
361
            color = "gray"
×
362

363
        ax.text(
2✔
364
            0.05,
365
            0.95,
366
            "\n".join(status_lines),
367
            transform=ax.transAxes,
368
            fontsize=11,
369
            verticalalignment="top",
370
            fontfamily="monospace",
371
            color=color,
372
        )
373
    except Exception:
×
374
        ax.text(0.5, 0.5, "Status unavailable", ha="center", va="center")
×
375

376
    ax.set_title("Convergence Status")
2✔
377
    ax.axis("off")
2✔
378

379

380
def _plot_log_posterior(idata, config):
2✔
381
    """Log posterior diagnostics."""
382
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
383

384
    # Check for lp first, then log_likelihood
385
    if "lp" in idata.sample_stats:
2✔
386
        lp_values = idata.sample_stats["lp"].values.flatten()
2✔
387
        var_name = "lp"
2✔
388
        title_prefix = "Log Posterior"
2✔
389
    elif "log_likelihood" in idata.sample_stats:
2✔
390
        lp_values = idata.sample_stats["log_likelihood"].values.flatten()
2✔
391
        var_name = "log_likelihood"
2✔
392
        title_prefix = "Log Likelihood"
2✔
393
    else:
394
        # Create a fallback layout when no posterior data available
395
        fig, axes = plt.subplots(1, 1, figsize=config.figsize)
×
396
        axes.text(
×
397
            0.5,
398
            0.5,
399
            "No log posterior\nor log likelihood\navailable",
400
            ha="center",
401
            va="center",
402
            fontsize=14,
403
        )
404
        axes.set_title("Log Posterior Diagnostics")
×
405
        axes.axis("off")
×
406
        plt.tight_layout()
×
407
        return
×
408

409
    # Trace plot with running mean overlaid
410
    axes[0, 0].plot(
2✔
411
        lp_values, alpha=0.7, linewidth=1, color="blue", label="Trace"
412
    )
413

414
    # Add running mean on the same plot
415
    window_size = max(10, len(lp_values) // 100)
2✔
416
    if len(lp_values) > window_size:
2✔
417
        running_mean = np.convolve(
2✔
418
            lp_values, np.ones(window_size) / window_size, mode="valid"
419
        )
420
        axes[0, 0].plot(
2✔
421
            range(window_size // 2, window_size // 2 + len(running_mean)),
422
            running_mean,
423
            alpha=0.9,
424
            linewidth=3,
425
            color="red",
426
            label=f"Running mean (w={window_size})",
427
        )
428

429
    axes[0, 0].set_xlabel("Iteration")
2✔
430
    axes[0, 0].set_ylabel(title_prefix)
2✔
431
    axes[0, 0].set_title(f"{title_prefix} Trace with Running Mean")
2✔
432
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
433
    axes[0, 0].grid(True, alpha=0.3)
2✔
434

435
    # Distribution
436
    axes[0, 1].hist(
2✔
437
        lp_values, bins=50, alpha=0.7, density=True, edgecolor="black"
438
    )
439
    axes[0, 1].axvline(
2✔
440
        np.mean(lp_values),
441
        color="red",
442
        linestyle="--",
443
        linewidth=2,
444
        label=f"Mean: {np.mean(lp_values):.1f}",
445
    )
446
    axes[0, 1].set_xlabel(title_prefix)
2✔
447
    axes[0, 1].set_ylabel("Density")
2✔
448
    axes[0, 1].set_title(f"{title_prefix} Distribution")
2✔
449
    axes[0, 1].legend(loc="best", fontsize="small")
2✔
450
    axes[0, 1].grid(True, alpha=0.3)
2✔
451

452
    # Step-to-step changes
453
    lp_diff = np.diff(lp_values)
2✔
454
    axes[1, 0].plot(lp_diff, alpha=0.5, linewidth=1)
2✔
455
    axes[1, 0].axhline(0, color="red", linestyle="--", alpha=0.7)
2✔
456
    axes[1, 0].axhline(
2✔
457
        np.mean(lp_diff),
458
        color="blue",
459
        linestyle="--",
460
        alpha=0.7,
461
        label=f"Mean change: {np.mean(lp_diff):.1f}",
462
    )
463
    axes[1, 0].set_xlabel("Iteration")
2✔
464
    axes[1, 0].set_ylabel(f"{title_prefix} Difference")
2✔
465
    axes[1, 0].set_title("Step-to-Step Changes")
2✔
466
    axes[1, 0].legend(loc="best", fontsize="small")
2✔
467
    axes[1, 0].grid(True, alpha=0.3)
2✔
468

469
    # Summary statistics
470
    stats_lines = [
2✔
471
        f"Mean: {np.mean(lp_values):.2f}",
472
        f"Std: {np.std(lp_values):.2f}",
473
        f"Min: {np.min(lp_values):.2f}",
474
        f"Max: {np.max(lp_values):.2f}",
475
        f"Range: {np.max(lp_values) - np.min(lp_values):.2f}",
476
        "",
477
        "Stability:",
478
        f"Final variation: {np.std(lp_values[-len(lp_values)//4:]):.2f}",
479
    ]
480

481
    axes[1, 1].text(
2✔
482
        0.05,
483
        0.95,
484
        "\n".join(stats_lines),
485
        transform=axes[1, 1].transAxes,
486
        fontsize=10,
487
        verticalalignment="top",
488
        fontfamily="monospace",
489
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
490
    )
491
    axes[1, 1].set_title("Posterior Statistics")
2✔
492
    axes[1, 1].axis("off")
2✔
493

494
    plt.tight_layout()
2✔
495

496

497
def _plot_acceptance_diagnostics(idata, config):
2✔
498
    """Acceptance rate diagnostics."""
499
    accept_key = None
×
500
    if "accept_prob" in idata.sample_stats:
×
501
        accept_key = "accept_prob"
×
502
    elif "acceptance_rate" in idata.sample_stats:
×
503
        accept_key = "acceptance_rate"
×
504

505
    if accept_key is None:
×
506
        fig, ax = plt.subplots(figsize=config.figsize)
×
507
        ax.text(
×
508
            0.5,
509
            0.5,
510
            "Acceptance rate data unavailable",
511
            ha="center",
512
            va="center",
513
        )
514
        ax.set_title("Acceptance Rate Diagnostics")
×
515
        return
×
516

517
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
×
518

519
    accept_rates = idata.sample_stats[accept_key].values.flatten()
×
520
    target_rate = getattr(idata.attrs, "target_accept_rate", 0.44)
×
521
    sampler_type = (
×
522
        idata.attrs["sampler_type"].lower()
523
        if "sampler_type" in idata.attrs
524
        else "unknown"
525
    )
526
    sampler_type = "NUTS" if "nuts" in sampler_type else "MH"
×
527

528
    # Define good ranges based on sampler
529
    if target_rate > 0.5:  # NUTS
×
530
        good_range = (0.7, 0.9)
×
531
        low_range = (0.0, 0.6)
×
532
        high_range = (0.9, 1.0)
×
533
        concerning_range = (0.6, 0.7)
×
534
    else:  # MH
535
        good_range = (0.2, 0.5)
×
536
        low_range = (0.0, 0.2)
×
537
        high_range = (0.5, 1.0)
×
538
        concerning_range = (0.1, 0.2)  # MH can be lower than NUTS
×
539

540
    # Trace plot with color zones
541
    # Add background zones
542
    axes[0, 0].axhspan(
×
543
        good_range[0],
544
        good_range[1],
545
        alpha=0.1,
546
        color="green",
547
        label=f"Good ({good_range[0]:.1f}-{good_range[1]:.1f})",
548
    )
549
    axes[0, 0].axhspan(
×
550
        low_range[0], low_range[1], alpha=0.1, color="red", label="Too low"
551
    )
552
    axes[0, 0].axhspan(
×
553
        high_range[0],
554
        high_range[1],
555
        alpha=0.1,
556
        color="orange",
557
        label="Too high",
558
    )
559
    if concerning_range[1] > concerning_range[0]:
×
560
        axes[0, 0].axhspan(
×
561
            concerning_range[0],
562
            concerning_range[1],
563
            alpha=0.1,
564
            color="yellow",
565
            label="Concerning",
566
        )
567

568
    # Main trace plot
569
    axes[0, 0].plot(
×
570
        accept_rates, alpha=0.8, linewidth=1, color="blue", label="Trace"
571
    )
572
    axes[0, 0].axhline(
×
573
        target_rate,
574
        color="red",
575
        linestyle="--",
576
        linewidth=2,
577
        label=f"Target ({target_rate})",
578
    )
579

580
    # Add running average on the same plot
581
    window_size = max(10, len(accept_rates) // 50)
×
582
    if len(accept_rates) > window_size:
×
583
        running_mean = np.convolve(
×
584
            accept_rates, np.ones(window_size) / window_size, mode="valid"
585
        )
586
        axes[0, 0].plot(
×
587
            range(window_size // 2, window_size // 2 + len(running_mean)),
588
            running_mean,
589
            alpha=0.9,
590
            linewidth=3,
591
            color="purple",
592
            label=f"Running mean (w={window_size})",
593
        )
594

595
    axes[0, 0].set_xlabel("Iteration")
×
596
    axes[0, 0].set_ylabel("Acceptance Rate")
×
597
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
×
598
    axes[0, 0].legend(loc="best", fontsize="small")
×
599
    axes[0, 0].grid(True, alpha=0.3)
×
600

601
    # Add interpretation text
602
    interpretation = f"{sampler_type} aims for {target_rate:.2f}."
×
603
    if target_rate > 0.5:
×
604
        interpretation += " Green: efficient sampling."
×
605
    else:
606
        interpretation += " MH adapts to find optimal rate."
×
607
    axes[0, 0].text(
×
608
        0.02,
609
        0.02,
610
        interpretation,
611
        transform=axes[0, 0].transAxes,
612
        fontsize=9,
613
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
614
    )
615

616
    # Distribution
617
    axes[0, 1].hist(
×
618
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
619
    )
620
    axes[0, 1].axvline(
×
621
        target_rate,
622
        color="red",
623
        linestyle="--",
624
        linewidth=2,
625
        label=f"Target ({target_rate})",
626
    )
627
    axes[0, 1].set_xlabel("Acceptance Rate")
×
628
    axes[0, 1].set_ylabel("Density")
×
629
    axes[0, 1].set_title("Acceptance Rate Distribution")
×
630
    axes[0, 1].legend()
×
631
    axes[0, 1].grid(True, alpha=0.3)
×
632

633
    # Since running means are already overlaid on the main plot, use the bottom row for additional info
634

635
    # Additional acceptance analysis - evolution over time
636
    if len(accept_rates) > 10:
×
637
        # Show moving standard deviation or coefficient of variation
638
        window_std = np.array(
×
639
            [
640
                np.std(accept_rates[max(0, i - 20) : i + 1])
641
                for i in range(len(accept_rates))
642
            ]
643
        )
644
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
×
645
        axes[1, 0].set_xlabel("Iteration")
×
646
        axes[1, 0].set_ylabel("Rolling Std")
×
647
        axes[1, 0].set_title("Rolling Standard Deviation")
×
648
        axes[1, 0].grid(True, alpha=0.3)
×
649
    else:
650
        axes[1, 0].text(
×
651
            0.5,
652
            0.5,
653
            "Acceptance variability\nanalysis unavailable",
654
            ha="center",
655
            va="center",
656
        )
657
        axes[1, 0].set_title("Acceptance Stability")
×
658

659
    # Summary statistics (expanded)
660
    stats_text = [
×
661
        f"Sampler: {sampler_type}",
662
        f"Target: {target_rate:.3f}",
663
        f"Mean: {np.mean(accept_rates):.3f}",
664
        f"Std: {np.std(accept_rates):.3f}",
665
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
666
        f"Min: {np.min(accept_rates):.3f}",
667
        f"Max: {np.max(accept_rates):.3f}",
668
        "",
669
        "Stability:",
670
        f"Final std: {np.std(accept_rates[-len(accept_rates)//4:]):.3f}",
671
    ]
672

673
    axes[1, 1].text(
×
674
        0.05,
675
        0.95,
676
        "\n".join(stats_text),
677
        transform=axes[1, 1].transAxes,
678
        fontsize=9,
679
        verticalalignment="top",
680
        fontfamily="monospace",
681
    )
682
    axes[1, 1].set_title("Acceptance Analysis")
×
683
    axes[1, 1].axis("off")
×
684

685
    plt.tight_layout()
×
686

687

688
def _plot_acceptance_diagnostics_blockaware(idata, config):
2✔
689
    """Acceptance diagnostics that also handle per‑channel series from blocked NUTS.
690

691
    If keys like ``accept_prob_channel_0`` are found in ``idata.sample_stats``,
692
    they are overlaid on the overall trace and included in the summary.
693
    """
694
    # Detect overall series
695
    accept_key = None
2✔
696
    if "accept_prob" in idata.sample_stats:
2✔
697
        accept_key = "accept_prob"
2✔
698
    elif "acceptance_rate" in idata.sample_stats:
2✔
699
        accept_key = "acceptance_rate"
×
700

701
    # Collect per-channel series
702
    channel_series = {}
2✔
703
    for key in idata.sample_stats:
2✔
704
        if isinstance(key, str) and key.startswith("accept_prob_channel_"):
2✔
705
            try:
2✔
706
                ch = int(key.rsplit("_", 1)[-1])
2✔
707
                channel_series[ch] = idata.sample_stats[key].values.flatten()
2✔
708
            except Exception:
×
709
                pass
×
710

711
    if accept_key is None and not channel_series:
2✔
712
        fig, ax = plt.subplots(figsize=config.figsize)
×
713
        ax.text(
×
714
            0.5,
715
            0.5,
716
            "Acceptance rate data unavailable",
717
            ha="center",
718
            va="center",
719
        )
720
        ax.set_title("Acceptance Rate Diagnostics")
×
721
        return
×
722

723
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
724

725
    # Overall or concatenated series
726
    if accept_key is not None:
2✔
727
        accept_rates = idata.sample_stats[accept_key].values.flatten()
2✔
728
    else:
729
        accept_rates = np.concatenate(list(channel_series.values()))
2✔
730

731
    sampler_type_attr = idata.attrs.get("sampler_type", "").lower()
2✔
732
    is_nuts = "nuts" in sampler_type_attr
2✔
733
    target_rate = idata.attrs.get(
2✔
734
        "target_accept_rate",
735
        idata.attrs.get("target_accept_prob", 0.8 if is_nuts else 0.44),
736
    )
737
    sampler_type = "NUTS" if is_nuts else "MH"
2✔
738

739
    # Ranges
740
    if is_nuts:
2✔
741
        good_range = (0.7, 0.9)
2✔
742
        low_range = (0.0, 0.6)
2✔
743
        high_range = (0.9, 1.0)
2✔
744
        concerning_range = (0.6, 0.7)
2✔
745
    else:
746
        good_range = (0.2, 0.5)
2✔
747
        low_range = (0.0, 0.2)
2✔
748
        high_range = (0.5, 1.0)
2✔
749
        concerning_range = (0.1, 0.2)
2✔
750

751
    # Background zones
752
    axes[0, 0].axhspan(good_range[0], good_range[1], alpha=0.1, color="green")
2✔
753
    axes[0, 0].axhspan(low_range[0], low_range[1], alpha=0.1, color="red")
2✔
754
    axes[0, 0].axhspan(high_range[0], high_range[1], alpha=0.1, color="orange")
2✔
755
    if concerning_range[1] > concerning_range[0]:
2✔
756
        axes[0, 0].axhspan(
2✔
757
            concerning_range[0], concerning_range[1], alpha=0.1, color="yellow"
758
        )
759

760
    # Plot traces: overall + per-channel overlays
761
    if accept_key is not None:
2✔
762
        axes[0, 0].plot(
2✔
763
            accept_rates, alpha=0.8, linewidth=1, color="blue", label="overall"
764
        )
765
    for ch in sorted(channel_series):
2✔
766
        axes[0, 0].plot(
2✔
767
            channel_series[ch], alpha=0.6, linewidth=1, label=f"ch {ch}"
768
        )
769
    axes[0, 0].axhline(
2✔
770
        target_rate,
771
        color="red",
772
        linestyle="--",
773
        linewidth=2,
774
        label=f"Target ({target_rate})",
775
    )
776

777
    # Running mean for overall
778
    window_size = max(10, len(accept_rates) // 50)
2✔
779
    if len(accept_rates) > window_size:
2✔
780
        running_mean = np.convolve(
2✔
781
            accept_rates, np.ones(window_size) / window_size, mode="valid"
782
        )
783
        axes[0, 0].plot(
2✔
784
            range(window_size // 2, window_size // 2 + len(running_mean)),
785
            running_mean,
786
            alpha=0.9,
787
            linewidth=3,
788
            color="purple",
789
            label=f"Running mean (w={window_size})",
790
        )
791

792
    axes[0, 0].set_xlabel("Iteration")
2✔
793
    axes[0, 0].set_ylabel("Acceptance Rate")
2✔
794
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
2✔
795
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
796
    axes[0, 0].grid(True, alpha=0.3)
2✔
797

798
    # Histogram and summary
799
    axes[0, 1].hist(
2✔
800
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
801
    )
802
    axes[0, 1].axvline(
2✔
803
        target_rate,
804
        color="red",
805
        linestyle="--",
806
        linewidth=2,
807
        label=f"Target ({target_rate})",
808
    )
809
    axes[0, 1].set_xlabel("Acceptance Rate")
2✔
810
    axes[0, 1].set_ylabel("Density")
2✔
811
    axes[0, 1].set_title("Acceptance Rate Distribution")
2✔
812
    axes[0, 1].legend()
2✔
813
    axes[0, 1].grid(True, alpha=0.3)
2✔
814

815
    if len(accept_rates) > 10:
2✔
816
        window_std = np.array(
2✔
817
            [
818
                np.std(accept_rates[max(0, i - 20) : i + 1])
819
                for i in range(len(accept_rates))
820
            ]
821
        )
822
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
2✔
823
        axes[1, 0].set_xlabel("Iteration")
2✔
824
        axes[1, 0].set_ylabel("Rolling Std")
2✔
825
        axes[1, 0].set_title("Rolling Standard Deviation")
2✔
826
        axes[1, 0].grid(True, alpha=0.3)
2✔
827
    else:
828
        axes[1, 0].text(
2✔
829
            0.5,
830
            0.5,
831
            "Acceptance variability\nanalysis unavailable",
832
            ha="center",
833
            va="center",
834
        )
835
        axes[1, 0].set_title("Acceptance Stability")
2✔
836

837
    # Summary text
838
    stats_text = [
2✔
839
        f"Sampler: {sampler_type}",
840
        f"Target: {target_rate:.3f}",
841
        f"Mean: {np.mean(accept_rates):.3f}",
842
        f"Std: {np.std(accept_rates):.3f}",
843
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
844
        f"Min: {np.min(accept_rates):.3f}",
845
        f"Max: {np.max(accept_rates):.3f}",
846
    ]
847
    if channel_series:
2✔
848
        stats_text.append("")
2✔
849
        stats_text.append("Per-channel means:")
2✔
850
        for ch in sorted(channel_series):
2✔
851
            stats_text.append(f"  ch {ch}: {np.mean(channel_series[ch]):.3f}")
2✔
852

853
    axes[1, 1].text(
2✔
854
        0.05,
855
        0.95,
856
        "\n".join(stats_text),
857
        transform=axes[1, 1].transAxes,
858
        fontsize=9,
859
        va="top",
860
        family="monospace",
861
    )
862
    axes[1, 1].set_title("Acceptance Analysis")
2✔
863
    axes[1, 1].axis("off")
2✔
864

865
    plt.tight_layout()
2✔
866

867

868
def _get_channel_indices(sample_stats, base_key: str) -> set:
2✔
869
    """Return set of channel indices for the given ``base_key`` prefix."""
870

871
    prefix = f"{base_key}_channel_"
2✔
872
    indices = set()
2✔
873
    for key in sample_stats:
2✔
874
        if isinstance(key, str) and key.startswith(prefix):
2✔
875
            try:
2✔
876
                indices.add(int(key.replace(prefix, "")))
2✔
877
            except Exception:
×
878
                continue
×
879
    return indices
2✔
880

881

882
def _plot_nuts_diagnostics_blockaware(idata, config):
2✔
883
    """NUTS diagnostics supporting per‑channel (blocked) diagnostics fields.
884

885
    Overlays per‑channel series when keys like ``energy_channel_{j}`` or
886
    ``num_steps_channel_{j}`` are present.
887
    """
888
    # Presence of overall arrays
889
    has_energy = "energy" in idata.sample_stats
2✔
890
    has_potential = "potential_energy" in idata.sample_stats
2✔
891
    has_steps = "num_steps" in idata.sample_stats
2✔
892
    has_accept = "accept_prob" in idata.sample_stats
2✔
893

894
    # Collect per-channel data
895
    def _collect(base):
2✔
896
        out = {}
2✔
897
        prefix = f"{base}_channel_"
2✔
898
        for key in idata.sample_stats:
2✔
899
            if isinstance(key, str) and key.startswith(prefix):
2✔
900
                try:
2✔
901
                    ch = int(key.replace(prefix, ""))
2✔
902
                    out[ch] = idata.sample_stats[key].values.flatten()
2✔
903
                except Exception:
×
904
                    pass
×
905
        return out
2✔
906

907
    energy_ch = _collect("energy")
2✔
908
    potential_ch = _collect("potential_energy")
2✔
909
    steps_ch = _collect("num_steps")
2✔
910
    accept_ch = _collect("accept_prob")
2✔
911

912
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
913

914
    # Energy / potential
915
    ax = axes[0, 0]
2✔
916
    plotted = False
2✔
917
    if has_energy:
2✔
918
        ax.plot(
2✔
919
            idata.sample_stats.energy.values.flatten(),
920
            alpha=0.7,
921
            lw=1,
922
            label="H",
923
        )
924
        plotted = True
2✔
925
    if has_potential:
2✔
926
        ax.plot(
2✔
927
            idata.sample_stats.potential_energy.values.flatten(),
928
            alpha=0.7,
929
            lw=1,
930
            label="P",
931
        )
932
        plotted = True
2✔
933
    for ch in sorted(energy_ch):
2✔
934
        ax.plot(energy_ch[ch], alpha=0.5, lw=1, label=f"H ch {ch}")
2✔
935
        plotted = True
2✔
936
    for ch in sorted(potential_ch):
2✔
937
        ax.plot(potential_ch[ch], alpha=0.5, lw=1, label=f"P ch {ch}")
2✔
938
        plotted = True
2✔
939
    if not plotted:
2✔
940
        ax.text(0.5, 0.5, "Energy data\nunavailable", ha="center", va="center")
×
941
    ax.set_title("Energy Diagnostics")
2✔
942
    ax.set_xlabel("Iteration")
2✔
943
    ax.set_ylabel("Energy")
2✔
944
    ax.grid(True, alpha=0.3)
2✔
945
    if plotted:
2✔
946
        ax.legend(loc="best", fontsize="small")
2✔
947

948
    # Steps histogram
949
    ax = axes[0, 1]
2✔
950
    if has_steps:
2✔
951
        vals = idata.sample_stats.num_steps.values.flatten()
2✔
952
    else:
953
        vals = (
2✔
954
            np.concatenate(list(steps_ch.values()))
955
            if steps_ch
956
            else np.array([])
957
        )
958
    if vals.size:
2✔
959
        ax.hist(vals, bins=20, alpha=0.7, edgecolor="black")
2✔
960
        ax.set_title("Leapfrog Steps Distribution")
2✔
961
        ax.set_xlabel("Steps")
2✔
962
        ax.set_ylabel("Trajectories")
2✔
963
        ax.grid(True, alpha=0.3)
2✔
964
    else:
965
        ax.text(0.5, 0.5, "Steps data\nunavailable", ha="center", va="center")
×
966

967
    # Acceptance (overlay per-channel)
968
    ax = axes[1, 0]
2✔
969
    plotted = False
2✔
970
    if has_accept:
2✔
971
        ax.plot(
2✔
972
            idata.sample_stats.accept_prob.values.flatten(),
973
            alpha=0.8,
974
            lw=1,
975
            label="overall",
976
        )
977
        plotted = True
2✔
978
    for ch in sorted(accept_ch):
2✔
979
        ax.plot(accept_ch[ch], alpha=0.6, lw=1, label=f"ch {ch}")
2✔
980
        plotted = True
2✔
981
    if not plotted:
2✔
982
        ax.text(
×
983
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
984
        )
985
    ax.set_title("Acceptance Trace")
2✔
986
    ax.set_xlabel("Iteration")
2✔
987
    ax.set_ylabel("accept_prob")
2✔
988
    ax.grid(True, alpha=0.3)
2✔
989
    if plotted:
2✔
990
        ax.legend(loc="best", fontsize="small")
2✔
991

992
    # Summary text
993
    ax = axes[1, 1]
2✔
994
    lines = []
2✔
995
    if has_steps or steps_ch:
2✔
996
        lines.append("Steps summary:")
2✔
997
        if has_steps:
2✔
998
            s = idata.sample_stats.num_steps.values.flatten()
2✔
999
            lines.append(f"  overall μ={np.mean(s):.1f}, max={np.max(s):.0f}")
2✔
1000
        for ch in sorted(steps_ch):
2✔
1001
            s = steps_ch[ch]
2✔
1002
            lines.append(f"  ch {ch} μ={np.mean(s):.1f}, max={np.max(s):.0f}")
2✔
1003
        lines.append("")
2✔
1004
    if has_accept or accept_ch:
2✔
1005
        lines.append("Acceptance summary:")
2✔
1006
        if has_accept:
2✔
1007
            a = idata.sample_stats.accept_prob.values.flatten()
2✔
1008
            lines.append(f"  overall μ={np.mean(a):.3f}")
2✔
1009
        for ch in sorted(accept_ch):
2✔
1010
            a = accept_ch[ch]
2✔
1011
            lines.append(f"  ch {ch} μ={np.mean(a):.3f}")
2✔
1012
    if lines:
2✔
1013
        ax.text(
2✔
1014
            0.05,
1015
            0.95,
1016
            "\n".join(lines),
1017
            transform=ax.transAxes,
1018
            va="top",
1019
            family="monospace",
1020
        )
1021
    ax.set_title("NUTS Diagnostics Summary")
2✔
1022
    ax.axis("off")
2✔
1023

1024
    plt.tight_layout()
2✔
1025

1026

1027
def _plot_single_nuts_block(idata, config, channel_idx: int):
2✔
1028
    """NUTS diagnostics for a single blocked channel."""
1029

1030
    def _get(key):
2✔
1031
        full_key = f"{key}_channel_{channel_idx}"
2✔
1032
        return (
2✔
1033
            idata.sample_stats[full_key].values.flatten()
1034
            if full_key in idata.sample_stats
1035
            else None
1036
        )
1037

1038
    energy = _get("energy")
2✔
1039
    potential = _get("potential_energy")
2✔
1040
    num_steps = _get("num_steps")
2✔
1041
    accept_prob = _get("accept_prob")
2✔
1042

1043
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1044

1045
    # Energy traces
1046
    ax = axes[0, 0]
2✔
1047
    plotted = False
2✔
1048
    if energy is not None:
2✔
1049
        ax.plot(energy, alpha=0.7, lw=1, label="H")
2✔
1050
        plotted = True
2✔
1051
    if potential is not None:
2✔
1052
        ax.plot(potential, alpha=0.7, lw=1, label="P")
2✔
1053
        plotted = True
2✔
1054
    if not plotted:
2✔
1055
        ax.text(0.5, 0.5, "Energy data\nunavailable", ha="center", va="center")
×
1056
    ax.set_title(f"Channel {channel_idx} Energy")
2✔
1057
    ax.set_xlabel("Iteration")
2✔
1058
    ax.set_ylabel("Energy")
2✔
1059
    ax.grid(True, alpha=0.3)
2✔
1060
    if plotted:
2✔
1061
        ax.legend(loc="best", fontsize="small")
2✔
1062

1063
    # Acceptance trace
1064
    ax = axes[0, 1]
2✔
1065
    if accept_prob is not None:
2✔
1066
        ax.axhspan(0.7, 0.9, alpha=0.1, color="green")
2✔
1067
        ax.axhspan(0.0, 0.6, alpha=0.1, color="red")
2✔
1068
        ax.axhspan(0.9, 1.0, alpha=0.1, color="orange")
2✔
1069
        ax.plot(accept_prob, alpha=0.8, lw=1, color="purple")
2✔
1070
        ax.axhline(0.8, color="red", linestyle="--", lw=1.5, label="target")
2✔
1071
        ax.set_ylim(0, 1)
2✔
1072
        ax.legend(loc="best", fontsize="small")
2✔
1073
        ax.grid(True, alpha=0.3)
2✔
1074
    else:
1075
        ax.text(
×
1076
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1077
        )
1078
    ax.set_title(f"Channel {channel_idx} Acceptance")
2✔
1079
    ax.set_xlabel("Iteration")
2✔
1080
    ax.set_ylabel("accept_prob")
2✔
1081

1082
    # Steps histogram
1083
    ax = axes[1, 0]
2✔
1084
    if num_steps is not None and num_steps.size:
2✔
1085
        ax.hist(num_steps, bins=20, alpha=0.7, edgecolor="black")
2✔
1086
        ax.set_xlabel("Steps")
2✔
1087
        ax.set_ylabel("Trajectories")
2✔
1088
        ax.grid(True, alpha=0.3)
2✔
1089
    else:
1090
        ax.text(0.5, 0.5, "Steps data\nunavailable", ha="center", va="center")
×
1091
    ax.set_title(f"Channel {channel_idx} Leapfrog Steps")
2✔
1092

1093
    # Summary stats
1094
    ax = axes[1, 1]
2✔
1095
    stats_lines = [f"Channel {channel_idx} summary:"]
2✔
1096
    if energy is not None:
2✔
1097
        stats_lines.append(
2✔
1098
            f"  H μ={np.mean(energy):.2f}, σ={np.std(energy):.2f}"
1099
        )
1100
    if potential is not None:
2✔
1101
        stats_lines.append(
2✔
1102
            f"  P μ={np.mean(potential):.2f}, σ={np.std(potential):.2f}"
1103
        )
1104
    if num_steps is not None:
2✔
1105
        stats_lines.append(
2✔
1106
            f"  steps μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1107
        )
1108
    if accept_prob is not None:
2✔
1109
        stats_lines.append(f"  accept μ={np.mean(accept_prob):.3f}")
2✔
1110

1111
    ax.text(
2✔
1112
        0.05,
1113
        0.95,
1114
        "\n".join(stats_lines),
1115
        transform=ax.transAxes,
1116
        va="top",
1117
        family="monospace",
1118
    )
1119
    ax.axis("off")
2✔
1120
    ax.set_title("Summary")
2✔
1121

1122
    plt.tight_layout()
2✔
1123

1124

1125
def _create_sampler_diagnostics(idata, diag_dir, config):
2✔
1126
    """Create sampler-specific diagnostics."""
1127

1128
    # Better sampler detection - check sampler type first
1129
    sampler_type = (
2✔
1130
        idata.attrs["sampler_type"].lower()
1131
        if "sampler_type" in idata.attrs
1132
        else "unknown"
1133
    )
1134

1135
    # Check for NUTS-specific fields that MH definitely doesn't have
1136
    nuts_specific_fields = [
2✔
1137
        "energy",
1138
        "num_steps",
1139
        "tree_depth",
1140
        "diverging",
1141
        "energy_error",
1142
    ]
1143

1144
    has_nuts = (
2✔
1145
        any(field in idata.sample_stats for field in nuts_specific_fields)
1146
        or "nuts" in sampler_type
1147
    )
1148

1149
    # Check for MH-specific fields (exclude anything NUTS might have)
1150
    has_mh = "step_size_mean" in idata.sample_stats and not has_nuts
2✔
1151

1152
    if has_nuts:
2✔
1153

1154
        @safe_plot(f"{diag_dir}/nuts_diagnostics.png", config.dpi)
2✔
1155
        def plot_nuts():
2✔
1156
            _plot_nuts_diagnostics_blockaware(idata, config)
2✔
1157

1158
        plot_nuts()
2✔
1159

1160
        # Per‑channel NUTS diagnostics for blocked samplers
1161
        channel_indices = _get_channel_indices(
2✔
1162
            idata.sample_stats, "accept_prob"
1163
        )
1164
        channel_indices |= _get_channel_indices(idata.sample_stats, "energy")
2✔
1165
        channel_indices |= _get_channel_indices(
2✔
1166
            idata.sample_stats, "potential_energy"
1167
        )
1168
        channel_indices |= _get_channel_indices(
2✔
1169
            idata.sample_stats, "num_steps"
1170
        )
1171

1172
        for channel_idx in sorted(channel_indices):
2✔
1173

1174
            @safe_plot(
2✔
1175
                f"{diag_dir}/nuts_block_{channel_idx}_diagnostics.png",
1176
                config.dpi,
1177
            )
1178
            def plot_nuts_block(channel_idx=channel_idx):
2✔
1179
                _plot_single_nuts_block(idata, config, channel_idx)
2✔
1180

1181
            plot_nuts_block()
2✔
1182
    elif has_mh:
2✔
1183

1184
        @safe_plot(f"{diag_dir}/mh_step_sizes.png", config.dpi)
2✔
1185
        def plot_mh():
2✔
1186
            _plot_mh_step_sizes(idata, config)
2✔
1187

1188
        plot_mh()
2✔
1189

1190

1191
def _plot_nuts_diagnostics(idata, config):
2✔
1192
    """NUTS diagnostics with enhanced information."""
1193
    # Determine available data to decide layout
1194
    has_energy = "energy" in idata.sample_stats
×
1195
    has_potential = "potential_energy" in idata.sample_stats
×
1196
    has_steps = "num_steps" in idata.sample_stats
×
1197
    has_accept = "accept_prob" in idata.sample_stats
×
1198
    has_divergences = "diverging" in idata.sample_stats
×
1199
    has_tree_depth = "tree_depth" in idata.sample_stats
×
1200
    has_energy_error = "energy_error" in idata.sample_stats
×
1201

1202
    # Create a 2x2 layout, potentially combining energy and potential on same plot
1203
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
×
1204

1205
    # Top-left: Energy diagnostics (combine Hamiltonian and Potential if both available)
1206
    energy_ax = axes[0, 0]
×
1207

1208
    if has_energy and has_potential:
×
1209
        # Both available - plot them together on one plot
1210
        energy = idata.sample_stats.energy.values.flatten()
×
1211
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1212

1213
        # Plot both energies on same axis
1214
        energy_ax.plot(
×
1215
            energy, alpha=0.7, linewidth=1, color="blue", label="Hamiltonian"
1216
        )
1217
        energy_ax.plot(
×
1218
            potential,
1219
            alpha=0.7,
1220
            linewidth=1,
1221
            color="orange",
1222
            label="Potential",
1223
        )
1224

1225
        # Add difference (which relates to kinetic energy)
1226
        energy_diff = energy - potential
×
1227
        # Create second y-axis for difference
1228
        ax2 = energy_ax.twinx()
×
1229
        ax2.plot(
×
1230
            energy_diff,
1231
            alpha=0.5,
1232
            linewidth=1,
1233
            color="red",
1234
            label="H - Potential (Kinetic)",
1235
            linestyle="--",
1236
        )
1237
        ax2.set_ylabel("Energy Difference", color="red")
×
1238
        ax2.tick_params(axis="y", labelcolor="red")
×
1239

1240
        energy_ax.set_xlabel("Iteration")
×
1241
        energy_ax.set_ylabel("Energy", color="blue")
×
1242
        energy_ax.tick_params(axis="y", labelcolor="blue")
×
1243
        energy_ax.set_title("Hamiltonian & Potential Energy")
×
1244
        energy_ax.legend(loc="best", fontsize="small")
×
1245
        energy_ax.grid(True, alpha=0.3)
×
1246

1247
        # Add statistics
1248
        energy_ax.text(
×
1249
            0.02,
1250
            0.98,
1251
            f"H: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}\nP: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}",
1252
            transform=energy_ax.transAxes,
1253
            fontsize=8,
1254
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1255
            verticalalignment="top",
1256
        )
1257

1258
    elif has_energy:
×
1259
        # Only Hamiltonian energy
1260
        energy = idata.sample_stats.energy.values.flatten()
×
1261
        energy_ax.plot(energy, alpha=0.7, linewidth=1, color="blue")
×
1262
        energy_ax.set_xlabel("Iteration")
×
1263
        energy_ax.set_ylabel("Hamiltonian Energy")
×
1264
        energy_ax.set_title("Hamiltonian Energy Trace")
×
1265
        energy_ax.grid(True, alpha=0.3)
×
1266

1267
    elif has_potential:
×
1268
        # Only potential energy
1269
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1270
        energy_ax.plot(potential, alpha=0.7, linewidth=1, color="orange")
×
1271
        energy_ax.set_xlabel("Iteration")
×
1272
        energy_ax.set_ylabel("Potential Energy")
×
1273
        energy_ax.set_title("Potential Energy Trace")
×
1274
        energy_ax.grid(True, alpha=0.3)
×
1275

1276
    else:
1277
        energy_ax.text(
×
1278
            0.5,
1279
            0.5,
1280
            "Energy data\nunavailable",
1281
            ha="center",
1282
            va="center",
1283
            transform=energy_ax.transAxes,
1284
        )
1285
        energy_ax.set_title("Energy Diagnostics")
×
1286

1287
    # Top-right: Sampling efficiency diagnostics
1288
    if has_steps:
×
1289
        steps_ax = axes[0, 1]
×
1290
        num_steps = idata.sample_stats.num_steps.values.flatten()
×
1291

1292
        # Show histogram with color zones for step efficiency
1293
        n, bins, edges = steps_ax.hist(
×
1294
            num_steps, bins=20, alpha=0.7, edgecolor="black"
1295
        )
1296

1297
        # Add shaded regions for different efficiency levels
1298
        # Green: efficient (tree depth ≤5, ~32 steps)
1299
        # Yellow: moderate (tree depth 6-8, ~64-256 steps)
1300
        # Red: inefficient (tree depth >8, >256 steps)
1301
        steps_ax.axvspan(
×
1302
            0, 64, alpha=0.1, color="green", label="Efficient (≤64)"
1303
        )
1304
        steps_ax.axvspan(
×
1305
            64, 256, alpha=0.1, color="yellow", label="Moderate (65-256)"
1306
        )
1307
        steps_ax.axvspan(
×
1308
            256,
1309
            np.max(num_steps),
1310
            alpha=0.1,
1311
            color="red",
1312
            label="Inefficient (>256)",
1313
        )
1314

1315
        # Add reference lines for different tree depths
1316
        for depth in [5, 7, 10]:  # Common tree depths
×
1317
            max_steps = 2**depth
×
1318
            steps_ax.axvline(
×
1319
                x=max_steps,
1320
                color="gray",
1321
                linestyle=":",
1322
                alpha=0.7,
1323
                linewidth=1,
1324
                label=f"2^{depth} ({max_steps})",
1325
            )
1326

1327
        steps_ax.set_xlabel("Leapfrog Steps")
×
1328
        steps_ax.set_ylabel("Trajectories")
×
1329
        steps_ax.set_title("Leapfrog Steps Distribution")
×
1330
        steps_ax.legend(loc="best", fontsize="small")
×
1331
        steps_ax.grid(True, alpha=0.3)
×
1332

1333
        # Add efficiency statistics
1334
        pct_inefficient = (num_steps > 256).mean() * 100
×
1335
        pct_moderate = ((num_steps > 64) & (num_steps <= 256)).mean() * 100
×
1336
        pct_efficient = (num_steps <= 64).mean() * 100
×
1337
        steps_ax.text(
×
1338
            0.02,
1339
            0.98,
1340
            f"Efficient: {pct_efficient:.1f}%\nModerate: {pct_moderate:.1f}%\nInefficient: {pct_inefficient:.1f}%\nMean steps: {np.mean(num_steps):.1f}",
1341
            transform=steps_ax.transAxes,
1342
            fontsize=7,
1343
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1344
            verticalalignment="top",
1345
        )
1346

1347
    else:
1348
        axes[0, 1].text(
×
1349
            0.5, 0.5, "Steps data\nunavailable", ha="center", va="center"
1350
        )
1351
        axes[0, 1].set_title("Sampling Steps")
×
1352

1353
    # Bottom-left: Acceptance and NS divergence diagnostics
1354
    accept_ax = axes[1, 0]
×
1355

1356
    if has_accept:
×
1357
        accept_prob = idata.sample_stats.accept_prob.values.flatten()
×
1358

1359
        # Plot acceptance probability with guidance zones
1360
        accept_ax.fill_between(
×
1361
            range(len(accept_prob)),
1362
            0.7,
1363
            0.9,
1364
            alpha=0.1,
1365
            color="green",
1366
            label="Good (0.7-0.9)",
1367
        )
1368
        accept_ax.fill_between(
×
1369
            range(len(accept_prob)),
1370
            0,
1371
            0.6,
1372
            alpha=0.1,
1373
            color="red",
1374
            label="Too low",
1375
        )
1376
        accept_ax.fill_between(
×
1377
            range(len(accept_prob)),
1378
            0.9,
1379
            1.0,
1380
            alpha=0.1,
1381
            color="orange",
1382
            label="Too high",
1383
        )
1384

1385
        accept_ax.plot(
×
1386
            accept_prob,
1387
            alpha=0.8,
1388
            linewidth=1,
1389
            color="blue",
1390
            label="Acceptance prob",
1391
        )
1392
        accept_ax.axhline(
×
1393
            0.8,
1394
            color="red",
1395
            linestyle="--",
1396
            linewidth=2,
1397
            label="NUTS target (0.8)",
1398
        )
1399
        accept_ax.set_xlabel("Iteration")
×
1400
        accept_ax.set_ylabel("Acceptance Probability")
×
1401
        accept_ax.set_title("NUTS Acceptance Diagnostic")
×
1402
        accept_ax.legend(loc="best", fontsize="small")
×
1403
        accept_ax.set_ylim(0, 1)
×
1404
        accept_ax.grid(True, alpha=0.3)
×
1405

1406
    else:
1407
        accept_ax.text(
×
1408
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1409
        )
1410
        accept_ax.set_title("Acceptance Diagnostic")
×
1411

1412
    # Bottom-right: Summary statistics and additional diagnostics
1413
    summary_ax = axes[1, 1]
×
1414

1415
    # Collect available statistics
1416
    stats_lines = []
×
1417

1418
    if has_energy:
×
1419
        energy = idata.sample_stats.energy.values.flatten()
×
1420
        stats_lines.append(
×
1421
            f"Energy: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}"
1422
        )
1423

1424
    if has_potential:
×
1425
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1426
        stats_lines.append(
×
1427
            f"Potential: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}"
1428
        )
1429

1430
    if has_steps:
×
1431
        num_steps = idata.sample_stats.num_steps.values.flatten()
×
1432
        stats_lines.append(
×
1433
            f"Steps: μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1434
        )
1435
        stats_lines.append("")
×
1436

1437
    if has_tree_depth:
×
1438
        tree_depth = idata.sample_stats.tree_depth.values.flatten()
×
1439
        stats_lines.append(f"Tree depth: μ={np.mean(tree_depth):.1f}")
×
1440
        pct_max_depth = (tree_depth >= 10).mean() * 100
×
1441
        stats_lines.append(f"Max depth (≥10): {pct_max_depth:.1f}%")
×
1442

1443
    if has_divergences:
×
1444
        divergences = idata.sample_stats.diverging.values.flatten()
×
1445
        n_divergences = np.sum(divergences)
×
1446
        pct_divergent = n_divergences / len(divergences) * 100
×
1447
        stats_lines.append(
×
1448
            f"Divergent: {n_divergences}/{len(divergences)} ({pct_divergent:.2f}%)"
1449
        )
1450

1451
    if has_energy_error:
×
1452
        energy_error = idata.sample_stats.energy_error.values.flatten()
×
1453
        stats_lines.append(
×
1454
            f"Energy error: |μ|={np.mean(np.abs(energy_error)):.3f}"
1455
        )
1456

1457
    if not stats_lines:
×
1458
        summary_ax.text(
×
1459
            0.5,
1460
            0.5,
1461
            "No diagnostics\ndata available",
1462
            ha="center",
1463
            va="center",
1464
            transform=summary_ax.transAxes,
1465
        )
1466
        summary_ax.set_title("NUTS Statistics")
×
1467
        summary_ax.axis("off")
×
1468
    else:
1469
        summary_text = "\n".join(["NUTS Diagnostics:"] + [""] + stats_lines)
×
1470
        summary_ax.text(
×
1471
            0.05,
1472
            0.95,
1473
            summary_text,
1474
            transform=summary_ax.transAxes,
1475
            fontsize=10,
1476
            verticalalignment="top",
1477
            fontfamily="monospace",
1478
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1479
        )
1480
        summary_ax.set_title("NUTS Summary Statistics")
×
1481
        summary_ax.axis("off")
×
1482

1483
    plt.tight_layout()
×
1484

1485

1486
def _plot_mh_step_sizes(idata, config):
2✔
1487
    """MH step size diagnostics."""
1488
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1489

1490
    step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
1491
    step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
1492

1493
    # Step size evolution
1494
    axes[0, 0].plot(
2✔
1495
        step_means, alpha=0.7, linewidth=1, label="Mean", color="blue"
1496
    )
1497
    axes[0, 0].plot(
2✔
1498
        step_stds, alpha=0.7, linewidth=1, label="Std", color="orange"
1499
    )
1500
    axes[0, 0].set_xlabel("Iteration")
2✔
1501
    axes[0, 0].set_ylabel("Step Size")
2✔
1502
    axes[0, 0].set_title("Step Size Evolution")
2✔
1503
    axes[0, 0].legend()
2✔
1504
    axes[0, 0].grid(True, alpha=0.3)
2✔
1505

1506
    # Step size distributions
1507
    axes[0, 1].hist(step_means, bins=30, alpha=0.5, label="Mean", color="blue")
2✔
1508
    axes[0, 1].hist(step_stds, bins=30, alpha=0.5, label="Std", color="orange")
2✔
1509
    axes[0, 1].set_xlabel("Step Size")
2✔
1510
    axes[0, 1].set_ylabel("Count")
2✔
1511
    axes[0, 1].set_title("Step Size Distributions")
2✔
1512
    axes[0, 1].legend()
2✔
1513
    axes[0, 1].grid(True, alpha=0.3)
2✔
1514

1515
    # Step size adaptation quality
1516
    axes[1, 0].plot(step_means / step_stds, alpha=0.7, linewidth=1)
2✔
1517
    axes[1, 0].set_xlabel("Iteration")
2✔
1518
    axes[1, 0].set_ylabel("Mean / Std")
2✔
1519
    axes[1, 0].set_title("Step Size Consistency")
2✔
1520
    axes[1, 0].grid(True, alpha=0.3)
2✔
1521

1522
    # Summary statistics
1523
    summary_lines = [
2✔
1524
        "Step Size Summary:",
1525
        f"Final mean: {step_means[-1]:.4f}",
1526
        f"Final std: {step_stds[-1]:.4f}",
1527
        f"Mean of means: {np.mean(step_means):.4f}",
1528
        f"Mean of stds: {np.mean(step_stds):.4f}",
1529
        "",
1530
        "Adaptation Quality:",
1531
        f"CV of means: {np.std(step_means)/np.mean(step_means):.3f}",
1532
        f"CV of stds: {np.std(step_stds)/np.mean(step_stds):.3f}",
1533
    ]
1534

1535
    axes[1, 1].text(
2✔
1536
        0.05,
1537
        0.95,
1538
        "\n".join(summary_lines),
1539
        transform=axes[1, 1].transAxes,
1540
        fontsize=10,
1541
        verticalalignment="top",
1542
        fontfamily="monospace",
1543
    )
1544
    axes[1, 1].set_title("Step Size Statistics")
2✔
1545
    axes[1, 1].axis("off")
2✔
1546

1547
    plt.tight_layout()
2✔
1548

1549

1550
def _create_divergences_diagnostics(idata, diag_dir, config):
2✔
1551
    """Create divergences diagnostics for NUTS samplers."""
1552
    # Check if divergences data exists
1553
    has_divergences = "diverging" in idata.sample_stats
2✔
1554
    has_channel_divergences = any(
2✔
1555
        key.startswith("diverging_channel_") for key in idata.sample_stats
1556
    )
1557

1558
    if not has_divergences and not has_channel_divergences:
2✔
1559
        return  # Nothing to plot
2✔
1560

1561
    @safe_plot(f"{diag_dir}/divergences.png", config.dpi)
2✔
1562
    def plot_divergences():
2✔
1563
        _plot_divergences(idata, config)
2✔
1564

1565
    plot_divergences()
2✔
1566

1567

1568
def _plot_divergences(idata, config):
2✔
1569
    """Plot divergences diagnostics."""
1570
    # Collect all divergence data
1571
    divergences_data = {}
2✔
1572

1573
    # Check for main divergences (single chain NUTS)
1574
    if "diverging" in idata.sample_stats:
2✔
1575
        divergences_data["main"] = (
2✔
1576
            idata.sample_stats.diverging.values.flatten()
1577
        )
1578

1579
    # Check for channel-specific divergences (blocked NUTS)
1580
    channel_divergences = {}
2✔
1581
    for key in idata.sample_stats:
2✔
1582
        if key.startswith("diverging_channel_"):
2✔
1583
            channel_idx = key.replace("diverging_channel_", "")
2✔
1584
            channel_divergences[int(channel_idx)] = idata.sample_stats[
2✔
1585
                key
1586
            ].values.flatten()
1587

1588
    if channel_divergences:
2✔
1589
        divergences_data.update(channel_divergences)
2✔
1590

1591
    if not divergences_data:
2✔
1592
        fig, ax = plt.subplots(figsize=config.figsize)
×
1593
        ax.text(
×
1594
            0.5, 0.5, "No divergence data available", ha="center", va="center"
1595
        )
1596
        ax.set_title("Divergences Diagnostics")
×
1597
        return
×
1598

1599
    # Create subplot layout
1600
    n_plots = len(divergences_data)
2✔
1601
    if n_plots == 1:
2✔
1602
        fig, axes = plt.subplots(1, 2, figsize=config.figsize)
2✔
1603
        trace_ax, summary_ax = axes
2✔
1604
    else:
1605
        # Multiple plots - arrange in grid
1606
        cols = 2
2✔
1607
        rows = (n_plots + 1) // cols  # Ceiling division
2✔
1608
        fig, axes = plt.subplots(rows, cols, figsize=config.figsize)
2✔
1609
        if rows == 1:
2✔
1610
            axes = axes.reshape(1, -1)
2✔
1611
        axes = axes.flatten()
2✔
1612

1613
        # Last plot goes in summary_ax if odd number
1614
        if n_plots % 2 == 1:
2✔
1615
            trace_axes = axes[:-1]
×
1616
            summary_ax = axes[-1]
×
1617
        else:
1618
            trace_axes = axes
2✔
1619
            summary_ax = None
2✔
1620

1621
    # Plot divergences traces
1622
    total_divergences = 0
2✔
1623
    total_iterations = 0
2✔
1624

1625
    plot_idx = 0
2✔
1626
    for label, div_values in divergences_data.items():
2✔
1627
        if label == "main":
2✔
1628
            title = "NUTS Divergences"
2✔
1629
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1630
        else:
1631
            title = f"Channel {label} Divergences"
2✔
1632
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1633
            plot_idx += 1
2✔
1634

1635
        # Plot divergence indicators (where divergences occur)
1636
        div_indices = np.where(div_values)[0]
2✔
1637
        ax.scatter(
2✔
1638
            div_indices,
1639
            np.ones_like(div_indices),
1640
            color="red",
1641
            marker="x",
1642
            s=50,
1643
            linewidth=2,
1644
            label="Divergent",
1645
            alpha=0.8,
1646
        )
1647

1648
        # Add background shading for divergent regions
1649
        if len(div_indices) > 0:
2✔
1650
            for idx in div_indices:
2✔
1651
                ax.axvspan(idx - 0.5, idx + 0.5, alpha=0.2, color="red")
2✔
1652

1653
        ax.set_xlabel("Iteration")
2✔
1654
        ax.set_ylabel("Divergence Indicator")
2✔
1655
        ax.set_title(title)
2✔
1656
        ax.set_yticks([0, 1])
2✔
1657
        ax.set_yticklabels(["No", "Yes"])
2✔
1658
        ax.grid(True, alpha=0.3)
2✔
1659

1660
        # Add statistics
1661
        n_divergent = np.sum(div_values)
2✔
1662
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1663
        stats_text = f"{n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
2✔
1664
        ax.text(
2✔
1665
            0.02,
1666
            0.98,
1667
            stats_text,
1668
            transform=ax.transAxes,
1669
            fontsize=10,
1670
            bbox=dict(boxstyle="round", facecolor="lightcoral", alpha=0.8),
1671
            verticalalignment="top",
1672
        )
1673

1674
        total_divergences += n_divergent
2✔
1675
        total_iterations += len(div_values)
2✔
1676

1677
        # Legend only if there are divergences
1678
        if n_divergent > 0:
2✔
1679
            ax.legend(loc="upper right", fontsize="small")
2✔
1680

1681
    # Summary plot
1682
    if summary_ax is not None and n_plots > 1:
2✔
1683
        summary_ax.text(
×
1684
            0.05,
1685
            0.95,
1686
            _get_divergences_summary(divergences_data),
1687
            transform=summary_ax.transAxes,
1688
            fontsize=12,
1689
            verticalalignment="top",
1690
            fontfamily="monospace",
1691
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1692
        )
1693
        summary_ax.set_title("Divergences Summary")
×
1694
        summary_ax.axis("off")
×
1695
    elif n_plots == 1:
2✔
1696
        axes[1].text(
2✔
1697
            0.05,
1698
            0.95,
1699
            _get_divergences_summary(divergences_data),
1700
            transform=axes[1].transAxes,
1701
            fontsize=12,
1702
            verticalalignment="top",
1703
            fontfamily="monospace",
1704
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1705
        )
1706
        axes[1].set_title("Divergences Summary")
2✔
1707
        axes[1].axis("off")
2✔
1708

1709
    # Overall title
1710
    overall_pct = (
2✔
1711
        total_divergences / total_iterations * 100
1712
        if total_iterations > 0
1713
        else 0
1714
    )
1715
    fig.suptitle(f"Overall Divergences: {overall_pct:.2f}%")
2✔
1716

1717
    plt.tight_layout()
2✔
1718

1719

1720
def _get_divergences_summary(divergences_data):
2✔
1721
    """Generate text summary of divergences."""
1722
    lines = ["Divergences Summary:", ""]
2✔
1723

1724
    total_divergences = 0
2✔
1725
    total_iterations = 0
2✔
1726

1727
    for label, div_values in divergences_data.items():
2✔
1728
        n_divergent = np.sum(div_values)
2✔
1729
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1730

1731
        if label == "main":
2✔
1732
            lines.append(
2✔
1733
                f"NUTS: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1734
            )
1735
        else:
1736
            lines.append(
×
1737
                f"Channel {label}: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1738
            )
1739

1740
        total_divergences += n_divergent
2✔
1741
        total_iterations += len(div_values)
2✔
1742

1743
    lines.append("")
2✔
1744
    overall_pct = (
2✔
1745
        total_divergences / total_iterations * 100
1746
        if total_iterations > 0
1747
        else 0
1748
    )
1749
    lines.append(
2✔
1750
        f"Total: {total_divergences}/{total_iterations} ({overall_pct:.2f}%)"
1751
    )
1752

1753
    lines.append("")
2✔
1754
    lines.append("Interpretation:")
2✔
1755
    if overall_pct == 0:
2✔
1756
        lines.append("  ✓ No divergences detected")
2✔
1757
        lines.append("    Sampling appears well-behaved")
2✔
1758
    elif overall_pct < 0.1:
2✔
1759
        lines.append("  ~ Few divergences")
×
1760
        lines.append("    Generally good, but monitor")
×
1761
    elif overall_pct < 1.0:
2✔
1762
        lines.append("  ⚠ Some divergences detected")
×
1763
        lines.append("    May indicate sampling issues")
×
1764
    else:
1765
        lines.append("  ✗ Many divergences!")
2✔
1766
        lines.append("    Significant sampling problems")
2✔
1767
        lines.append("    Consider model reparameterization")
2✔
1768

1769
    return "\n".join(lines)
2✔
1770

1771

1772
def _group_parameters_simple(idata):
2✔
1773
    """Simple parameter grouping for counting."""
1774
    param_groups = {"phi": [], "delta": [], "weights": [], "other": []}
2✔
1775

1776
    for param in idata.posterior.data_vars:
2✔
1777
        if param.startswith("phi"):
2✔
1778
            param_groups["phi"].append(param)
2✔
1779
        elif param.startswith("delta"):
2✔
1780
            param_groups["delta"].append(param)
2✔
1781
        elif param.startswith("weights"):
2✔
1782
            param_groups["weights"].append(param)
2✔
1783
        else:
1784
            param_groups["other"].append(param)
×
1785

1786
    return {k: v for k, v in param_groups.items() if v}
2✔
1787

1788

1789
def generate_diagnostics_summary(idata, outdir):
2✔
1790
    """Generate comprehensive text summary."""
1791
    summary = []
2✔
1792
    summary.append("=== MCMC Diagnostics Summary ===\n")
2✔
1793

1794
    ess_values = None
2✔
1795
    rhat_good = None
2✔
1796

1797
    # Basic info
1798
    attrs = getattr(idata, "attrs", {}) or {}
2✔
1799
    if not hasattr(attrs, "get"):
2✔
1800
        attrs = dict(attrs)
×
1801

1802
    n_samples = idata.posterior.sizes.get("draw", 0)
2✔
1803
    n_chains = idata.posterior.sizes.get("chain", 1)
2✔
1804
    n_params = len(list(idata.posterior.data_vars))
2✔
1805
    sampler_type = attrs.get("sampler_type", "Unknown")
2✔
1806

1807
    summary.append(f"Sampler: {sampler_type}")
2✔
1808
    summary.append(
2✔
1809
        f"Samples: {n_samples} per chain × {n_chains} chains = {n_samples * n_chains} total"
1810
    )
1811
    summary.append(f"Parameters: {n_params}")
2✔
1812

1813
    # Parameter breakdown
1814
    param_groups = _group_parameters_simple(idata)
2✔
1815
    if param_groups:
2✔
1816
        param_summary = ", ".join(
2✔
1817
            [f"{k}: {len(v)}" for k, v in param_groups.items()]
1818
        )
1819
        summary.append(f"Parameter groups: {param_summary}")
2✔
1820

1821
    # ESS
1822
    try:
2✔
1823
        ess = idata.attrs.get("ess")
2✔
1824
        ess_values = ess[~np.isnan(ess)]
2✔
1825

1826
        if len(ess_values) > 0:
2✔
1827
            summary.append(
2✔
1828
                f"\nESS: min={ess_values.min():.0f}, mean={ess_values.mean():.0f}, max={ess_values.max():.0f}"
1829
            )
1830
            summary.append(f"ESS ≥ 400: {(ess_values >= 400).mean()*100:.1f}%")
2✔
1831
    except Exception as e:
×
1832
        summary.append(f"\nESS: unavailable")
×
1833

1834
    # Rhat
1835
    try:
2✔
1836
        rhat = az.rhat(idata)
2✔
1837
        rhat_vals = np.asarray(rhat.to_array()).ravel()
2✔
1838
        rhat_vals = rhat_vals[np.isfinite(rhat_vals)]
2✔
1839
        if rhat_vals.size:
2✔
1840
            summary.append(
2✔
1841
                f"Rhat: min={rhat_vals.min():.3f}, mean={rhat_vals.mean():.3f}, max={rhat_vals.max():.3f}"
1842
            )
1843
            rhat_good = (rhat_vals <= 1.01).mean() * 100
2✔
1844
            summary.append(f"Rhat ≤ 1.01: {rhat_good:.1f}%")
2✔
1845
        else:
1846
            summary.append("Rhat: unavailable (needs ≥2 chains)")
2✔
NEW
1847
    except Exception:
×
NEW
1848
        summary.append("Rhat: unavailable")
×
1849

1850
    # Acceptance
1851
    accept_key = None
2✔
1852
    if "accept_prob" in idata.sample_stats:
2✔
1853
        accept_key = "accept_prob"
2✔
1854
    elif "acceptance_rate" in idata.sample_stats:
2✔
1855
        accept_key = "acceptance_rate"
×
1856

1857
    if accept_key is not None:
2✔
1858
        accept_rate = idata.sample_stats[accept_key].values.mean()
2✔
1859
        target_rate = attrs.get(
2✔
1860
            "target_accept_rate", attrs.get("target_accept_prob", 0.44)
1861
        )
1862
        summary.append(
2✔
1863
            f"Acceptance rate: {accept_rate:.3f} (target: {target_rate:.3f})"
1864
        )
1865
    else:
1866
        # Blocked NUTS: compute a combined mean from per‑channel keys if present
1867
        channel_means = []
2✔
1868
        for key in idata.sample_stats:
2✔
1869
            if isinstance(key, str) and key.startswith("accept_prob_channel_"):
2✔
1870
                try:
2✔
1871
                    channel_means.append(
2✔
1872
                        float(idata.sample_stats[key].values.mean())
1873
                    )
1874
                except Exception:
×
1875
                    pass
×
1876
        if channel_means:
2✔
1877
            target_rate = attrs.get(
2✔
1878
                "target_accept_rate", attrs.get("target_accept_prob", 0.8)
1879
            )
1880
            summary.append(
2✔
1881
                f"Acceptance rate (per-channel mean): {np.mean(channel_means):.3f} (target: {target_rate:.3f})"
1882
            )
1883

1884
    # PSD accuracy diagnostics (requires true_psd in attrs)
1885
    has_true_psd = "true_psd" in attrs
2✔
1886

1887
    if has_true_psd:
2✔
1888
        coverage_level = attrs.get("coverage_level")
2✔
1889
        coverage_label = (
2✔
1890
            f"{int(round(coverage_level * 100))}% interval coverage"
1891
            if coverage_level is not None
1892
            else "Interval coverage"
1893
        )
1894

1895
        def _format_riae_line(value, errorbars, prefix="  "):
2✔
1896
            line = f"{prefix}RIAE: {value:.3f}"
2✔
1897
            if errorbars:
2✔
1898
                q05, q25, median, q75, q95 = errorbars
2✔
1899
                line += f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
2✔
1900
            summary.append(line)
2✔
1901

1902
        def _format_coverage_line(value, prefix="  "):
2✔
1903
            if value is None:
2✔
1904
                return
×
1905
            summary.append(f"{prefix}{coverage_label}: {value * 100:.1f}%")
2✔
1906

1907
        summary.append("\nPSD accuracy diagnostics:")
2✔
1908

1909
        if "riae" in attrs:
2✔
1910
            _format_riae_line(attrs["riae"], attrs.get("riae_errorbars"))
2✔
1911
        if "coverage" in attrs:
2✔
1912
            _format_coverage_line(attrs["coverage"])
2✔
1913

1914
        channel_indices = sorted(
2✔
1915
            int(key.replace("riae_ch", ""))
1916
            for key in attrs.keys()
1917
            if key.startswith("riae_ch")
1918
        )
1919

1920
        for idx in channel_indices:
2✔
1921
            metrics = []
×
1922
            riae_key = f"riae_ch{idx}"
×
1923
            cov_key = f"coverage_ch{idx}"
×
1924
            error_key = f"riae_errorbars_ch{idx}"
×
1925

1926
            if riae_key in attrs:
×
1927
                riae_line = f"RIAE {attrs[riae_key]:.3f}"
×
1928
                errorbars = attrs.get(error_key)
×
1929
                if errorbars:
×
1930
                    q05, _, median, _, q95 = errorbars
×
1931
                    riae_line += (
×
1932
                        f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
1933
                    )
1934
                metrics.append(riae_line)
×
1935

1936
            if cov_key in attrs:
×
1937
                metrics.append(f"{coverage_label} {attrs[cov_key] * 100:.1f}%")
×
1938

1939
            if metrics:
×
1940
                summary.append(f"  Channel {idx}: " + "; ".join(metrics))
×
1941

1942
    # Overall assessment
1943
    try:
2✔
1944
        ess_good = (
2✔
1945
            (ess_values >= 400).mean() * 100
1946
            if ess_values is not None
1947
            else None
1948
        )
1949
        summary.append(f"\nOverall Convergence Assessment:")
2✔
1950
        if ess_good is None:
2✔
NEW
1951
            summary.append("  Status: UNKNOWN (insufficient diagnostics)")
×
1952
        else:
1953
            meets_rhat = (
2✔
1954
                rhat_good is None or rhat_good >= 90
1955
            )  # treat missing rhat as neutral
1956
            if ess_good >= 90 and meets_rhat:
2✔
UNCOV
1957
                summary.append("  Status: EXCELLENT ✓")
×
1958
            elif ess_good >= 75 and meets_rhat:
2✔
1959
                summary.append("  Status: GOOD ✓")
×
1960
            else:
1961
                summary.append("  Status: NEEDS ATTENTION ⚠")
2✔
1962
    except:
×
1963
        pass
×
1964

1965
    summary_text = "\n".join(summary)
2✔
1966

1967
    if outdir:
2✔
1968
        with open(f"{outdir}/diagnostics_summary.txt", "w") as f:
2✔
1969
            f.write(summary_text)
2✔
1970

1971
    logger.info(f"\n{summary_text}\n")
2✔
1972
    return summary_text
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc