• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 19384398576

15 Nov 2025 04:24AM UTC coverage: 79.66% (-0.1%) from 79.76%
19384398576

push

github

avivajpeyi
fix scaling and one sided psd

727 of 864 branches covered (84.14%)

Branch coverage included in aggregate %.

42 of 47 new or added lines in 6 files covered. (89.36%)

6 existing lines in 2 files now uncovered.

4662 of 5901 relevant lines covered (79.0%)

1.58 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.1
/src/log_psplines/plotting/diagnostics.py
1
import os
2✔
2
from dataclasses import dataclass
2✔
3
from typing import Optional
2✔
4

5
import arviz as az
2✔
6
import matplotlib.pyplot as plt
2✔
7
import numpy as np
2✔
8

9
from ..logger import logger
2✔
10
from .base import PlotConfig, safe_plot, setup_plot_style
2✔
11

12
# Setup consistent styling for diagnostics plots
13
setup_plot_style()
2✔
14

15

16
@dataclass
2✔
17
class DiagnosticsConfig:
2✔
18
    """Configuration for diagnostics plotting parameters."""
19

20
    figsize: tuple = (12, 8)
2✔
21
    dpi: int = 150
2✔
22
    ess_threshold: int = 400
2✔
23
    rhat_threshold: float = 1.01
2✔
24
    fontsize: int = 11
2✔
25
    labelsize: int = 12
2✔
26
    titlesize: int = 12
2✔
27

28

29
def plot_trace(idata: az.InferenceData, compact=True) -> plt.Figure:
2✔
30
    groups = {
2✔
31
        "delta": [
32
            v for v in idata.posterior.data_vars if v.startswith("delta")
33
        ],
34
        "phi": [v for v in idata.posterior.data_vars if v.startswith("phi")],
35
        "weights": [
36
            v for v in idata.posterior.data_vars if v.startswith("weights")
37
        ],
38
    }
39

40
    if compact:
2✔
41
        nrows = 3
2✔
42
    else:
43
        nrows = len(groups)
×
44
    fig, axes = plt.subplots(nrows, 2, figsize=(7, 3 * nrows))
2✔
45

46
    for row, (group_name, vars) in enumerate(groups.items()):
2✔
47

48
        # if vars are more than 1, and compact, then we need to repeat the axes
49
        if compact:
2✔
50
            group_axes = axes[row, :].reshape(1, 2)
2✔
51
            group_axes = np.repeat(group_axes, len(vars), axis=0)
2✔
52
        else:
53
            group_axes = axes[row, :]
×
54

55
        group_axes[0, 0].set_title(
2✔
56
            f"{group_name.capitalize()} Parameters", fontsize=14
57
        )
58

59
        for i, var in enumerate(vars):
2✔
60
            data = idata.posterior[
2✔
61
                var
62
            ].values  # shape is (nchain, nsamples, ndim) if ndim>1 else (nchain, nsamples)
63
            if data.ndim == 3:
2✔
64
                data = data[0].T  # shape is now (ndim, nsamples)
2✔
65

66
            ax_trace = group_axes[i, 0] if compact else group_axes[0]
2✔
67
            ax_hist = group_axes[i, 1] if compact else group_axes[1]
2✔
68
            ax_trace.set_ylabel(group_name, fontsize=8)
2✔
69
            ax_trace.set_xlabel("MCMC Step", fontsize=8)
2✔
70
            ax_hist.set_xlabel(group_name, fontsize=8)
2✔
71
            # place ylabel on right side of hist
72
            ax_hist.yaxis.set_label_position("right")
2✔
73
            ax_hist.set_ylabel("Density", fontsize=8, rotation=270, labelpad=0)
2✔
74

75
            # remove axes yspine for hist
76
            ax_hist.spines["left"].set_visible(False)
2✔
77
            ax_hist.spines["right"].set_visible(False)
2✔
78
            ax_hist.spines["top"].set_visible(False)
2✔
79
            ax_hist.set_yticks([])  # remove y ticks
2✔
80
            ax_hist.yaxis.set_ticks_position("none")
2✔
81

82
            ax_trace.spines["right"].set_visible(False)
2✔
83
            ax_trace.spines["top"].set_visible(False)
2✔
84

85
            color = f"C{i}"
2✔
86
            label = f"{var}"
2✔
87
            if group_name in ["phi", "delta"]:
2✔
88
                ax_trace.set_yscale("log")
2✔
89
                ax_hist.set_xscale("log")
2✔
90

91
            for p in data:
2✔
92
                ax_trace.plot(p, color=color, alpha=0.7, label=label)
2✔
93

94
                # if phi or delta, use log scale for hist-x, log for trace y
95
                if group_name in ["phi", "delta"]:
2✔
96
                    bins = np.logspace(
2✔
97
                        np.log10(np.min(p)), np.log10(np.max(p)), 30
98
                    )
99
                    logp = np.log(p)
2✔
100
                    log_grid, log_pdf = az.kde(logp)
2✔
101
                    grid = np.exp(log_grid)
2✔
102
                    pdf = log_pdf / grid  # change of variables
2✔
103
                else:
104
                    bins = 30
2✔
105
                    grid, pdf = az.kde(p)
2✔
106
                ax_hist.plot(grid, pdf, color=color, label=label)
2✔
107
                ax_hist.hist(
2✔
108
                    p, bins=bins, density=True, color=color, alpha=0.3
109
                )
110

111
                # KDE plot instead of histogram
112

113
    plt.suptitle("Parameter Traces", fontsize=16)
2✔
114
    plt.tight_layout()
2✔
115
    return fig
2✔
116

117

118
def plot_diagnostics(
2✔
119
    idata: az.InferenceData,
120
    outdir: str,
121
    n_channels: Optional[int] = None,
122
    n_freq: Optional[int] = None,
123
    runtime: Optional[float] = None,
124
    config: Optional[DiagnosticsConfig] = None,
125
) -> None:
126
    """
127
    Create essential MCMC diagnostics in organized subdirectories.
128
    """
129
    if outdir is None:
2✔
130
        return
×
131

132
    if config is None:
2✔
133
        config = DiagnosticsConfig()
2✔
134

135
    # Create diagnostics subdirectory
136
    diag_dir = os.path.join(outdir, "diagnostics")
2✔
137
    os.makedirs(diag_dir, exist_ok=True)
2✔
138

139
    logger.info("Generating MCMC diagnostics...")
2✔
140

141
    # Generate summary report
142
    generate_diagnostics_summary(idata, diag_dir)
2✔
143
    _create_diagnostic_plots(
2✔
144
        idata, diag_dir, config, n_channels, n_freq, runtime
145
    )
146

147

148
def _create_diagnostic_plots(
2✔
149
    idata, diag_dir, config, n_channels, n_freq, runtime
150
):
151
    """Create only the essential diagnostic plots."""
152
    logger.debug("Generating diagnostic plots...")
2✔
153

154
    # 1. ArviZ trace plots
155
    @safe_plot(f"{diag_dir}/trace_plots.png", config.dpi)
2✔
156
    def create_trace_plots():
2✔
157
        return plot_trace(idata)
2✔
158

159
    create_trace_plots()
2✔
160

161
    # 2. Summary dashboard with key convergence metrics
162
    @safe_plot(f"{diag_dir}/summary_dashboard.png", config.dpi)
2✔
163
    def plot_summary():
2✔
164
        _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime)
2✔
165

166
    plot_summary()
2✔
167

168
    # 3. Log posterior diagnostics
169
    @safe_plot(f"{diag_dir}/log_posterior.png", config.dpi)
2✔
170
    def plot_lp():
2✔
171
        _plot_log_posterior(idata, config)
2✔
172

173
    plot_lp()
2✔
174

175
    # 4. Acceptance rate diagnostics
176
    @safe_plot(f"{diag_dir}/acceptance_diagnostics.png", config.dpi)
2✔
177
    def plot_acceptance():
2✔
178
        _plot_acceptance_diagnostics_blockaware(idata, config)
2✔
179

180
    plot_acceptance()
2✔
181

182
    # 5. Sampler-specific diagnostics
183
    _create_sampler_diagnostics(idata, diag_dir, config)
2✔
184

185
    # 6. Divergences diagnostics (for NUTS only)
186
    _create_divergences_diagnostics(idata, diag_dir, config)
2✔
187

188

189
def _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime):
2✔
190

191
    # Create 2x2 layout
192
    fig, axes = plt.subplots(
2✔
193
        2, 2, figsize=(config.figsize[0] * 0.8, config.figsize[1])
194
    )
195
    ess_ax = axes[0, 0]
2✔
196
    meta_ax = axes[0, 1]
2✔
197
    param_ax = axes[1, 0]
2✔
198
    status_ax = axes[1, 1]
2✔
199

200
    # Get ESS values once
201
    ess_values = None
2✔
202
    try:
2✔
203
        ess = idata.attrs.get("ess")
2✔
204
        ess_values = ess[~np.isnan(ess)]
2✔
205
    except Exception:
×
206
        pass
×
207

208
    # 1. ESS Distribution
209
    _plot_ess_histogram(ess_ax, ess_values, config)
2✔
210

211
    # 2. Analysis Metadata
212
    _plot_metadata(meta_ax, idata, n_channels, n_freq, runtime)
2✔
213

214
    # 3. Parameter Summary
215
    _plot_parameter_summary(param_ax, idata)
2✔
216

217
    # 4. Convergence Status
218
    _plot_convergence_status(status_ax, ess_values, config)
2✔
219

220
    plt.tight_layout()
2✔
221

222

223
def _plot_ess_histogram(ax, ess_values, config):
2✔
224
    """Plot ESS distribution with quality thresholds."""
225
    if ess_values is None or len(ess_values) == 0:
2✔
226
        ax.text(0.5, 0.5, "ESS unavailable", ha="center", va="center")
×
227
        ax.set_title("ESS Distribution")
×
228
        return
×
229

230
    # Histogram
231
    ax.hist(ess_values, bins=30, alpha=0.7, edgecolor="black")
2✔
232

233
    # Reference lines
234
    thresholds = [
2✔
235
        (400, "red", "--", "Minimum reliable"),
236
        (1000, "orange", "--", "Good"),
237
        (np.max(ess_values), "green", ":", f"Max = {np.max(ess_values):.0f}"),
238
    ]
239

240
    for threshold, color, style, label in thresholds:
2✔
241
        ax.axvline(
2✔
242
            x=threshold,
243
            color=color,
244
            linestyle=style,
245
            linewidth=2 if threshold < np.max(ess_values) else 1,
246
            alpha=0.8,
247
            label=label,
248
        )
249

250
    ax.set_xlabel("ESS")
2✔
251
    ax.set_ylabel("Count")
2✔
252
    ax.set_title("ESS Distribution")
2✔
253
    ax.legend(loc="upper right", fontsize="x-small")
2✔
254
    ax.grid(True, alpha=0.3)
2✔
255

256
    # Summary stats
257
    pct_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
258
    stats_text = f"Min: {ess_values.min():.0f}\nMean: {ess_values.mean():.0f}\n≥{config.ess_threshold}: {pct_good:.1f}%"
2✔
259
    ax.text(
2✔
260
        0.02,
261
        0.98,
262
        stats_text,
263
        transform=ax.transAxes,
264
        fontsize=10,
265
        verticalalignment="top",
266
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.7),
267
    )
268

269

270
def _plot_metadata(ax, idata, n_channels, n_freq, runtime):
2✔
271
    """Display analysis metadata."""
272
    try:
2✔
273
        n_samples = idata.posterior.sizes.get("draw", 0)
2✔
274
        n_chains = idata.posterior.sizes.get("chain", 1)
2✔
275
        n_params = len(list(idata.posterior.data_vars))
2✔
276
        sampler_type = idata.attrs["sampler_type"]
2✔
277

278
        metadata_lines = [
2✔
279
            f"Sampler: {sampler_type}",
280
            f"Samples: {n_samples} × {n_chains} chains",
281
            f"Parameters: {n_params}",
282
        ]
283
        if n_channels is not None:
2✔
284
            metadata_lines.append(f"Channels: {n_channels}")
×
285
        if n_freq is not None:
2✔
286
            metadata_lines.append(f"Frequencies: {n_freq}")
×
287
        if runtime is not None:
2✔
288
            metadata_lines.append(f"Runtime: {runtime:.2f}s")
×
289

290
        ax.text(
2✔
291
            0.05,
292
            0.95,
293
            "\n".join(metadata_lines),
294
            transform=ax.transAxes,
295
            fontsize=12,
296
            verticalalignment="top",
297
            fontfamily="monospace",
298
        )
299
    except Exception:
×
300
        ax.text(0.5, 0.5, "Metadata unavailable", ha="center", va="center")
×
301

302
    ax.set_title("Analysis Summary")
2✔
303
    ax.axis("off")
2✔
304

305

306
def _plot_parameter_summary(ax, idata):
2✔
307
    """Display parameter count summary."""
308
    try:
2✔
309
        param_groups = _group_parameters_simple(idata)
2✔
310
        if param_groups:
2✔
311
            summary_text = "Parameter Summary:\n"
2✔
312
            for group_name, params in param_groups.items():
2✔
313
                if params:
2✔
314
                    summary_text += f"{group_name}: {len(params)}\n"
2✔
315
            ax.text(
2✔
316
                0.05,
317
                0.95,
318
                summary_text.strip(),
319
                transform=ax.transAxes,
320
                fontsize=11,
321
                verticalalignment="top",
322
                fontfamily="monospace",
323
            )
324
    except Exception:
×
325
        ax.text(
×
326
            0.5,
327
            0.5,
328
            "Parameter summary\nunavailable",
329
            ha="center",
330
            va="center",
331
        )
332

333
    ax.set_title("Parameter Summary")
2✔
334
    ax.axis("off")
2✔
335

336

337
def _plot_convergence_status(ax, ess_values, config):
2✔
338
    """Display convergence status based on ESS only."""
339
    try:
2✔
340
        status_lines = ["Convergence Status:"]
2✔
341

342
        if ess_values is not None and len(ess_values) > 0:
2✔
343
            ess_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
344
            status_lines.append(
2✔
345
                f"ESS ≥ {config.ess_threshold}: {ess_good:.0f}%"
346
            )
347
            status_lines.append("")
2✔
348
            status_lines.append("Overall Status:")
2✔
349

350
            if ess_good >= 90:
2✔
351
                status_lines.append("✓ EXCELLENT")
×
352
                color = "green"
×
353
            elif ess_good >= 75:
2✔
354
                status_lines.append("✓ ADEQUATE")
×
355
                color = "orange"
×
356
            else:
357
                status_lines.append("⚠ NEEDS ATTENTION")
2✔
358
                color = "red"
2✔
359
        else:
360
            status_lines.append("? UNABLE TO ASSESS")
×
361
            color = "gray"
×
362

363
        ax.text(
2✔
364
            0.05,
365
            0.95,
366
            "\n".join(status_lines),
367
            transform=ax.transAxes,
368
            fontsize=11,
369
            verticalalignment="top",
370
            fontfamily="monospace",
371
            color=color,
372
        )
373
    except Exception:
×
374
        ax.text(0.5, 0.5, "Status unavailable", ha="center", va="center")
×
375

376
    ax.set_title("Convergence Status")
2✔
377
    ax.axis("off")
2✔
378

379

380
def _plot_log_posterior(idata, config):
2✔
381
    """Log posterior diagnostics."""
382
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
383

384
    # Check for lp first, then log_likelihood
385
    if "lp" in idata.sample_stats:
2✔
386
        lp_values = idata.sample_stats["lp"].values.flatten()
2✔
387
        var_name = "lp"
2✔
388
        title_prefix = "Log Posterior"
2✔
389
    elif "log_likelihood" in idata.sample_stats:
2✔
390
        lp_values = idata.sample_stats["log_likelihood"].values.flatten()
2✔
391
        var_name = "log_likelihood"
2✔
392
        title_prefix = "Log Likelihood"
2✔
393
    else:
394
        # Create a fallback layout when no posterior data available
395
        fig, axes = plt.subplots(1, 1, figsize=config.figsize)
×
396
        axes.text(
×
397
            0.5,
398
            0.5,
399
            "No log posterior\nor log likelihood\navailable",
400
            ha="center",
401
            va="center",
402
            fontsize=14,
403
        )
404
        axes.set_title("Log Posterior Diagnostics")
×
405
        axes.axis("off")
×
406
        plt.tight_layout()
×
407
        return
×
408

409
    # Trace plot with running mean overlaid
410
    axes[0, 0].plot(
2✔
411
        lp_values, alpha=0.7, linewidth=1, color="blue", label="Trace"
412
    )
413

414
    # Add running mean on the same plot
415
    window_size = max(10, len(lp_values) // 100)
2✔
416
    if len(lp_values) > window_size:
2✔
417
        running_mean = np.convolve(
2✔
418
            lp_values, np.ones(window_size) / window_size, mode="valid"
419
        )
420
        axes[0, 0].plot(
2✔
421
            range(window_size // 2, window_size // 2 + len(running_mean)),
422
            running_mean,
423
            alpha=0.9,
424
            linewidth=3,
425
            color="red",
426
            label=f"Running mean (w={window_size})",
427
        )
428

429
    axes[0, 0].set_xlabel("Iteration")
2✔
430
    axes[0, 0].set_ylabel(title_prefix)
2✔
431
    axes[0, 0].set_title(f"{title_prefix} Trace with Running Mean")
2✔
432
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
433
    axes[0, 0].grid(True, alpha=0.3)
2✔
434

435
    # Distribution
436
    axes[0, 1].hist(
2✔
437
        lp_values, bins=50, alpha=0.7, density=True, edgecolor="black"
438
    )
439
    axes[0, 1].axvline(
2✔
440
        np.mean(lp_values),
441
        color="red",
442
        linestyle="--",
443
        linewidth=2,
444
        label=f"Mean: {np.mean(lp_values):.1f}",
445
    )
446
    axes[0, 1].set_xlabel(title_prefix)
2✔
447
    axes[0, 1].set_ylabel("Density")
2✔
448
    axes[0, 1].set_title(f"{title_prefix} Distribution")
2✔
449
    axes[0, 1].legend(loc="best", fontsize="small")
2✔
450
    axes[0, 1].grid(True, alpha=0.3)
2✔
451

452
    # Step-to-step changes
453
    lp_diff = np.diff(lp_values)
2✔
454
    axes[1, 0].plot(lp_diff, alpha=0.5, linewidth=1)
2✔
455
    axes[1, 0].axhline(0, color="red", linestyle="--", alpha=0.7)
2✔
456
    axes[1, 0].axhline(
2✔
457
        np.mean(lp_diff),
458
        color="blue",
459
        linestyle="--",
460
        alpha=0.7,
461
        label=f"Mean change: {np.mean(lp_diff):.1f}",
462
    )
463
    axes[1, 0].set_xlabel("Iteration")
2✔
464
    axes[1, 0].set_ylabel(f"{title_prefix} Difference")
2✔
465
    axes[1, 0].set_title("Step-to-Step Changes")
2✔
466
    axes[1, 0].legend(loc="best", fontsize="small")
2✔
467
    axes[1, 0].grid(True, alpha=0.3)
2✔
468

469
    # Summary statistics
470
    stats_lines = [
2✔
471
        f"Mean: {np.mean(lp_values):.2f}",
472
        f"Std: {np.std(lp_values):.2f}",
473
        f"Min: {np.min(lp_values):.2f}",
474
        f"Max: {np.max(lp_values):.2f}",
475
        f"Range: {np.max(lp_values) - np.min(lp_values):.2f}",
476
        "",
477
        "Stability:",
478
        f"Final variation: {np.std(lp_values[-len(lp_values)//4:]):.2f}",
479
    ]
480

481
    axes[1, 1].text(
2✔
482
        0.05,
483
        0.95,
484
        "\n".join(stats_lines),
485
        transform=axes[1, 1].transAxes,
486
        fontsize=10,
487
        verticalalignment="top",
488
        fontfamily="monospace",
489
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
490
    )
491
    axes[1, 1].set_title("Posterior Statistics")
2✔
492
    axes[1, 1].axis("off")
2✔
493

494
    plt.tight_layout()
2✔
495

496

497
def _plot_acceptance_diagnostics(idata, config):
2✔
498
    """Acceptance rate diagnostics."""
499
    accept_key = None
×
500
    if "accept_prob" in idata.sample_stats:
×
501
        accept_key = "accept_prob"
×
502
    elif "acceptance_rate" in idata.sample_stats:
×
503
        accept_key = "acceptance_rate"
×
504

505
    if accept_key is None:
×
506
        fig, ax = plt.subplots(figsize=config.figsize)
×
507
        ax.text(
×
508
            0.5,
509
            0.5,
510
            "Acceptance rate data unavailable",
511
            ha="center",
512
            va="center",
513
        )
514
        ax.set_title("Acceptance Rate Diagnostics")
×
515
        return
×
516

517
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
×
518

519
    accept_rates = idata.sample_stats[accept_key].values.flatten()
×
520
    target_rate = getattr(idata.attrs, "target_accept_rate", 0.44)
×
521
    sampler_type = (
×
522
        idata.attrs["sampler_type"].lower()
523
        if "sampler_type" in idata.attrs
524
        else "unknown"
525
    )
526
    sampler_type = "NUTS" if "nuts" in sampler_type else "MH"
×
527

528
    # Define good ranges based on sampler
529
    if target_rate > 0.5:  # NUTS
×
530
        good_range = (0.7, 0.9)
×
531
        low_range = (0.0, 0.6)
×
532
        high_range = (0.9, 1.0)
×
533
        concerning_range = (0.6, 0.7)
×
534
    else:  # MH
535
        good_range = (0.2, 0.5)
×
536
        low_range = (0.0, 0.2)
×
537
        high_range = (0.5, 1.0)
×
538
        concerning_range = (0.1, 0.2)  # MH can be lower than NUTS
×
539

540
    # Trace plot with color zones
541
    # Add background zones
542
    axes[0, 0].axhspan(
×
543
        good_range[0],
544
        good_range[1],
545
        alpha=0.1,
546
        color="green",
547
        label=f"Good ({good_range[0]:.1f}-{good_range[1]:.1f})",
548
    )
549
    axes[0, 0].axhspan(
×
550
        low_range[0], low_range[1], alpha=0.1, color="red", label="Too low"
551
    )
552
    axes[0, 0].axhspan(
×
553
        high_range[0],
554
        high_range[1],
555
        alpha=0.1,
556
        color="orange",
557
        label="Too high",
558
    )
559
    if concerning_range[1] > concerning_range[0]:
×
560
        axes[0, 0].axhspan(
×
561
            concerning_range[0],
562
            concerning_range[1],
563
            alpha=0.1,
564
            color="yellow",
565
            label="Concerning",
566
        )
567

568
    # Main trace plot
569
    axes[0, 0].plot(
×
570
        accept_rates, alpha=0.8, linewidth=1, color="blue", label="Trace"
571
    )
572
    axes[0, 0].axhline(
×
573
        target_rate,
574
        color="red",
575
        linestyle="--",
576
        linewidth=2,
577
        label=f"Target ({target_rate})",
578
    )
579

580
    # Add running average on the same plot
581
    window_size = max(10, len(accept_rates) // 50)
×
582
    if len(accept_rates) > window_size:
×
583
        running_mean = np.convolve(
×
584
            accept_rates, np.ones(window_size) / window_size, mode="valid"
585
        )
586
        axes[0, 0].plot(
×
587
            range(window_size // 2, window_size // 2 + len(running_mean)),
588
            running_mean,
589
            alpha=0.9,
590
            linewidth=3,
591
            color="purple",
592
            label=f"Running mean (w={window_size})",
593
        )
594

595
    axes[0, 0].set_xlabel("Iteration")
×
596
    axes[0, 0].set_ylabel("Acceptance Rate")
×
597
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
×
598
    axes[0, 0].legend(loc="best", fontsize="small")
×
599
    axes[0, 0].grid(True, alpha=0.3)
×
600

601
    # Add interpretation text
602
    interpretation = f"{sampler_type} aims for {target_rate:.2f}."
×
603
    if target_rate > 0.5:
×
604
        interpretation += " Green: efficient sampling."
×
605
    else:
606
        interpretation += " MH adapts to find optimal rate."
×
607
    axes[0, 0].text(
×
608
        0.02,
609
        0.02,
610
        interpretation,
611
        transform=axes[0, 0].transAxes,
612
        fontsize=9,
613
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
614
    )
615

616
    # Distribution
617
    axes[0, 1].hist(
×
618
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
619
    )
620
    axes[0, 1].axvline(
×
621
        target_rate,
622
        color="red",
623
        linestyle="--",
624
        linewidth=2,
625
        label=f"Target ({target_rate})",
626
    )
627
    axes[0, 1].set_xlabel("Acceptance Rate")
×
628
    axes[0, 1].set_ylabel("Density")
×
629
    axes[0, 1].set_title("Acceptance Rate Distribution")
×
630
    axes[0, 1].legend()
×
631
    axes[0, 1].grid(True, alpha=0.3)
×
632

633
    # Since running means are already overlaid on the main plot, use the bottom row for additional info
634

635
    # Additional acceptance analysis - evolution over time
636
    if len(accept_rates) > 10:
×
637
        # Show moving standard deviation or coefficient of variation
638
        window_std = np.array(
×
639
            [
640
                np.std(accept_rates[max(0, i - 20) : i + 1])
641
                for i in range(len(accept_rates))
642
            ]
643
        )
644
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
×
645
        axes[1, 0].set_xlabel("Iteration")
×
646
        axes[1, 0].set_ylabel("Rolling Std")
×
647
        axes[1, 0].set_title("Rolling Standard Deviation")
×
648
        axes[1, 0].grid(True, alpha=0.3)
×
649
    else:
650
        axes[1, 0].text(
×
651
            0.5,
652
            0.5,
653
            "Acceptance variability\nanalysis unavailable",
654
            ha="center",
655
            va="center",
656
        )
657
        axes[1, 0].set_title("Acceptance Stability")
×
658

659
    # Summary statistics (expanded)
660
    stats_text = [
×
661
        f"Sampler: {sampler_type}",
662
        f"Target: {target_rate:.3f}",
663
        f"Mean: {np.mean(accept_rates):.3f}",
664
        f"Std: {np.std(accept_rates):.3f}",
665
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
666
        f"Min: {np.min(accept_rates):.3f}",
667
        f"Max: {np.max(accept_rates):.3f}",
668
        "",
669
        "Stability:",
670
        f"Final std: {np.std(accept_rates[-len(accept_rates)//4:]):.3f}",
671
    ]
672

673
    axes[1, 1].text(
×
674
        0.05,
675
        0.95,
676
        "\n".join(stats_text),
677
        transform=axes[1, 1].transAxes,
678
        fontsize=9,
679
        verticalalignment="top",
680
        fontfamily="monospace",
681
    )
682
    axes[1, 1].set_title("Acceptance Analysis")
×
683
    axes[1, 1].axis("off")
×
684

685
    plt.tight_layout()
×
686

687

688
def _plot_acceptance_diagnostics_blockaware(idata, config):
2✔
689
    """Acceptance diagnostics that also handle per‑channel series from blocked NUTS.
690

691
    If keys like ``accept_prob_channel_0`` are found in ``idata.sample_stats``,
692
    they are overlaid on the overall trace and included in the summary.
693
    """
694
    # Detect overall series
695
    accept_key = None
2✔
696
    if "accept_prob" in idata.sample_stats:
2✔
697
        accept_key = "accept_prob"
2✔
698
    elif "acceptance_rate" in idata.sample_stats:
2✔
699
        accept_key = "acceptance_rate"
×
700

701
    # Collect per-channel series
702
    channel_series = {}
2✔
703
    for key in idata.sample_stats:
2✔
704
        if isinstance(key, str) and key.startswith("accept_prob_channel_"):
2✔
705
            try:
2✔
706
                ch = int(key.rsplit("_", 1)[-1])
2✔
707
                channel_series[ch] = idata.sample_stats[key].values.flatten()
2✔
708
            except Exception:
×
709
                pass
×
710

711
    if accept_key is None and not channel_series:
2✔
712
        fig, ax = plt.subplots(figsize=config.figsize)
×
713
        ax.text(
×
714
            0.5,
715
            0.5,
716
            "Acceptance rate data unavailable",
717
            ha="center",
718
            va="center",
719
        )
720
        ax.set_title("Acceptance Rate Diagnostics")
×
721
        return
×
722

723
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
724

725
    # Overall or concatenated series
726
    if accept_key is not None:
2✔
727
        accept_rates = idata.sample_stats[accept_key].values.flatten()
2✔
728
    else:
729
        accept_rates = np.concatenate(list(channel_series.values()))
2✔
730

731
    sampler_type_attr = idata.attrs.get("sampler_type", "").lower()
2✔
732
    is_nuts = "nuts" in sampler_type_attr
2✔
733
    target_rate = idata.attrs.get(
2✔
734
        "target_accept_rate",
735
        idata.attrs.get("target_accept_prob", 0.8 if is_nuts else 0.44),
736
    )
737
    sampler_type = "NUTS" if is_nuts else "MH"
2✔
738

739
    # Ranges
740
    if is_nuts:
2✔
741
        good_range = (0.7, 0.9)
2✔
742
        low_range = (0.0, 0.6)
2✔
743
        high_range = (0.9, 1.0)
2✔
744
        concerning_range = (0.6, 0.7)
2✔
745
    else:
746
        good_range = (0.2, 0.5)
2✔
747
        low_range = (0.0, 0.2)
2✔
748
        high_range = (0.5, 1.0)
2✔
749
        concerning_range = (0.1, 0.2)
2✔
750

751
    # Background zones
752
    axes[0, 0].axhspan(good_range[0], good_range[1], alpha=0.1, color="green")
2✔
753
    axes[0, 0].axhspan(low_range[0], low_range[1], alpha=0.1, color="red")
2✔
754
    axes[0, 0].axhspan(high_range[0], high_range[1], alpha=0.1, color="orange")
2✔
755
    if concerning_range[1] > concerning_range[0]:
2✔
756
        axes[0, 0].axhspan(
2✔
757
            concerning_range[0], concerning_range[1], alpha=0.1, color="yellow"
758
        )
759

760
    # Plot traces: overall + per-channel overlays
761
    if accept_key is not None:
2✔
762
        axes[0, 0].plot(
2✔
763
            accept_rates, alpha=0.8, linewidth=1, color="blue", label="overall"
764
        )
765
    for ch in sorted(channel_series):
2✔
766
        axes[0, 0].plot(
2✔
767
            channel_series[ch], alpha=0.6, linewidth=1, label=f"ch {ch}"
768
        )
769
    axes[0, 0].axhline(
2✔
770
        target_rate,
771
        color="red",
772
        linestyle="--",
773
        linewidth=2,
774
        label=f"Target ({target_rate})",
775
    )
776

777
    # Running mean for overall
778
    window_size = max(10, len(accept_rates) // 50)
2✔
779
    if len(accept_rates) > window_size:
2✔
780
        running_mean = np.convolve(
2✔
781
            accept_rates, np.ones(window_size) / window_size, mode="valid"
782
        )
783
        axes[0, 0].plot(
2✔
784
            range(window_size // 2, window_size // 2 + len(running_mean)),
785
            running_mean,
786
            alpha=0.9,
787
            linewidth=3,
788
            color="purple",
789
            label=f"Running mean (w={window_size})",
790
        )
791

792
    axes[0, 0].set_xlabel("Iteration")
2✔
793
    axes[0, 0].set_ylabel("Acceptance Rate")
2✔
794
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
2✔
795
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
796
    axes[0, 0].grid(True, alpha=0.3)
2✔
797

798
    # Histogram and summary
799
    axes[0, 1].hist(
2✔
800
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
801
    )
802
    axes[0, 1].axvline(
2✔
803
        target_rate,
804
        color="red",
805
        linestyle="--",
806
        linewidth=2,
807
        label=f"Target ({target_rate})",
808
    )
809
    axes[0, 1].set_xlabel("Acceptance Rate")
2✔
810
    axes[0, 1].set_ylabel("Density")
2✔
811
    axes[0, 1].set_title("Acceptance Rate Distribution")
2✔
812
    axes[0, 1].legend()
2✔
813
    axes[0, 1].grid(True, alpha=0.3)
2✔
814

815
    if len(accept_rates) > 10:
2✔
816
        window_std = np.array(
2✔
817
            [
818
                np.std(accept_rates[max(0, i - 20) : i + 1])
819
                for i in range(len(accept_rates))
820
            ]
821
        )
822
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
2✔
823
        axes[1, 0].set_xlabel("Iteration")
2✔
824
        axes[1, 0].set_ylabel("Rolling Std")
2✔
825
        axes[1, 0].set_title("Rolling Standard Deviation")
2✔
826
        axes[1, 0].grid(True, alpha=0.3)
2✔
827
    else:
828
        axes[1, 0].text(
2✔
829
            0.5,
830
            0.5,
831
            "Acceptance variability\nanalysis unavailable",
832
            ha="center",
833
            va="center",
834
        )
835
        axes[1, 0].set_title("Acceptance Stability")
2✔
836

837
    # Summary text
838
    stats_text = [
2✔
839
        f"Sampler: {sampler_type}",
840
        f"Target: {target_rate:.3f}",
841
        f"Mean: {np.mean(accept_rates):.3f}",
842
        f"Std: {np.std(accept_rates):.3f}",
843
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
844
        f"Min: {np.min(accept_rates):.3f}",
845
        f"Max: {np.max(accept_rates):.3f}",
846
    ]
847
    if channel_series:
2✔
848
        stats_text.append("")
2✔
849
        stats_text.append("Per-channel means:")
2✔
850
        for ch in sorted(channel_series):
2✔
851
            stats_text.append(f"  ch {ch}: {np.mean(channel_series[ch]):.3f}")
2✔
852

853
    axes[1, 1].text(
2✔
854
        0.05,
855
        0.95,
856
        "\n".join(stats_text),
857
        transform=axes[1, 1].transAxes,
858
        fontsize=9,
859
        va="top",
860
        family="monospace",
861
    )
862
    axes[1, 1].set_title("Acceptance Analysis")
2✔
863
    axes[1, 1].axis("off")
2✔
864

865
    plt.tight_layout()
2✔
866

867

868
def _plot_nuts_diagnostics_blockaware(idata, config):
2✔
869
    """NUTS diagnostics supporting per‑channel (blocked) diagnostics fields.
870

871
    Overlays per‑channel series when keys like ``energy_channel_{j}`` or
872
    ``num_steps_channel_{j}`` are present.
873
    """
874
    # Presence of overall arrays
875
    has_energy = "energy" in idata.sample_stats
2✔
876
    has_potential = "potential_energy" in idata.sample_stats
2✔
877
    has_steps = "num_steps" in idata.sample_stats
2✔
878
    has_accept = "accept_prob" in idata.sample_stats
2✔
879

880
    # Collect per-channel data
881
    def _collect(base):
2✔
882
        out = {}
2✔
883
        prefix = f"{base}_channel_"
2✔
884
        for key in idata.sample_stats:
2✔
885
            if isinstance(key, str) and key.startswith(prefix):
2✔
886
                try:
2✔
887
                    ch = int(key.replace(prefix, ""))
2✔
888
                    out[ch] = idata.sample_stats[key].values.flatten()
2✔
889
                except Exception:
×
890
                    pass
×
891
        return out
2✔
892

893
    energy_ch = _collect("energy")
2✔
894
    potential_ch = _collect("potential_energy")
2✔
895
    steps_ch = _collect("num_steps")
2✔
896
    accept_ch = _collect("accept_prob")
2✔
897

898
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
899

900
    # Energy / potential
901
    ax = axes[0, 0]
2✔
902
    plotted = False
2✔
903
    if has_energy:
2✔
904
        ax.plot(
2✔
905
            idata.sample_stats.energy.values.flatten(),
906
            alpha=0.7,
907
            lw=1,
908
            label="H",
909
        )
910
        plotted = True
2✔
911
    if has_potential:
2✔
912
        ax.plot(
2✔
913
            idata.sample_stats.potential_energy.values.flatten(),
914
            alpha=0.7,
915
            lw=1,
916
            label="P",
917
        )
918
        plotted = True
2✔
919
    for ch in sorted(energy_ch):
2✔
920
        ax.plot(energy_ch[ch], alpha=0.5, lw=1, label=f"H ch {ch}")
2✔
921
        plotted = True
2✔
922
    for ch in sorted(potential_ch):
2✔
923
        ax.plot(potential_ch[ch], alpha=0.5, lw=1, label=f"P ch {ch}")
2✔
924
        plotted = True
2✔
925
    if not plotted:
2✔
926
        ax.text(0.5, 0.5, "Energy data\nunavailable", ha="center", va="center")
×
927
    ax.set_title("Energy Diagnostics")
2✔
928
    ax.set_xlabel("Iteration")
2✔
929
    ax.set_ylabel("Energy")
2✔
930
    ax.grid(True, alpha=0.3)
2✔
931
    if plotted:
2✔
932
        ax.legend(loc="best", fontsize="small")
2✔
933

934
    # Steps histogram
935
    ax = axes[0, 1]
2✔
936
    if has_steps:
2✔
937
        vals = idata.sample_stats.num_steps.values.flatten()
2✔
938
    else:
939
        vals = (
2✔
940
            np.concatenate(list(steps_ch.values()))
941
            if steps_ch
942
            else np.array([])
943
        )
944
    if vals.size:
2✔
945
        ax.hist(vals, bins=20, alpha=0.7, edgecolor="black")
2✔
946
        ax.set_title("Leapfrog Steps Distribution")
2✔
947
        ax.set_xlabel("Steps")
2✔
948
        ax.set_ylabel("Trajectories")
2✔
949
        ax.grid(True, alpha=0.3)
2✔
950
    else:
951
        ax.text(0.5, 0.5, "Steps data\nunavailable", ha="center", va="center")
×
952

953
    # Acceptance (overlay per-channel)
954
    ax = axes[1, 0]
2✔
955
    plotted = False
2✔
956
    if has_accept:
2✔
957
        ax.plot(
2✔
958
            idata.sample_stats.accept_prob.values.flatten(),
959
            alpha=0.8,
960
            lw=1,
961
            label="overall",
962
        )
963
        plotted = True
2✔
964
    for ch in sorted(accept_ch):
2✔
965
        ax.plot(accept_ch[ch], alpha=0.6, lw=1, label=f"ch {ch}")
2✔
966
        plotted = True
2✔
967
    if not plotted:
2✔
968
        ax.text(
×
969
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
970
        )
971
    ax.set_title("Acceptance Trace")
2✔
972
    ax.set_xlabel("Iteration")
2✔
973
    ax.set_ylabel("accept_prob")
2✔
974
    ax.grid(True, alpha=0.3)
2✔
975
    if plotted:
2✔
976
        ax.legend(loc="best", fontsize="small")
2✔
977

978
    # Summary text
979
    ax = axes[1, 1]
2✔
980
    lines = []
2✔
981
    if has_steps or steps_ch:
2✔
982
        lines.append("Steps summary:")
2✔
983
        if has_steps:
2✔
984
            s = idata.sample_stats.num_steps.values.flatten()
2✔
985
            lines.append(f"  overall μ={np.mean(s):.1f}, max={np.max(s):.0f}")
2✔
986
        for ch in sorted(steps_ch):
2✔
987
            s = steps_ch[ch]
2✔
988
            lines.append(f"  ch {ch} μ={np.mean(s):.1f}, max={np.max(s):.0f}")
2✔
989
        lines.append("")
2✔
990
    if has_accept or accept_ch:
2✔
991
        lines.append("Acceptance summary:")
2✔
992
        if has_accept:
2✔
993
            a = idata.sample_stats.accept_prob.values.flatten()
2✔
994
            lines.append(f"  overall μ={np.mean(a):.3f}")
2✔
995
        for ch in sorted(accept_ch):
2✔
996
            a = accept_ch[ch]
2✔
997
            lines.append(f"  ch {ch} μ={np.mean(a):.3f}")
2✔
998
    if lines:
2✔
999
        ax.text(
2✔
1000
            0.05,
1001
            0.95,
1002
            "\n".join(lines),
1003
            transform=ax.transAxes,
1004
            va="top",
1005
            family="monospace",
1006
        )
1007
    ax.set_title("NUTS Diagnostics Summary")
2✔
1008
    ax.axis("off")
2✔
1009

1010
    plt.tight_layout()
2✔
1011

1012

1013
def _create_sampler_diagnostics(idata, diag_dir, config):
2✔
1014
    """Create sampler-specific diagnostics."""
1015

1016
    # Better sampler detection - check sampler type first
1017
    sampler_type = (
2✔
1018
        idata.attrs["sampler_type"].lower()
1019
        if "sampler_type" in idata.attrs
1020
        else "unknown"
1021
    )
1022

1023
    # Check for NUTS-specific fields that MH definitely doesn't have
1024
    nuts_specific_fields = [
2✔
1025
        "energy",
1026
        "num_steps",
1027
        "tree_depth",
1028
        "diverging",
1029
        "energy_error",
1030
    ]
1031

1032
    has_nuts = (
2✔
1033
        any(field in idata.sample_stats for field in nuts_specific_fields)
1034
        or "nuts" in sampler_type
1035
    )
1036

1037
    # Check for MH-specific fields (exclude anything NUTS might have)
1038
    has_mh = "step_size_mean" in idata.sample_stats and not has_nuts
2✔
1039

1040
    if has_nuts:
2✔
1041

1042
        @safe_plot(f"{diag_dir}/nuts_diagnostics.png", config.dpi)
2✔
1043
        def plot_nuts():
2✔
1044
            _plot_nuts_diagnostics_blockaware(idata, config)
2✔
1045

1046
        plot_nuts()
2✔
1047
    elif has_mh:
2✔
1048

1049
        @safe_plot(f"{diag_dir}/mh_step_sizes.png", config.dpi)
2✔
1050
        def plot_mh():
2✔
1051
            _plot_mh_step_sizes(idata, config)
2✔
1052

1053
        plot_mh()
2✔
1054

1055

1056
def _plot_nuts_diagnostics(idata, config):
2✔
1057
    """NUTS diagnostics with enhanced information."""
1058
    # Determine available data to decide layout
1059
    has_energy = "energy" in idata.sample_stats
×
1060
    has_potential = "potential_energy" in idata.sample_stats
×
1061
    has_steps = "num_steps" in idata.sample_stats
×
1062
    has_accept = "accept_prob" in idata.sample_stats
×
1063
    has_divergences = "diverging" in idata.sample_stats
×
1064
    has_tree_depth = "tree_depth" in idata.sample_stats
×
1065
    has_energy_error = "energy_error" in idata.sample_stats
×
1066

1067
    # Create a 2x2 layout, potentially combining energy and potential on same plot
1068
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
×
1069

1070
    # Top-left: Energy diagnostics (combine Hamiltonian and Potential if both available)
1071
    energy_ax = axes[0, 0]
×
1072

1073
    if has_energy and has_potential:
×
1074
        # Both available - plot them together on one plot
1075
        energy = idata.sample_stats.energy.values.flatten()
×
1076
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1077

1078
        # Plot both energies on same axis
1079
        energy_ax.plot(
×
1080
            energy, alpha=0.7, linewidth=1, color="blue", label="Hamiltonian"
1081
        )
1082
        energy_ax.plot(
×
1083
            potential,
1084
            alpha=0.7,
1085
            linewidth=1,
1086
            color="orange",
1087
            label="Potential",
1088
        )
1089

1090
        # Add difference (which relates to kinetic energy)
1091
        energy_diff = energy - potential
×
1092
        # Create second y-axis for difference
1093
        ax2 = energy_ax.twinx()
×
1094
        ax2.plot(
×
1095
            energy_diff,
1096
            alpha=0.5,
1097
            linewidth=1,
1098
            color="red",
1099
            label="H - Potential (Kinetic)",
1100
            linestyle="--",
1101
        )
1102
        ax2.set_ylabel("Energy Difference", color="red")
×
1103
        ax2.tick_params(axis="y", labelcolor="red")
×
1104

1105
        energy_ax.set_xlabel("Iteration")
×
1106
        energy_ax.set_ylabel("Energy", color="blue")
×
1107
        energy_ax.tick_params(axis="y", labelcolor="blue")
×
1108
        energy_ax.set_title("Hamiltonian & Potential Energy")
×
1109
        energy_ax.legend(loc="best", fontsize="small")
×
1110
        energy_ax.grid(True, alpha=0.3)
×
1111

1112
        # Add statistics
1113
        energy_ax.text(
×
1114
            0.02,
1115
            0.98,
1116
            f"H: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}\nP: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}",
1117
            transform=energy_ax.transAxes,
1118
            fontsize=8,
1119
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1120
            verticalalignment="top",
1121
        )
1122

1123
    elif has_energy:
×
1124
        # Only Hamiltonian energy
1125
        energy = idata.sample_stats.energy.values.flatten()
×
1126
        energy_ax.plot(energy, alpha=0.7, linewidth=1, color="blue")
×
1127
        energy_ax.set_xlabel("Iteration")
×
1128
        energy_ax.set_ylabel("Hamiltonian Energy")
×
1129
        energy_ax.set_title("Hamiltonian Energy Trace")
×
1130
        energy_ax.grid(True, alpha=0.3)
×
1131

1132
    elif has_potential:
×
1133
        # Only potential energy
1134
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1135
        energy_ax.plot(potential, alpha=0.7, linewidth=1, color="orange")
×
1136
        energy_ax.set_xlabel("Iteration")
×
1137
        energy_ax.set_ylabel("Potential Energy")
×
1138
        energy_ax.set_title("Potential Energy Trace")
×
1139
        energy_ax.grid(True, alpha=0.3)
×
1140

1141
    else:
1142
        energy_ax.text(
×
1143
            0.5,
1144
            0.5,
1145
            "Energy data\nunavailable",
1146
            ha="center",
1147
            va="center",
1148
            transform=energy_ax.transAxes,
1149
        )
1150
        energy_ax.set_title("Energy Diagnostics")
×
1151

1152
    # Top-right: Sampling efficiency diagnostics
1153
    if has_steps:
×
1154
        steps_ax = axes[0, 1]
×
1155
        num_steps = idata.sample_stats.num_steps.values.flatten()
×
1156

1157
        # Show histogram with color zones for step efficiency
1158
        n, bins, edges = steps_ax.hist(
×
1159
            num_steps, bins=20, alpha=0.7, edgecolor="black"
1160
        )
1161

1162
        # Add shaded regions for different efficiency levels
1163
        # Green: efficient (tree depth ≤5, ~32 steps)
1164
        # Yellow: moderate (tree depth 6-8, ~64-256 steps)
1165
        # Red: inefficient (tree depth >8, >256 steps)
1166
        steps_ax.axvspan(
×
1167
            0, 64, alpha=0.1, color="green", label="Efficient (≤64)"
1168
        )
1169
        steps_ax.axvspan(
×
1170
            64, 256, alpha=0.1, color="yellow", label="Moderate (65-256)"
1171
        )
1172
        steps_ax.axvspan(
×
1173
            256,
1174
            np.max(num_steps),
1175
            alpha=0.1,
1176
            color="red",
1177
            label="Inefficient (>256)",
1178
        )
1179

1180
        # Add reference lines for different tree depths
1181
        for depth in [5, 7, 10]:  # Common tree depths
×
1182
            max_steps = 2**depth
×
1183
            steps_ax.axvline(
×
1184
                x=max_steps,
1185
                color="gray",
1186
                linestyle=":",
1187
                alpha=0.7,
1188
                linewidth=1,
1189
                label=f"2^{depth} ({max_steps})",
1190
            )
1191

1192
        steps_ax.set_xlabel("Leapfrog Steps")
×
1193
        steps_ax.set_ylabel("Trajectories")
×
1194
        steps_ax.set_title("Leapfrog Steps Distribution")
×
1195
        steps_ax.legend(loc="best", fontsize="small")
×
1196
        steps_ax.grid(True, alpha=0.3)
×
1197

1198
        # Add efficiency statistics
1199
        pct_inefficient = (num_steps > 256).mean() * 100
×
1200
        pct_moderate = ((num_steps > 64) & (num_steps <= 256)).mean() * 100
×
1201
        pct_efficient = (num_steps <= 64).mean() * 100
×
1202
        steps_ax.text(
×
1203
            0.02,
1204
            0.98,
1205
            f"Efficient: {pct_efficient:.1f}%\nModerate: {pct_moderate:.1f}%\nInefficient: {pct_inefficient:.1f}%\nMean steps: {np.mean(num_steps):.1f}",
1206
            transform=steps_ax.transAxes,
1207
            fontsize=7,
1208
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1209
            verticalalignment="top",
1210
        )
1211

1212
    else:
1213
        axes[0, 1].text(
×
1214
            0.5, 0.5, "Steps data\nunavailable", ha="center", va="center"
1215
        )
1216
        axes[0, 1].set_title("Sampling Steps")
×
1217

1218
    # Bottom-left: Acceptance and NS divergence diagnostics
1219
    accept_ax = axes[1, 0]
×
1220

1221
    if has_accept:
×
1222
        accept_prob = idata.sample_stats.accept_prob.values.flatten()
×
1223

1224
        # Plot acceptance probability with guidance zones
1225
        accept_ax.fill_between(
×
1226
            range(len(accept_prob)),
1227
            0.7,
1228
            0.9,
1229
            alpha=0.1,
1230
            color="green",
1231
            label="Good (0.7-0.9)",
1232
        )
1233
        accept_ax.fill_between(
×
1234
            range(len(accept_prob)),
1235
            0,
1236
            0.6,
1237
            alpha=0.1,
1238
            color="red",
1239
            label="Too low",
1240
        )
1241
        accept_ax.fill_between(
×
1242
            range(len(accept_prob)),
1243
            0.9,
1244
            1.0,
1245
            alpha=0.1,
1246
            color="orange",
1247
            label="Too high",
1248
        )
1249

1250
        accept_ax.plot(
×
1251
            accept_prob,
1252
            alpha=0.8,
1253
            linewidth=1,
1254
            color="blue",
1255
            label="Acceptance prob",
1256
        )
1257
        accept_ax.axhline(
×
1258
            0.8,
1259
            color="red",
1260
            linestyle="--",
1261
            linewidth=2,
1262
            label="NUTS target (0.8)",
1263
        )
1264
        accept_ax.set_xlabel("Iteration")
×
1265
        accept_ax.set_ylabel("Acceptance Probability")
×
1266
        accept_ax.set_title("NUTS Acceptance Diagnostic")
×
1267
        accept_ax.legend(loc="best", fontsize="small")
×
1268
        accept_ax.set_ylim(0, 1)
×
1269
        accept_ax.grid(True, alpha=0.3)
×
1270

1271
    else:
1272
        accept_ax.text(
×
1273
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1274
        )
1275
        accept_ax.set_title("Acceptance Diagnostic")
×
1276

1277
    # Bottom-right: Summary statistics and additional diagnostics
1278
    summary_ax = axes[1, 1]
×
1279

1280
    # Collect available statistics
1281
    stats_lines = []
×
1282

1283
    if has_energy:
×
1284
        energy = idata.sample_stats.energy.values.flatten()
×
1285
        stats_lines.append(
×
1286
            f"Energy: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}"
1287
        )
1288

1289
    if has_potential:
×
1290
        potential = idata.sample_stats.potential_energy.values.flatten()
×
1291
        stats_lines.append(
×
1292
            f"Potential: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}"
1293
        )
1294

1295
    if has_steps:
×
1296
        num_steps = idata.sample_stats.num_steps.values.flatten()
×
1297
        stats_lines.append(
×
1298
            f"Steps: μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1299
        )
1300
        stats_lines.append("")
×
1301

1302
    if has_tree_depth:
×
1303
        tree_depth = idata.sample_stats.tree_depth.values.flatten()
×
1304
        stats_lines.append(f"Tree depth: μ={np.mean(tree_depth):.1f}")
×
1305
        pct_max_depth = (tree_depth >= 10).mean() * 100
×
1306
        stats_lines.append(f"Max depth (≥10): {pct_max_depth:.1f}%")
×
1307

1308
    if has_divergences:
×
1309
        divergences = idata.sample_stats.diverging.values.flatten()
×
1310
        n_divergences = np.sum(divergences)
×
1311
        pct_divergent = n_divergences / len(divergences) * 100
×
1312
        stats_lines.append(
×
1313
            f"Divergent: {n_divergences}/{len(divergences)} ({pct_divergent:.2f}%)"
1314
        )
1315

1316
    if has_energy_error:
×
1317
        energy_error = idata.sample_stats.energy_error.values.flatten()
×
1318
        stats_lines.append(
×
1319
            f"Energy error: |μ|={np.mean(np.abs(energy_error)):.3f}"
1320
        )
1321

1322
    if not stats_lines:
×
1323
        summary_ax.text(
×
1324
            0.5,
1325
            0.5,
1326
            "No diagnostics\ndata available",
1327
            ha="center",
1328
            va="center",
1329
            transform=summary_ax.transAxes,
1330
        )
1331
        summary_ax.set_title("NUTS Statistics")
×
1332
        summary_ax.axis("off")
×
1333
    else:
1334
        summary_text = "\n".join(["NUTS Diagnostics:"] + [""] + stats_lines)
×
1335
        summary_ax.text(
×
1336
            0.05,
1337
            0.95,
1338
            summary_text,
1339
            transform=summary_ax.transAxes,
1340
            fontsize=10,
1341
            verticalalignment="top",
1342
            fontfamily="monospace",
1343
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1344
        )
1345
        summary_ax.set_title("NUTS Summary Statistics")
×
1346
        summary_ax.axis("off")
×
1347

1348
    plt.tight_layout()
×
1349

1350

1351
def _plot_mh_step_sizes(idata, config):
2✔
1352
    """MH step size diagnostics."""
1353
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1354

1355
    step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
1356
    step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
1357

1358
    # Step size evolution
1359
    axes[0, 0].plot(
2✔
1360
        step_means, alpha=0.7, linewidth=1, label="Mean", color="blue"
1361
    )
1362
    axes[0, 0].plot(
2✔
1363
        step_stds, alpha=0.7, linewidth=1, label="Std", color="orange"
1364
    )
1365
    axes[0, 0].set_xlabel("Iteration")
2✔
1366
    axes[0, 0].set_ylabel("Step Size")
2✔
1367
    axes[0, 0].set_title("Step Size Evolution")
2✔
1368
    axes[0, 0].legend()
2✔
1369
    axes[0, 0].grid(True, alpha=0.3)
2✔
1370

1371
    # Step size distributions
1372
    axes[0, 1].hist(step_means, bins=30, alpha=0.5, label="Mean", color="blue")
2✔
1373
    axes[0, 1].hist(step_stds, bins=30, alpha=0.5, label="Std", color="orange")
2✔
1374
    axes[0, 1].set_xlabel("Step Size")
2✔
1375
    axes[0, 1].set_ylabel("Count")
2✔
1376
    axes[0, 1].set_title("Step Size Distributions")
2✔
1377
    axes[0, 1].legend()
2✔
1378
    axes[0, 1].grid(True, alpha=0.3)
2✔
1379

1380
    # Step size adaptation quality
1381
    axes[1, 0].plot(step_means / step_stds, alpha=0.7, linewidth=1)
2✔
1382
    axes[1, 0].set_xlabel("Iteration")
2✔
1383
    axes[1, 0].set_ylabel("Mean / Std")
2✔
1384
    axes[1, 0].set_title("Step Size Consistency")
2✔
1385
    axes[1, 0].grid(True, alpha=0.3)
2✔
1386

1387
    # Summary statistics
1388
    summary_lines = [
2✔
1389
        "Step Size Summary:",
1390
        f"Final mean: {step_means[-1]:.4f}",
1391
        f"Final std: {step_stds[-1]:.4f}",
1392
        f"Mean of means: {np.mean(step_means):.4f}",
1393
        f"Mean of stds: {np.mean(step_stds):.4f}",
1394
        "",
1395
        "Adaptation Quality:",
1396
        f"CV of means: {np.std(step_means)/np.mean(step_means):.3f}",
1397
        f"CV of stds: {np.std(step_stds)/np.mean(step_stds):.3f}",
1398
    ]
1399

1400
    axes[1, 1].text(
2✔
1401
        0.05,
1402
        0.95,
1403
        "\n".join(summary_lines),
1404
        transform=axes[1, 1].transAxes,
1405
        fontsize=10,
1406
        verticalalignment="top",
1407
        fontfamily="monospace",
1408
    )
1409
    axes[1, 1].set_title("Step Size Statistics")
2✔
1410
    axes[1, 1].axis("off")
2✔
1411

1412
    plt.tight_layout()
2✔
1413

1414

1415
def _create_divergences_diagnostics(idata, diag_dir, config):
2✔
1416
    """Create divergences diagnostics for NUTS samplers."""
1417
    # Check if divergences data exists
1418
    has_divergences = "diverging" in idata.sample_stats
2✔
1419
    has_channel_divergences = any(
2✔
1420
        key.startswith("diverging_channel_") for key in idata.sample_stats
1421
    )
1422

1423
    if not has_divergences and not has_channel_divergences:
2✔
1424
        return  # Nothing to plot
2✔
1425

1426
    @safe_plot(f"{diag_dir}/divergences.png", config.dpi)
2✔
1427
    def plot_divergences():
2✔
1428
        _plot_divergences(idata, config)
2✔
1429

1430
    plot_divergences()
2✔
1431

1432

1433
def _plot_divergences(idata, config):
2✔
1434
    """Plot divergences diagnostics."""
1435
    # Collect all divergence data
1436
    divergences_data = {}
2✔
1437

1438
    # Check for main divergences (single chain NUTS)
1439
    if "diverging" in idata.sample_stats:
2✔
1440
        divergences_data["main"] = (
2✔
1441
            idata.sample_stats.diverging.values.flatten()
1442
        )
1443

1444
    # Check for channel-specific divergences (blocked NUTS)
1445
    channel_divergences = {}
2✔
1446
    for key in idata.sample_stats:
2✔
1447
        if key.startswith("diverging_channel_"):
2✔
1448
            channel_idx = key.replace("diverging_channel_", "")
2✔
1449
            channel_divergences[int(channel_idx)] = idata.sample_stats[
2✔
1450
                key
1451
            ].values.flatten()
1452

1453
    if channel_divergences:
2✔
1454
        divergences_data.update(channel_divergences)
2✔
1455

1456
    if not divergences_data:
2✔
1457
        fig, ax = plt.subplots(figsize=config.figsize)
×
1458
        ax.text(
×
1459
            0.5, 0.5, "No divergence data available", ha="center", va="center"
1460
        )
1461
        ax.set_title("Divergences Diagnostics")
×
1462
        return
×
1463

1464
    # Create subplot layout
1465
    n_plots = len(divergences_data)
2✔
1466
    if n_plots == 1:
2✔
1467
        fig, axes = plt.subplots(1, 2, figsize=config.figsize)
2✔
1468
        trace_ax, summary_ax = axes
2✔
1469
    else:
1470
        # Multiple plots - arrange in grid
1471
        cols = 2
2✔
1472
        rows = (n_plots + 1) // cols  # Ceiling division
2✔
1473
        fig, axes = plt.subplots(rows, cols, figsize=config.figsize)
2✔
1474
        if rows == 1:
2✔
1475
            axes = axes.reshape(1, -1)
2✔
1476
        axes = axes.flatten()
2✔
1477

1478
        # Last plot goes in summary_ax if odd number
1479
        if n_plots % 2 == 1:
2✔
1480
            trace_axes = axes[:-1]
×
1481
            summary_ax = axes[-1]
×
1482
        else:
1483
            trace_axes = axes
2✔
1484
            summary_ax = None
2✔
1485

1486
    # Plot divergences traces
1487
    total_divergences = 0
2✔
1488
    total_iterations = 0
2✔
1489

1490
    plot_idx = 0
2✔
1491
    for label, div_values in divergences_data.items():
2✔
1492
        if label == "main":
2✔
1493
            title = "NUTS Divergences"
2✔
1494
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1495
        else:
1496
            title = f"Channel {label} Divergences"
2✔
1497
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1498
            plot_idx += 1
2✔
1499

1500
        # Plot divergence indicators (where divergences occur)
1501
        div_indices = np.where(div_values)[0]
2✔
1502
        ax.scatter(
2✔
1503
            div_indices,
1504
            np.ones_like(div_indices),
1505
            color="red",
1506
            marker="x",
1507
            s=50,
1508
            linewidth=2,
1509
            label="Divergent",
1510
            alpha=0.8,
1511
        )
1512

1513
        # Add background shading for divergent regions
1514
        if len(div_indices) > 0:
2✔
1515
            for idx in div_indices:
2✔
1516
                ax.axvspan(idx - 0.5, idx + 0.5, alpha=0.2, color="red")
2✔
1517

1518
        ax.set_xlabel("Iteration")
2✔
1519
        ax.set_ylabel("Divergence Indicator")
2✔
1520
        ax.set_title(title)
2✔
1521
        ax.set_yticks([0, 1])
2✔
1522
        ax.set_yticklabels(["No", "Yes"])
2✔
1523
        ax.grid(True, alpha=0.3)
2✔
1524

1525
        # Add statistics
1526
        n_divergent = np.sum(div_values)
2✔
1527
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1528
        stats_text = f"{n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
2✔
1529
        ax.text(
2✔
1530
            0.02,
1531
            0.98,
1532
            stats_text,
1533
            transform=ax.transAxes,
1534
            fontsize=10,
1535
            bbox=dict(boxstyle="round", facecolor="lightcoral", alpha=0.8),
1536
            verticalalignment="top",
1537
        )
1538

1539
        total_divergences += n_divergent
2✔
1540
        total_iterations += len(div_values)
2✔
1541

1542
        # Legend only if there are divergences
1543
        if n_divergent > 0:
2✔
1544
            ax.legend(loc="upper right", fontsize="small")
2✔
1545

1546
    # Summary plot
1547
    if summary_ax is not None and n_plots > 1:
2✔
1548
        summary_ax.text(
×
1549
            0.05,
1550
            0.95,
1551
            _get_divergences_summary(divergences_data),
1552
            transform=summary_ax.transAxes,
1553
            fontsize=12,
1554
            verticalalignment="top",
1555
            fontfamily="monospace",
1556
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1557
        )
1558
        summary_ax.set_title("Divergences Summary")
×
1559
        summary_ax.axis("off")
×
1560
    elif n_plots == 1:
2✔
1561
        axes[1].text(
2✔
1562
            0.05,
1563
            0.95,
1564
            _get_divergences_summary(divergences_data),
1565
            transform=axes[1].transAxes,
1566
            fontsize=12,
1567
            verticalalignment="top",
1568
            fontfamily="monospace",
1569
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1570
        )
1571
        axes[1].set_title("Divergences Summary")
2✔
1572
        axes[1].axis("off")
2✔
1573

1574
    # Overall title
1575
    overall_pct = (
2✔
1576
        total_divergences / total_iterations * 100
1577
        if total_iterations > 0
1578
        else 0
1579
    )
1580
    fig.suptitle(f"Overall Divergences: {overall_pct:.2f}%")
2✔
1581

1582
    plt.tight_layout()
2✔
1583

1584

1585
def _get_divergences_summary(divergences_data):
2✔
1586
    """Generate text summary of divergences."""
1587
    lines = ["Divergences Summary:", ""]
2✔
1588

1589
    total_divergences = 0
2✔
1590
    total_iterations = 0
2✔
1591

1592
    for label, div_values in divergences_data.items():
2✔
1593
        n_divergent = np.sum(div_values)
2✔
1594
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1595

1596
        if label == "main":
2✔
1597
            lines.append(
2✔
1598
                f"NUTS: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1599
            )
1600
        else:
1601
            lines.append(
×
1602
                f"Channel {label}: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1603
            )
1604

1605
        total_divergences += n_divergent
2✔
1606
        total_iterations += len(div_values)
2✔
1607

1608
    lines.append("")
2✔
1609
    overall_pct = (
2✔
1610
        total_divergences / total_iterations * 100
1611
        if total_iterations > 0
1612
        else 0
1613
    )
1614
    lines.append(
2✔
1615
        f"Total: {total_divergences}/{total_iterations} ({overall_pct:.2f}%)"
1616
    )
1617

1618
    lines.append("")
2✔
1619
    lines.append("Interpretation:")
2✔
1620
    if overall_pct == 0:
2✔
1621
        lines.append("  ✓ No divergences detected")
2✔
1622
        lines.append("    Sampling appears well-behaved")
2✔
UNCOV
1623
    elif overall_pct < 0.1:
×
1624
        lines.append("  ~ Few divergences")
×
1625
        lines.append("    Generally good, but monitor")
×
UNCOV
1626
    elif overall_pct < 1.0:
×
1627
        lines.append("  ⚠ Some divergences detected")
×
1628
        lines.append("    May indicate sampling issues")
×
1629
    else:
UNCOV
1630
        lines.append("  ✗ Many divergences!")
×
UNCOV
1631
        lines.append("    Significant sampling problems")
×
UNCOV
1632
        lines.append("    Consider model reparameterization")
×
1633

1634
    return "\n".join(lines)
2✔
1635

1636

1637
def _group_parameters_simple(idata):
2✔
1638
    """Simple parameter grouping for counting."""
1639
    param_groups = {"phi": [], "delta": [], "weights": [], "other": []}
2✔
1640

1641
    for param in idata.posterior.data_vars:
2✔
1642
        if param.startswith("phi"):
2✔
1643
            param_groups["phi"].append(param)
2✔
1644
        elif param.startswith("delta"):
2✔
1645
            param_groups["delta"].append(param)
2✔
1646
        elif param.startswith("weights"):
2✔
1647
            param_groups["weights"].append(param)
2✔
1648
        else:
1649
            param_groups["other"].append(param)
×
1650

1651
    return {k: v for k, v in param_groups.items() if v}
2✔
1652

1653

1654
def generate_diagnostics_summary(idata, outdir):
2✔
1655
    """Generate comprehensive text summary."""
1656
    summary = []
2✔
1657
    summary.append("=== MCMC Diagnostics Summary ===\n")
2✔
1658

1659
    # Basic info
1660
    attrs = getattr(idata, "attrs", {}) or {}
2✔
1661
    if not hasattr(attrs, "get"):
2✔
1662
        attrs = dict(attrs)
×
1663

1664
    n_samples = idata.posterior.sizes.get("draw", 0)
2✔
1665
    n_chains = idata.posterior.sizes.get("chain", 1)
2✔
1666
    n_params = len(list(idata.posterior.data_vars))
2✔
1667
    sampler_type = attrs.get("sampler_type", "Unknown")
2✔
1668

1669
    summary.append(f"Sampler: {sampler_type}")
2✔
1670
    summary.append(
2✔
1671
        f"Samples: {n_samples} per chain × {n_chains} chains = {n_samples * n_chains} total"
1672
    )
1673
    summary.append(f"Parameters: {n_params}")
2✔
1674

1675
    # Parameter breakdown
1676
    param_groups = _group_parameters_simple(idata)
2✔
1677
    if param_groups:
2✔
1678
        param_summary = ", ".join(
2✔
1679
            [f"{k}: {len(v)}" for k, v in param_groups.items()]
1680
        )
1681
        summary.append(f"Parameter groups: {param_summary}")
2✔
1682

1683
    # ESS
1684
    try:
2✔
1685
        ess = idata.attrs.get("ess")
2✔
1686
        ess_values = ess[~np.isnan(ess)]
2✔
1687

1688
        if len(ess_values) > 0:
2✔
1689
            summary.append(
2✔
1690
                f"\nESS: min={ess_values.min():.0f}, mean={ess_values.mean():.0f}, max={ess_values.max():.0f}"
1691
            )
1692
            summary.append(f"ESS ≥ 400: {(ess_values >= 400).mean()*100:.1f}%")
2✔
1693
    except Exception as e:
×
1694
        summary.append(f"\nESS: unavailable")
×
1695

1696
    # Acceptance
1697
    accept_key = None
2✔
1698
    if "accept_prob" in idata.sample_stats:
2✔
1699
        accept_key = "accept_prob"
2✔
1700
    elif "acceptance_rate" in idata.sample_stats:
2✔
1701
        accept_key = "acceptance_rate"
×
1702

1703
    if accept_key is not None:
2✔
1704
        accept_rate = idata.sample_stats[accept_key].values.mean()
2✔
1705
        target_rate = attrs.get(
2✔
1706
            "target_accept_rate", attrs.get("target_accept_prob", 0.44)
1707
        )
1708
        summary.append(
2✔
1709
            f"Acceptance rate: {accept_rate:.3f} (target: {target_rate:.3f})"
1710
        )
1711
    else:
1712
        # Blocked NUTS: compute a combined mean from per‑channel keys if present
1713
        channel_means = []
2✔
1714
        for key in idata.sample_stats:
2✔
1715
            if isinstance(key, str) and key.startswith("accept_prob_channel_"):
2✔
1716
                try:
2✔
1717
                    channel_means.append(
2✔
1718
                        float(idata.sample_stats[key].values.mean())
1719
                    )
1720
                except Exception:
×
1721
                    pass
×
1722
        if channel_means:
2✔
1723
            target_rate = attrs.get(
2✔
1724
                "target_accept_rate", attrs.get("target_accept_prob", 0.8)
1725
            )
1726
            summary.append(
2✔
1727
                f"Acceptance rate (per-channel mean): {np.mean(channel_means):.3f} (target: {target_rate:.3f})"
1728
            )
1729

1730
    # PSD accuracy diagnostics (requires true_psd in attrs)
1731
    has_true_psd = "true_psd" in attrs
2✔
1732

1733
    if has_true_psd:
2✔
1734
        coverage_level = attrs.get("coverage_level")
2✔
1735
        coverage_label = (
2✔
1736
            f"{int(round(coverage_level * 100))}% interval coverage"
1737
            if coverage_level is not None
1738
            else "Interval coverage"
1739
        )
1740

1741
        def _format_riae_line(value, errorbars, prefix="  "):
2✔
1742
            line = f"{prefix}RIAE: {value:.3f}"
2✔
1743
            if errorbars:
2✔
1744
                q05, q25, median, q75, q95 = errorbars
2✔
1745
                line += f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
2✔
1746
            summary.append(line)
2✔
1747

1748
        def _format_coverage_line(value, prefix="  "):
2✔
1749
            if value is None:
2✔
1750
                return
×
1751
            summary.append(f"{prefix}{coverage_label}: {value * 100:.1f}%")
2✔
1752

1753
        summary.append("\nPSD accuracy diagnostics:")
2✔
1754

1755
        if "riae" in attrs:
2✔
1756
            _format_riae_line(attrs["riae"], attrs.get("riae_errorbars"))
2✔
1757
        if "coverage" in attrs:
2✔
1758
            _format_coverage_line(attrs["coverage"])
2✔
1759

1760
        channel_indices = sorted(
2✔
1761
            int(key.replace("riae_ch", ""))
1762
            for key in attrs.keys()
1763
            if key.startswith("riae_ch")
1764
        )
1765

1766
        for idx in channel_indices:
2✔
1767
            metrics = []
×
1768
            riae_key = f"riae_ch{idx}"
×
1769
            cov_key = f"coverage_ch{idx}"
×
1770
            error_key = f"riae_errorbars_ch{idx}"
×
1771

1772
            if riae_key in attrs:
×
1773
                riae_line = f"RIAE {attrs[riae_key]:.3f}"
×
1774
                errorbars = attrs.get(error_key)
×
1775
                if errorbars:
×
1776
                    q05, _, median, _, q95 = errorbars
×
1777
                    riae_line += (
×
1778
                        f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
1779
                    )
1780
                metrics.append(riae_line)
×
1781

1782
            if cov_key in attrs:
×
1783
                metrics.append(f"{coverage_label} {attrs[cov_key] * 100:.1f}%")
×
1784

1785
            if metrics:
×
1786
                summary.append(f"  Channel {idx}: " + "; ".join(metrics))
×
1787

1788
    # Overall assessment
1789
    try:
2✔
1790
        # rhat_values = []
1791
        # rhat_good = (rhat_values <= 1.01).mean() * 100
1792
        if len(ess_values) > 0:
2✔
1793
            ess_good = (ess_values >= 400).mean() * 100
2✔
1794
            summary.append(f"\nOverall Convergence Assessment:")
2✔
1795
            if ess_good >= 90:  # and rhat_good >= 90:
2✔
1796
                summary.append("  Status: EXCELLENT ✓")
×
1797
            elif ess_good >= 75:  # and rhat_good >= 75:
2✔
1798
                summary.append("  Status: GOOD ✓")
×
1799
            else:
1800
                summary.append("  Status: NEEDS ATTENTION ⚠")
2✔
1801
    except:
×
1802
        pass
×
1803

1804
    summary_text = "\n".join(summary)
2✔
1805

1806
    if outdir:
2✔
1807
        with open(f"{outdir}/diagnostics_summary.txt", "w") as f:
2✔
1808
            f.write(summary_text)
2✔
1809

1810
    logger.info(f"\n{summary_text}\n")
2✔
1811
    return summary_text
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc