• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18150943472

01 Oct 2025 04:03AM UTC coverage: 78.59% (-0.3%) from 78.908%
18150943472

push

github

avivajpeyi
add fixes for plotters

396 of 498 branches covered (79.52%)

Branch coverage included in aggregate %.

126 of 165 new or added lines in 2 files covered. (76.36%)

54 existing lines in 3 files now uncovered.

3014 of 3841 relevant lines covered (78.47%)

1.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

65.33
/src/log_psplines/plotting/diagnostics.py
1
import os
2✔
2
from dataclasses import dataclass
2✔
3
from functools import wraps
2✔
4
from typing import Callable, Optional
2✔
5

6
import arviz as az
2✔
7
import matplotlib.pyplot as plt
2✔
8
import numpy as np
2✔
9

10

11
@dataclass
2✔
12
class DiagnosticsConfig:
2✔
13
    figsize: tuple = (12, 8)
2✔
14
    dpi: int = 150
2✔
15
    ess_threshold: int = 400
2✔
16
    rhat_threshold: float = 1.01
2✔
17

18

19
def safe_plot(filename: str, dpi: int = 150):
2✔
20
    """Decorator for safe plotting with error handling."""
21

22
    def decorator(plot_func: Callable):
2✔
23
        @wraps(plot_func)
2✔
24
        def wrapper(*args, **kwargs):
2✔
25
            try:
2✔
26
                result = plot_func(*args, **kwargs)
2✔
27
                plt.savefig(filename, dpi=dpi, bbox_inches="tight")
2✔
28
                plt.close()
2✔
29
                return True
2✔
30
            except Exception as e:
×
31
                print(
×
32
                    f"Warning: Failed to create {os.path.basename(filename)}: {e}"
33
                )
34
                plt.close("all")
×
35
                return False
×
36

37
        return wrapper
2✔
38

39
    return decorator
2✔
40

41

42
def plot_trace(idata: az.InferenceData, compact=True) -> plt.Figure:
2✔
43
    groups = {
2✔
44
        "delta": [
45
            v for v in idata.posterior.data_vars if v.startswith("delta")
46
        ],
47
        "phi": [v for v in idata.posterior.data_vars if v.startswith("phi")],
48
        "weights": [
49
            v for v in idata.posterior.data_vars if v.startswith("weights")
50
        ],
51
    }
52

53
    if compact:
2✔
54
        nrows = 3
2✔
55
    else:
NEW
56
        nrows = len(groups)
×
57
    fig, axes = plt.subplots(nrows, 2, figsize=(7, 3 * nrows))
2✔
58

59
    for row, (group_name, vars) in enumerate(groups.items()):
2✔
60

61
        # if vars are more than 1, and compact, then we need to repeat the axes
62
        if compact:
2✔
63
            group_axes = axes[row, :].reshape(1, 2)
2✔
64
            group_axes = np.repeat(group_axes, len(vars), axis=0)
2✔
65
        else:
NEW
66
            group_axes = axes[row, :]
×
67

68
        group_axes[0, 0].set_title(
2✔
69
            f"{group_name.capitalize()} Parameters", fontsize=14
70
        )
71

72
        for i, var in enumerate(vars):
2✔
73
            data = idata.posterior[
2✔
74
                var
75
            ].values  # shape is (nchain, nsamples, ndim) if ndim>1 else (nchain, nsamples)
76
            if data.ndim == 3:
2✔
77
                data = data[0].T  # shape is now (ndim, nsamples)
2✔
78

79
            ax_trace = group_axes[i, 0] if compact else group_axes[0]
2✔
80
            ax_hist = group_axes[i, 1] if compact else group_axes[1]
2✔
81
            ax_trace.set_ylabel(group_name, fontsize=8)
2✔
82
            ax_trace.set_xlabel("MCMC Step", fontsize=8)
2✔
83
            ax_hist.set_xlabel(group_name, fontsize=8)
2✔
84
            # place ylabel on right side of hist
85
            ax_hist.yaxis.set_label_position("right")
2✔
86
            ax_hist.set_ylabel("Density", fontsize=8, rotation=270, labelpad=0)
2✔
87

88
            # remove axes yspine for hist
89
            ax_hist.spines["left"].set_visible(False)
2✔
90
            ax_hist.spines["right"].set_visible(False)
2✔
91
            ax_hist.spines["top"].set_visible(False)
2✔
92
            ax_hist.set_yticks([])  # remove y ticks
2✔
93
            ax_hist.yaxis.set_ticks_position("none")
2✔
94

95
            ax_trace.spines["right"].set_visible(False)
2✔
96
            ax_trace.spines["top"].set_visible(False)
2✔
97

98
            color = f"C{i}"
2✔
99
            label = f"{var}"
2✔
100
            if group_name in ["phi", "delta"]:
2✔
101
                ax_trace.set_yscale("log")
2✔
102
                ax_hist.set_xscale("log")
2✔
103

104
            for p in data:
2✔
105
                ax_trace.plot(p, color=color, alpha=0.7, label=label)
2✔
106

107
                # if phi or delta, use log scale for hist-x, log for trace y
108
                if group_name in ["phi", "delta"]:
2✔
109
                    bins = np.logspace(
2✔
110
                        np.log10(np.min(p)), np.log10(np.max(p)), 30
111
                    )
112
                    logp = np.log(p)
2✔
113
                    log_grid, log_pdf = az.kde(logp)
2✔
114
                    grid = np.exp(log_grid)
2✔
115
                    pdf = log_pdf / grid  # change of variables
2✔
116
                else:
117
                    bins = 30
2✔
118
                    grid, pdf = az.kde(p)
2✔
119
                ax_hist.plot(grid, pdf, color=color, label=label)
2✔
120
                ax_hist.hist(
2✔
121
                    p, bins=bins, density=True, color=color, alpha=0.3
122
                )
123

124
                # KDE plot instead of histogram
125

126
    plt.suptitle("Parameter Traces", fontsize=16)
2✔
127
    plt.tight_layout()
2✔
128
    return fig
2✔
129

130

131
def plot_diagnostics(
2✔
132
    idata: az.InferenceData,
133
    outdir: str,
134
    n_channels: Optional[int] = None,
135
    n_freq: Optional[int] = None,
136
    runtime: Optional[float] = None,
137
    config: Optional[DiagnosticsConfig] = None,
138
) -> None:
139
    """
140
    Create essential MCMC diagnostics in organized subdirectories.
141
    """
142
    if outdir is None:
2✔
143
        return
×
144

145
    if config is None:
2✔
146
        config = DiagnosticsConfig()
2✔
147

148
    # Create diagnostics subdirectory
149
    diag_dir = os.path.join(outdir, "diagnostics")
2✔
150
    os.makedirs(diag_dir, exist_ok=True)
2✔
151

152
    print("Generating MCMC diagnostics...")
2✔
153

154
    # Generate summary report
155
    generate_diagnostics_summary(idata, diag_dir)
2✔
156

157
    # Essential diagnostics only
158
    _create_essential_diagnostics(
2✔
159
        idata, diag_dir, config, n_channels, n_freq, runtime
160
    )
161

162
    print(f"Diagnostics saved to {diag_dir}/")
2✔
163

164

165
def _create_essential_diagnostics(
2✔
166
    idata, diag_dir, config, n_channels, n_freq, runtime
167
):
168
    """Create only the essential diagnostic plots."""
169

170
    # 1. ArviZ trace plots
171
    @safe_plot(f"{diag_dir}/trace_plots.png", config.dpi)
2✔
172
    def create_trace_plots():
2✔
173
        return plot_trace(idata)
2✔
174

175
    create_trace_plots()
2✔
176

177
    # 2. Summary dashboard with key convergence metrics
178
    @safe_plot(f"{diag_dir}/summary_dashboard.png", config.dpi)
2✔
179
    def plot_summary():
2✔
180
        _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime)
2✔
181

182
    plot_summary()
2✔
183

184
    # 3. Log posterior diagnostics
185
    @safe_plot(f"{diag_dir}/log_posterior.png", config.dpi)
2✔
186
    def plot_lp():
2✔
187
        _plot_log_posterior(idata, config)
2✔
188

189
    plot_lp()
2✔
190

191
    # 4. Acceptance rate diagnostics
192
    @safe_plot(f"{diag_dir}/acceptance_diagnostics.png", config.dpi)
2✔
193
    def plot_acceptance():
2✔
194
        _plot_acceptance_diagnostics(idata, config)
2✔
195

196
    plot_acceptance()
2✔
197

198
    # 5. Sampler-specific diagnostics
199
    _create_sampler_diagnostics(idata, diag_dir, config)
2✔
200

201
    # 6. Divergences diagnostics (for NUTS only)
202
    _create_divergences_diagnostics(idata, diag_dir, config)
2✔
203

204

205
def _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime):
2✔
206
    """Essential summary dashboard."""
207

208
    # Check if R-hat is available
209
    rhat_available = False
2✔
210
    try:
2✔
211
        rhat = az.rhat(idata).to_array().values.flatten()
2✔
212
        rhat_values = rhat[~np.isnan(rhat)]
2✔
213
        rhat_available = len(rhat_values) > 0
2✔
214
    except Exception:
×
215
        pass
×
216

217
    # Create subplot layout based on data availability
218
    if rhat_available:
2✔
219
        # Full 2x3 layout when R-hat is available
220
        fig, axes = plt.subplots(2, 3, figsize=config.figsize)
×
221
        rhat_ax = axes[0, 1]
×
222
        scatter_ax = axes[0, 2]
×
223
        meta_ax = axes[1, 0]
×
224
        param_ax = axes[1, 1]
×
225
        status_ax = axes[1, 2]
×
226
    else:
227
        # Reduced 2x2 layout when R-hat is not available
228
        fig, axes = plt.subplots(
2✔
229
            2, 2, figsize=(config.figsize[0] * 0.8, config.figsize[1])
230
        )
231
        # Rearrange axes for 2x2 layout
232
        meta_ax = axes[0, 0]
2✔
233
        param_ax = axes[0, 1]
2✔
234
        status_ax = axes[1, 0]
2✔
235
        # Use bottom-right for additional info or leave empty
236
        axes[1, 1].axis("off")  # Hide unused subplot
2✔
237

238
    # ESS histogram (always in top-left)
239
    try:
2✔
240
        ess = az.ess(idata).to_array().values.flatten()
2✔
241
        ess_values = ess[~np.isnan(ess)]
2✔
242

243
        if len(ess_values) > 0:
2✔
244
            # Add color zones for ESS quality
245
            ess_thresholds = [
2✔
246
                (400, "red", "--", "Minimum reliable ESS"),
247
                (1000, "orange", "--", "Good ESS"),
248
                (
249
                    np.max(ess_values),
250
                    "green",
251
                    ":",
252
                    f"Max ESS = {np.max(ess_values):.0f}",
253
                ),
254
            ]
255

256
            ax_ess = axes[0, 0]  # Always available
2✔
257
            n, bins, patches = ax_ess.hist(
2✔
258
                ess_values, bins=30, alpha=0.7, edgecolor="black"
259
            )
260

261
            # Add reference lines
262
            for threshold, color, style, label in ess_thresholds:
2✔
263
                ax_ess.axvline(
2✔
264
                    x=threshold,
265
                    color=color,
266
                    linestyle=style,
267
                    linewidth=2 if threshold < np.max(ess_values) else 1,
268
                    alpha=0.8,
269
                    label=label,
270
                )
271

272
            ax_ess.set_xlabel("ESS")
2✔
273
            ax_ess.set_ylabel("Count")
2✔
274
            ax_ess.set_title("ESS Distribution")
2✔
275
            ax_ess.legend(loc="upper right", fontsize="x-small")
2✔
276
            ax_ess.grid(True, alpha=0.3)
2✔
277

278
            pct_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
279
            ax_ess.text(
2✔
280
                0.02,
281
                0.98,
282
                f"Min: {ess_values.min():.0f}\nMean: {ess_values.mean():.0f}\n≥{config.ess_threshold}: {pct_good:.1f}%",
283
                transform=ax_ess.transAxes,
284
                fontsize=10,
285
                verticalalignment="top",
286
                bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.7),
287
            )
288
    except Exception:
×
289
        axes[0, 0].text(0.5, 0.5, "ESS unavailable", ha="center", va="center")
×
290
        axes[0, 0].set_title("ESS Distribution")
×
291

292
    # R-hat histogram and scatter (only when R-hat is available)
293
    if rhat_available and "rhat_ax" in locals():
2✔
294
        # Add shaded regions for R-hat quality
295
        rhat_ax.axvspan(
×
296
            1.0,
297
            config.rhat_threshold,
298
            alpha=0.1,
299
            color="green",
300
            label="Converged (≤1.01)",
301
        )
302
        rhat_ax.axvspan(
×
303
            config.rhat_threshold,
304
            1.1,
305
            alpha=0.1,
306
            color="yellow",
307
            label="Concerning (1.01-1.10)",
308
        )
309
        rhat_ax.axvspan(
×
310
            1.1,
311
            rhat_values.max(),
312
            alpha=0.1,
313
            color="red",
314
            label="Not converged (>1.10)",
315
        )
316

317
        rhat_ax.hist(rhat_values, bins=30, alpha=0.7, edgecolor="black")
×
318
        rhat_ax.axvline(
×
319
            1.0,
320
            color="green",
321
            linestyle="--",
322
            linewidth=2,
323
            label="Perfectly mixed",
324
        )
325
        rhat_ax.axvline(
×
326
            config.rhat_threshold,
327
            color="orange",
328
            linestyle="--",
329
            linewidth=2,
330
            label="Acceptable",
331
        )
332
        rhat_ax.set_xlabel("R-hat")
×
333
        rhat_ax.set_ylabel("Count")
×
334
        rhat_ax.set_title("R-hat Distribution")
×
335
        rhat_ax.legend(loc="upper right", fontsize="x-small")
×
336
        rhat_ax.grid(True, alpha=0.3)
×
337

338
        pct_excellent = (rhat_values <= 1.01).mean() * 100
×
339
        pct_concerning = (
×
340
            (rhat_values > 1.01) & (rhat_values <= 1.1)
341
        ).mean() * 100
342
        pct_bad = (rhat_values > 1.1).mean() * 100
×
343
        rhat_ax.text(
×
344
            0.02,
345
            0.98,
346
            f"Max: {rhat_values.max():.3f}\nMean: {rhat_values.mean():.3f}\n≤1.01: {pct_excellent:.1f}%\n>1.10: {pct_bad:.1f}%",
347
            transform=rhat_ax.transAxes,
348
            fontsize=9,
349
            verticalalignment="top",
350
            bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.7),
351
        )
352

353
    # ESS vs R-hat scatter (only when R-hat is available)
354
    if rhat_available and "scatter_ax" in locals():
2✔
355
        try:
×
356
            if len(ess_values) > 0 and len(rhat_values) > 0:
×
357
                ess_all = az.ess(idata).to_array().values.flatten()
×
358
                rhat_all = az.rhat(idata).to_array().values.flatten()
×
359
                valid_mask = ~(np.isnan(ess_all) | np.isnan(rhat_all))
×
360

361
                if np.sum(valid_mask) > 0:
×
362
                    scatter_ax.scatter(
×
363
                        rhat_all[valid_mask],
364
                        ess_all[valid_mask],
365
                        alpha=0.6,
366
                        s=20,
367
                    )
368
                    scatter_ax.axvline(
×
369
                        config.rhat_threshold,
370
                        color="red",
371
                        linestyle="--",
372
                        alpha=0.7,
373
                    )
374
                    scatter_ax.axhline(
×
375
                        config.ess_threshold,
376
                        color="orange",
377
                        linestyle="--",
378
                        alpha=0.7,
379
                    )
380
                    scatter_ax.set_xlabel("R-hat")
×
381
                    scatter_ax.set_ylabel("ESS")
×
382
                    scatter_ax.set_title("Convergence Overview")
×
383
                    scatter_ax.grid(True, alpha=0.3)
×
384
        except Exception:
×
385
            scatter_ax.text(
×
386
                0.5, 0.5, "Scatter unavailable", ha="center", va="center"
387
            )
388
            scatter_ax.set_title("Convergence Overview")
×
389

390
    # Analysis metadata
391
    try:
2✔
392
        n_samples = idata.posterior.sizes.get("draw", 0)
2✔
393
        n_chains = idata.posterior.sizes.get("chain", 1)
2✔
394
        n_params = len(list(idata.posterior.data_vars))
2✔
395
        sampler_type = idata.attrs["sampler_type"]
2✔
396

397
        metadata_lines = [
2✔
398
            f"Sampler: {sampler_type}",
399
            f"Samples: {n_samples} × {n_chains} chains",
400
            f"Parameters: {n_params}",
401
        ]
402
        if n_channels is not None:
2✔
403
            metadata_lines.append(f"Channels: {n_channels}")
×
404
        if n_freq is not None:
2✔
405
            metadata_lines.append(f"Frequencies: {n_freq}")
×
406
        if runtime is not None:
2✔
407
            metadata_lines.append(f"Runtime: {runtime:.2f}s")
×
408

409
        meta_ax.text(
2✔
410
            0.05,
411
            0.95,
412
            "\n".join(metadata_lines),
413
            transform=meta_ax.transAxes,
414
            fontsize=12,
415
            verticalalignment="top",
416
            fontfamily="monospace",
417
        )
418
        meta_ax.set_title("Analysis Summary")
2✔
419
        meta_ax.axis("off")
2✔
420
    except Exception:
×
421
        meta_ax.text(
×
422
            0.5, 0.5, "Metadata unavailable", ha="center", va="center"
423
        )
424
        meta_ax.set_title("Analysis Summary")
×
425
        meta_ax.axis("off")
×
426

427
    # Parameter counts (placeholder - turned off as not helpful per user feedback)
428
    try:
2✔
429
        # Just show a summary text instead of the full bar chart
430
        param_groups = _group_parameters_simple(idata)
2✔
431
        if param_groups:
2✔
432
            summary_text = "Parameter Summary:\n"
2✔
433
            for group_name, params in param_groups.items():
2✔
434
                if params:  # Only show non-empty groups
2✔
435
                    summary_text += f"{group_name}: {len(params)}\n"
2✔
436
            param_ax.text(
2✔
437
                0.05,
438
                0.95,
439
                summary_text.strip(),
440
                transform=param_ax.transAxes,
441
                fontsize=11,
442
                verticalalignment="top",
443
                fontfamily="monospace",
444
            )
445
        param_ax.set_title("Parameter Summary")
2✔
446
        param_ax.axis("off")  # Don't show axes for this simple text summary
2✔
447
    except Exception:
×
448
        param_ax.text(
×
449
            0.5,
450
            0.5,
451
            "Parameter summary\nunavailable",
452
            ha="center",
453
            va="center",
454
        )
455
        param_ax.set_title("Parameter Summary")
×
456
        param_ax.axis("off")
×
457

458
    # Convergence status
459
    try:
2✔
460
        ess_retrieved = []
2✔
461
        rhat_retrieved = []
2✔
462
        try:
2✔
463
            ess_retrieved = az.ess(idata).to_array().values.flatten()
2✔
464
            ess_retrieved = ess_retrieved[~np.isnan(ess_retrieved)]
2✔
465
        except:
×
466
            pass
×
467
        try:
2✔
468
            rhat_retrieved = az.rhat(idata).to_array().values.flatten()
2✔
469
            rhat_retrieved = rhat_retrieved[~np.isnan(rhat_retrieved)]
2✔
470
        except:
×
471
            pass
×
472

473
        status_lines = ["Convergence Status:"]
2✔
474

475
        if len(ess_retrieved) > 0:
2✔
476
            ess_good = (ess_retrieved >= config.ess_threshold).mean() * 100
2✔
477
            status_lines.append(
2✔
478
                f"ESS ≥ {config.ess_threshold}: {ess_good:.0f}%"
479
            )
480

481
        if len(rhat_retrieved) > 0:
2✔
482
            rhat_good = (rhat_retrieved <= config.rhat_threshold).mean() * 100
×
483
            status_lines.append(
×
484
                f"R-hat ≤ {config.rhat_threshold}: {rhat_good:.0f}%"
485
            )
486

487
        status_lines.append("")
2✔
488
        status_lines.append("Overall Status:")
2✔
489

490
        if len(ess_retrieved) > 0 and len(rhat_retrieved) > 0:
2✔
491
            if ess_good >= 90 and rhat_good >= 90:
×
492
                status_lines.append("✓ EXCELLENT")
×
493
                color = "green"
×
494
            elif ess_good >= 75 and rhat_good >= 75:
×
495
                status_lines.append("✓ GOOD")
×
496
                color = "orange"
×
497
            else:
498
                status_lines.append("⚠ NEEDS ATTENTION")
×
499
                color = "red"
×
500
        elif len(ess_retrieved) > 0:
2✔
501
            if ess_good >= 90:
2✔
502
                status_lines.append("✓ GOOD (based on ESS only)")
×
503
                color = "green"
×
504
            elif ess_good >= 75:
2✔
505
                status_lines.append("✓ ADEQUATE (based on ESS only)")
×
506
                color = "orange"
×
507
            else:
508
                status_lines.append("⚠ NEEDS ATTENTION")
2✔
509
                color = "red"
2✔
510
        else:
511
            status_lines.append("? UNABLE TO ASSESS")
×
512
            color = "gray"
×
513

514
        status_ax.text(
2✔
515
            0.05,
516
            0.95,
517
            "\n".join(status_lines),
518
            transform=status_ax.transAxes,
519
            fontsize=11,
520
            verticalalignment="top",
521
            fontfamily="monospace",
522
            color=color,
523
        )
524
        status_ax.set_title("Convergence Status")
2✔
525
        status_ax.axis("off")
2✔
526
    except Exception:
×
527
        status_ax.text(
×
528
            0.5, 0.5, "Status unavailable", ha="center", va="center"
529
        )
530
        status_ax.set_title("Convergence Status")
×
531
        status_ax.axis("off")
×
532

533
    plt.tight_layout()
2✔
534

535

536
def _plot_log_posterior(idata, config):
2✔
537
    """Log posterior diagnostics."""
538
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
539

540
    # Check for lp first, then log_likelihood
541
    if "lp" in idata.sample_stats:
2✔
542
        lp_values = idata.sample_stats["lp"].values.flatten()
2✔
543
        var_name = "lp"
2✔
544
        title_prefix = "Log Posterior"
2✔
545
    elif "log_likelihood" in idata.sample_stats:
×
546
        lp_values = idata.sample_stats["log_likelihood"].values.flatten()
×
547
        var_name = "log_likelihood"
×
548
        title_prefix = "Log Likelihood"
×
549
    else:
550
        # Create a fallback layout when no posterior data available
551
        fig, axes = plt.subplots(1, 1, figsize=config.figsize)
×
552
        axes.text(
×
553
            0.5,
554
            0.5,
555
            "No log posterior\nor log likelihood\navailable",
556
            ha="center",
557
            va="center",
558
            fontsize=14,
559
        )
560
        axes.set_title("Log Posterior Diagnostics")
×
561
        axes.axis("off")
×
562
        plt.tight_layout()
×
563
        return
×
564

565
    # Trace plot with running mean overlaid
566
    axes[0, 0].plot(
2✔
567
        lp_values, alpha=0.7, linewidth=1, color="blue", label="Trace"
568
    )
569

570
    # Add running mean on the same plot
571
    window_size = max(10, len(lp_values) // 100)
2✔
572
    if len(lp_values) > window_size:
2✔
573
        running_mean = np.convolve(
×
574
            lp_values, np.ones(window_size) / window_size, mode="valid"
575
        )
576
        axes[0, 0].plot(
×
577
            range(window_size // 2, window_size // 2 + len(running_mean)),
578
            running_mean,
579
            alpha=0.9,
580
            linewidth=3,
581
            color="red",
582
            label=f"Running mean (w={window_size})",
583
        )
584

585
    axes[0, 0].set_xlabel("Iteration")
2✔
586
    axes[0, 0].set_ylabel(title_prefix)
2✔
587
    axes[0, 0].set_title(f"{title_prefix} Trace with Running Mean")
2✔
588
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
589
    axes[0, 0].grid(True, alpha=0.3)
2✔
590

591
    # Distribution
592
    axes[0, 1].hist(
2✔
593
        lp_values, bins=50, alpha=0.7, density=True, edgecolor="black"
594
    )
595
    axes[0, 1].axvline(
2✔
596
        np.mean(lp_values),
597
        color="red",
598
        linestyle="--",
599
        linewidth=2,
600
        label=f"Mean: {np.mean(lp_values):.1f}",
601
    )
602
    axes[0, 1].set_xlabel(title_prefix)
2✔
603
    axes[0, 1].set_ylabel("Density")
2✔
604
    axes[0, 1].set_title(f"{title_prefix} Distribution")
2✔
605
    axes[0, 1].legend(loc="best", fontsize="small")
2✔
606
    axes[0, 1].grid(True, alpha=0.3)
2✔
607

608
    # Step-to-step changes
609
    lp_diff = np.diff(lp_values)
2✔
610
    axes[1, 0].plot(lp_diff, alpha=0.5, linewidth=1)
2✔
611
    axes[1, 0].axhline(0, color="red", linestyle="--", alpha=0.7)
2✔
612
    axes[1, 0].axhline(
2✔
613
        np.mean(lp_diff),
614
        color="blue",
615
        linestyle="--",
616
        alpha=0.7,
617
        label=f"Mean change: {np.mean(lp_diff):.1f}",
618
    )
619
    axes[1, 0].set_xlabel("Iteration")
2✔
620
    axes[1, 0].set_ylabel(f"{title_prefix} Difference")
2✔
621
    axes[1, 0].set_title("Step-to-Step Changes")
2✔
622
    axes[1, 0].legend(loc="best", fontsize="small")
2✔
623
    axes[1, 0].grid(True, alpha=0.3)
2✔
624

625
    # Summary statistics
626
    stats_lines = [
2✔
627
        f"Mean: {np.mean(lp_values):.2f}",
628
        f"Std: {np.std(lp_values):.2f}",
629
        f"Min: {np.min(lp_values):.2f}",
630
        f"Max: {np.max(lp_values):.2f}",
631
        f"Range: {np.max(lp_values) - np.min(lp_values):.2f}",
632
        "",
633
        "Stability:",
634
        f"Final variation: {np.std(lp_values[-len(lp_values)//4:]):.2f}",
635
    ]
636

637
    axes[1, 1].text(
2✔
638
        0.05,
639
        0.95,
640
        "\n".join(stats_lines),
641
        transform=axes[1, 1].transAxes,
642
        fontsize=10,
643
        verticalalignment="top",
644
        fontfamily="monospace",
645
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
646
    )
647
    axes[1, 1].set_title("Posterior Statistics")
2✔
648
    axes[1, 1].axis("off")
2✔
649

650
    plt.tight_layout()
2✔
651

652

653
def _plot_acceptance_diagnostics(idata, config):
2✔
654
    """Acceptance rate diagnostics."""
655
    accept_key = None
2✔
656
    if "accept_prob" in idata.sample_stats:
2✔
657
        accept_key = "accept_prob"
2✔
658
    elif "acceptance_rate" in idata.sample_stats:
×
659
        accept_key = "acceptance_rate"
×
660

661
    if accept_key is None:
2✔
662
        fig, ax = plt.subplots(figsize=config.figsize)
×
663
        ax.text(
×
664
            0.5,
665
            0.5,
666
            "Acceptance rate data unavailable",
667
            ha="center",
668
            va="center",
669
        )
670
        ax.set_title("Acceptance Rate Diagnostics")
×
671
        return
×
672

673
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
674

675
    accept_rates = idata.sample_stats[accept_key].values.flatten()
2✔
676
    target_rate = getattr(idata.attrs, "target_accept_rate", 0.44)
2✔
677
    sampler_type = (
2✔
678
        idata.attrs["sampler_type"].lower()
679
        if "sampler_type" in idata.attrs
680
        else "unknown"
681
    )
682
    sampler_type = "NUTS" if "nuts" in sampler_type else "MH"
2✔
683

684
    # Define good ranges based on sampler
685
    if target_rate > 0.5:  # NUTS
2✔
686
        good_range = (0.7, 0.9)
×
687
        low_range = (0.0, 0.6)
×
688
        high_range = (0.9, 1.0)
×
689
        concerning_range = (0.6, 0.7)
×
690
    else:  # MH
691
        good_range = (0.2, 0.5)
2✔
692
        low_range = (0.0, 0.2)
2✔
693
        high_range = (0.5, 1.0)
2✔
694
        concerning_range = (0.1, 0.2)  # MH can be lower than NUTS
2✔
695

696
    # Trace plot with color zones
697
    # Add background zones
698
    axes[0, 0].axhspan(
2✔
699
        good_range[0],
700
        good_range[1],
701
        alpha=0.1,
702
        color="green",
703
        label=f"Good ({good_range[0]:.1f}-{good_range[1]:.1f})",
704
    )
705
    axes[0, 0].axhspan(
2✔
706
        low_range[0], low_range[1], alpha=0.1, color="red", label="Too low"
707
    )
708
    axes[0, 0].axhspan(
2✔
709
        high_range[0],
710
        high_range[1],
711
        alpha=0.1,
712
        color="orange",
713
        label="Too high",
714
    )
715
    if concerning_range[1] > concerning_range[0]:
2✔
716
        axes[0, 0].axhspan(
2✔
717
            concerning_range[0],
718
            concerning_range[1],
719
            alpha=0.1,
720
            color="yellow",
721
            label="Concerning",
722
        )
723

724
    # Main trace plot
725
    axes[0, 0].plot(
2✔
726
        accept_rates, alpha=0.8, linewidth=1, color="blue", label="Trace"
727
    )
728
    axes[0, 0].axhline(
2✔
729
        target_rate,
730
        color="red",
731
        linestyle="--",
732
        linewidth=2,
733
        label=f"Target ({target_rate})",
734
    )
735

736
    # Add running average on the same plot
737
    window_size = max(10, len(accept_rates) // 50)
2✔
738
    if len(accept_rates) > window_size:
2✔
739
        running_mean = np.convolve(
×
740
            accept_rates, np.ones(window_size) / window_size, mode="valid"
741
        )
742
        axes[0, 0].plot(
×
743
            range(window_size // 2, window_size // 2 + len(running_mean)),
744
            running_mean,
745
            alpha=0.9,
746
            linewidth=3,
747
            color="purple",
748
            label=f"Running mean (w={window_size})",
749
        )
750

751
    axes[0, 0].set_xlabel("Iteration")
2✔
752
    axes[0, 0].set_ylabel("Acceptance Rate")
2✔
753
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
2✔
754
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
755
    axes[0, 0].grid(True, alpha=0.3)
2✔
756

757
    # Add interpretation text
758
    interpretation = f"{sampler_type} aims for {target_rate:.2f}."
2✔
759
    if target_rate > 0.5:
2✔
760
        interpretation += " Green: efficient sampling."
×
761
    else:
762
        interpretation += " MH adapts to find optimal rate."
2✔
763
    axes[0, 0].text(
2✔
764
        0.02,
765
        0.02,
766
        interpretation,
767
        transform=axes[0, 0].transAxes,
768
        fontsize=9,
769
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
770
    )
771

772
    # Distribution
773
    axes[0, 1].hist(
2✔
774
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
775
    )
776
    axes[0, 1].axvline(
2✔
777
        target_rate,
778
        color="red",
779
        linestyle="--",
780
        linewidth=2,
781
        label=f"Target ({target_rate})",
782
    )
783
    axes[0, 1].set_xlabel("Acceptance Rate")
2✔
784
    axes[0, 1].set_ylabel("Density")
2✔
785
    axes[0, 1].set_title("Acceptance Rate Distribution")
2✔
786
    axes[0, 1].legend()
2✔
787
    axes[0, 1].grid(True, alpha=0.3)
2✔
788

789
    # Since running means are already overlaid on the main plot, use the bottom row for additional info
790

791
    # Additional acceptance analysis - evolution over time
792
    if len(accept_rates) > 10:
2✔
793
        # Show moving standard deviation or coefficient of variation
794
        window_std = np.array(
×
795
            [
796
                np.std(accept_rates[max(0, i - 20) : i + 1])
797
                for i in range(len(accept_rates))
798
            ]
799
        )
800
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
×
801
        axes[1, 0].set_xlabel("Iteration")
×
802
        axes[1, 0].set_ylabel("Rolling Std")
×
803
        axes[1, 0].set_title("Rolling Standard Deviation")
×
804
        axes[1, 0].grid(True, alpha=0.3)
×
805
    else:
806
        axes[1, 0].text(
2✔
807
            0.5,
808
            0.5,
809
            "Acceptance variability\nanalysis unavailable",
810
            ha="center",
811
            va="center",
812
        )
813
        axes[1, 0].set_title("Acceptance Stability")
2✔
814

815
    # Summary statistics (expanded)
816
    stats_text = [
2✔
817
        f"Sampler: {sampler_type}",
818
        f"Target: {target_rate:.3f}",
819
        f"Mean: {np.mean(accept_rates):.3f}",
820
        f"Std: {np.std(accept_rates):.3f}",
821
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
822
        f"Min: {np.min(accept_rates):.3f}",
823
        f"Max: {np.max(accept_rates):.3f}",
824
        "",
825
        "Stability:",
826
        f"Final std: {np.std(accept_rates[-len(accept_rates)//4:]):.3f}",
827
    ]
828

829
    axes[1, 1].text(
2✔
830
        0.05,
831
        0.95,
832
        "\n".join(stats_text),
833
        transform=axes[1, 1].transAxes,
834
        fontsize=9,
835
        verticalalignment="top",
836
        fontfamily="monospace",
837
    )
838
    axes[1, 1].set_title("Acceptance Analysis")
2✔
839
    axes[1, 1].axis("off")
2✔
840

841
    plt.tight_layout()
2✔
842

843

844
def _create_sampler_diagnostics(idata, diag_dir, config):
2✔
845
    """Create sampler-specific diagnostics."""
846

847
    # Better sampler detection - check sampler type first
848
    sampler_type = (
2✔
849
        idata.attrs["sampler_type"].lower()
850
        if "sampler_type" in idata.attrs
851
        else "unknown"
852
    )
853

854
    # Check for NUTS-specific fields that MH definitely doesn't have
855
    nuts_specific_fields = [
2✔
856
        "energy",
857
        "num_steps",
858
        "tree_depth",
859
        "diverging",
860
        "energy_error",
861
    ]
862

863
    has_nuts = (
2✔
864
        any(field in idata.sample_stats for field in nuts_specific_fields)
865
        or "nuts" in sampler_type
866
    )
867

868
    # Check for MH-specific fields (exclude anything NUTS might have)
869
    has_mh = "step_size_mean" in idata.sample_stats and not has_nuts
2✔
870

871
    if has_nuts:
2✔
872

873
        @safe_plot(f"{diag_dir}/nuts_diagnostics.png", config.dpi)
2✔
874
        def plot_nuts():
2✔
875
            _plot_nuts_diagnostics(idata, config)
2✔
876

877
        plot_nuts()
2✔
878
    elif has_mh:
2✔
879

880
        @safe_plot(f"{diag_dir}/mh_step_sizes.png", config.dpi)
2✔
881
        def plot_mh():
2✔
882
            _plot_mh_step_sizes(idata, config)
2✔
883

884
        plot_mh()
2✔
885

886

887
def _plot_nuts_diagnostics(idata, config):
2✔
888
    """NUTS diagnostics with enhanced information."""
889
    # Determine available data to decide layout
890
    has_energy = "energy" in idata.sample_stats
2✔
891
    has_potential = "potential_energy" in idata.sample_stats
2✔
892
    has_steps = "num_steps" in idata.sample_stats
2✔
893
    has_accept = "accept_prob" in idata.sample_stats
2✔
894
    has_divergences = "diverging" in idata.sample_stats
2✔
895
    has_tree_depth = "tree_depth" in idata.sample_stats
2✔
896
    has_energy_error = "energy_error" in idata.sample_stats
2✔
897

898
    # Create a 2x2 layout, potentially combining energy and potential on same plot
899
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
900

901
    # Top-left: Energy diagnostics (combine Hamiltonian and Potential if both available)
902
    energy_ax = axes[0, 0]
2✔
903

904
    if has_energy and has_potential:
2✔
905
        # Both available - plot them together on one plot
906
        energy = idata.sample_stats.energy.values.flatten()
2✔
907
        potential = idata.sample_stats.potential_energy.values.flatten()
2✔
908

909
        # Plot both energies on same axis
910
        energy_ax.plot(
2✔
911
            energy, alpha=0.7, linewidth=1, color="blue", label="Hamiltonian"
912
        )
913
        energy_ax.plot(
2✔
914
            potential,
915
            alpha=0.7,
916
            linewidth=1,
917
            color="orange",
918
            label="Potential",
919
        )
920

921
        # Add difference (which relates to kinetic energy)
922
        energy_diff = energy - potential
2✔
923
        # Create second y-axis for difference
924
        ax2 = energy_ax.twinx()
2✔
925
        ax2.plot(
2✔
926
            energy_diff,
927
            alpha=0.5,
928
            linewidth=1,
929
            color="red",
930
            label="H - Potential (Kinetic)",
931
            linestyle="--",
932
        )
933
        ax2.set_ylabel("Energy Difference", color="red")
2✔
934
        ax2.tick_params(axis="y", labelcolor="red")
2✔
935

936
        energy_ax.set_xlabel("Iteration")
2✔
937
        energy_ax.set_ylabel("Energy", color="blue")
2✔
938
        energy_ax.tick_params(axis="y", labelcolor="blue")
2✔
939
        energy_ax.set_title("Hamiltonian & Potential Energy")
2✔
940
        energy_ax.legend(loc="best", fontsize="small")
2✔
941
        energy_ax.grid(True, alpha=0.3)
2✔
942

943
        # Add statistics
944
        energy_ax.text(
2✔
945
            0.02,
946
            0.98,
947
            f"H: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}\nP: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}",
948
            transform=energy_ax.transAxes,
949
            fontsize=8,
950
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
951
            verticalalignment="top",
952
        )
953

954
    elif has_energy:
×
955
        # Only Hamiltonian energy
956
        energy = idata.sample_stats.energy.values.flatten()
×
957
        energy_ax.plot(energy, alpha=0.7, linewidth=1, color="blue")
×
958
        energy_ax.set_xlabel("Iteration")
×
959
        energy_ax.set_ylabel("Hamiltonian Energy")
×
960
        energy_ax.set_title("Hamiltonian Energy Trace")
×
961
        energy_ax.grid(True, alpha=0.3)
×
962

963
    elif has_potential:
×
964
        # Only potential energy
965
        potential = idata.sample_stats.potential_energy.values.flatten()
×
966
        energy_ax.plot(potential, alpha=0.7, linewidth=1, color="orange")
×
967
        energy_ax.set_xlabel("Iteration")
×
968
        energy_ax.set_ylabel("Potential Energy")
×
969
        energy_ax.set_title("Potential Energy Trace")
×
970
        energy_ax.grid(True, alpha=0.3)
×
971

972
    else:
973
        energy_ax.text(
×
974
            0.5,
975
            0.5,
976
            "Energy data\nunavailable",
977
            ha="center",
978
            va="center",
979
            transform=energy_ax.transAxes,
980
        )
981
        energy_ax.set_title("Energy Diagnostics")
×
982

983
    # Top-right: Sampling efficiency diagnostics
984
    if has_steps:
2✔
985
        steps_ax = axes[0, 1]
2✔
986
        num_steps = idata.sample_stats.num_steps.values.flatten()
2✔
987

988
        # Show histogram with color zones for step efficiency
989
        n, bins, edges = steps_ax.hist(
2✔
990
            num_steps, bins=20, alpha=0.7, edgecolor="black"
991
        )
992

993
        # Add shaded regions for different efficiency levels
994
        # Green: efficient (tree depth ≤5, ~32 steps)
995
        # Yellow: moderate (tree depth 6-8, ~64-256 steps)
996
        # Red: inefficient (tree depth >8, >256 steps)
997
        steps_ax.axvspan(
2✔
998
            0, 64, alpha=0.1, color="green", label="Efficient (≤64)"
999
        )
1000
        steps_ax.axvspan(
2✔
1001
            64, 256, alpha=0.1, color="yellow", label="Moderate (65-256)"
1002
        )
1003
        steps_ax.axvspan(
2✔
1004
            256,
1005
            np.max(num_steps),
1006
            alpha=0.1,
1007
            color="red",
1008
            label="Inefficient (>256)",
1009
        )
1010

1011
        # Add reference lines for different tree depths
1012
        for depth in [5, 7, 10]:  # Common tree depths
2✔
1013
            max_steps = 2**depth
2✔
1014
            steps_ax.axvline(
2✔
1015
                x=max_steps,
1016
                color="gray",
1017
                linestyle=":",
1018
                alpha=0.7,
1019
                linewidth=1,
1020
                label=f"2^{depth} ({max_steps})",
1021
            )
1022

1023
        steps_ax.set_xlabel("Leapfrog Steps")
2✔
1024
        steps_ax.set_ylabel("Trajectories")
2✔
1025
        steps_ax.set_title("Leapfrog Steps Distribution")
2✔
1026
        steps_ax.legend(loc="best", fontsize="small")
2✔
1027
        steps_ax.grid(True, alpha=0.3)
2✔
1028

1029
        # Add efficiency statistics
1030
        pct_inefficient = (num_steps > 256).mean() * 100
2✔
1031
        pct_moderate = ((num_steps > 64) & (num_steps <= 256)).mean() * 100
2✔
1032
        pct_efficient = (num_steps <= 64).mean() * 100
2✔
1033
        steps_ax.text(
2✔
1034
            0.02,
1035
            0.98,
1036
            f"Efficient: {pct_efficient:.1f}%\nModerate: {pct_moderate:.1f}%\nInefficient: {pct_inefficient:.1f}%\nMean steps: {np.mean(num_steps):.1f}",
1037
            transform=steps_ax.transAxes,
1038
            fontsize=7,
1039
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
1040
            verticalalignment="top",
1041
        )
1042

1043
    else:
1044
        axes[0, 1].text(
×
1045
            0.5, 0.5, "Steps data\nunavailable", ha="center", va="center"
1046
        )
1047
        axes[0, 1].set_title("Sampling Steps")
×
1048

1049
    # Bottom-left: Acceptance and NS divergence diagnostics
1050
    accept_ax = axes[1, 0]
2✔
1051

1052
    if has_accept:
2✔
1053
        accept_prob = idata.sample_stats.accept_prob.values.flatten()
2✔
1054

1055
        # Plot acceptance probability with guidance zones
1056
        accept_ax.fill_between(
2✔
1057
            range(len(accept_prob)),
1058
            0.7,
1059
            0.9,
1060
            alpha=0.1,
1061
            color="green",
1062
            label="Good (0.7-0.9)",
1063
        )
1064
        accept_ax.fill_between(
2✔
1065
            range(len(accept_prob)),
1066
            0,
1067
            0.6,
1068
            alpha=0.1,
1069
            color="red",
1070
            label="Too low",
1071
        )
1072
        accept_ax.fill_between(
2✔
1073
            range(len(accept_prob)),
1074
            0.9,
1075
            1.0,
1076
            alpha=0.1,
1077
            color="orange",
1078
            label="Too high",
1079
        )
1080

1081
        accept_ax.plot(
2✔
1082
            accept_prob,
1083
            alpha=0.8,
1084
            linewidth=1,
1085
            color="blue",
1086
            label="Acceptance prob",
1087
        )
1088
        accept_ax.axhline(
2✔
1089
            0.8,
1090
            color="red",
1091
            linestyle="--",
1092
            linewidth=2,
1093
            label="NUTS target (0.8)",
1094
        )
1095
        accept_ax.set_xlabel("Iteration")
2✔
1096
        accept_ax.set_ylabel("Acceptance Probability")
2✔
1097
        accept_ax.set_title("NUTS Acceptance Diagnostic")
2✔
1098
        accept_ax.legend(loc="best", fontsize="small")
2✔
1099
        accept_ax.set_ylim(0, 1)
2✔
1100
        accept_ax.grid(True, alpha=0.3)
2✔
1101

1102
    else:
1103
        accept_ax.text(
×
1104
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1105
        )
1106
        accept_ax.set_title("Acceptance Diagnostic")
×
1107

1108
    # Bottom-right: Summary statistics and additional diagnostics
1109
    summary_ax = axes[1, 1]
2✔
1110

1111
    # Collect available statistics
1112
    stats_lines = []
2✔
1113

1114
    if has_energy:
2✔
1115
        energy = idata.sample_stats.energy.values.flatten()
2✔
1116
        stats_lines.append(
2✔
1117
            f"Energy: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}"
1118
        )
1119

1120
    if has_potential:
2✔
1121
        potential = idata.sample_stats.potential_energy.values.flatten()
2✔
1122
        stats_lines.append(
2✔
1123
            f"Potential: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}"
1124
        )
1125

1126
    if has_steps:
2✔
1127
        num_steps = idata.sample_stats.num_steps.values.flatten()
2✔
1128
        stats_lines.append(
2✔
1129
            f"Steps: μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1130
        )
1131
        stats_lines.append("")
2✔
1132

1133
    if has_tree_depth:
2✔
1134
        tree_depth = idata.sample_stats.tree_depth.values.flatten()
×
1135
        stats_lines.append(f"Tree depth: μ={np.mean(tree_depth):.1f}")
×
1136
        pct_max_depth = (tree_depth >= 10).mean() * 100
×
1137
        stats_lines.append(f"Max depth (≥10): {pct_max_depth:.1f}%")
×
1138

1139
    if has_divergences:
2✔
1140
        divergences = idata.sample_stats.diverging.values.flatten()
2✔
1141
        n_divergences = np.sum(divergences)
2✔
1142
        pct_divergent = n_divergences / len(divergences) * 100
2✔
1143
        stats_lines.append(
2✔
1144
            f"Divergent: {n_divergences}/{len(divergences)} ({pct_divergent:.2f}%)"
1145
        )
1146

1147
    if has_energy_error:
2✔
1148
        energy_error = idata.sample_stats.energy_error.values.flatten()
×
1149
        stats_lines.append(
×
1150
            f"Energy error: |μ|={np.mean(np.abs(energy_error)):.3f}"
1151
        )
1152

1153
    if not stats_lines:
2✔
1154
        summary_ax.text(
×
1155
            0.5,
1156
            0.5,
1157
            "No diagnostics\ndata available",
1158
            ha="center",
1159
            va="center",
1160
            transform=summary_ax.transAxes,
1161
        )
1162
        summary_ax.set_title("NUTS Statistics")
×
1163
        summary_ax.axis("off")
×
1164
    else:
1165
        summary_text = "\n".join(["NUTS Diagnostics:"] + [""] + stats_lines)
2✔
1166
        summary_ax.text(
2✔
1167
            0.05,
1168
            0.95,
1169
            summary_text,
1170
            transform=summary_ax.transAxes,
1171
            fontsize=10,
1172
            verticalalignment="top",
1173
            fontfamily="monospace",
1174
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1175
        )
1176
        summary_ax.set_title("NUTS Summary Statistics")
2✔
1177
        summary_ax.axis("off")
2✔
1178

1179
    plt.tight_layout()
2✔
1180

1181

1182
def _plot_mh_step_sizes(idata, config):
2✔
1183
    """MH step size diagnostics."""
1184
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1185

1186
    step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
1187
    step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
1188

1189
    # Step size evolution
1190
    axes[0, 0].plot(
2✔
1191
        step_means, alpha=0.7, linewidth=1, label="Mean", color="blue"
1192
    )
1193
    axes[0, 0].plot(
2✔
1194
        step_stds, alpha=0.7, linewidth=1, label="Std", color="orange"
1195
    )
1196
    axes[0, 0].set_xlabel("Iteration")
2✔
1197
    axes[0, 0].set_ylabel("Step Size")
2✔
1198
    axes[0, 0].set_title("Step Size Evolution")
2✔
1199
    axes[0, 0].legend()
2✔
1200
    axes[0, 0].grid(True, alpha=0.3)
2✔
1201

1202
    # Step size distributions
1203
    axes[0, 1].hist(step_means, bins=30, alpha=0.5, label="Mean", color="blue")
2✔
1204
    axes[0, 1].hist(step_stds, bins=30, alpha=0.5, label="Std", color="orange")
2✔
1205
    axes[0, 1].set_xlabel("Step Size")
2✔
1206
    axes[0, 1].set_ylabel("Count")
2✔
1207
    axes[0, 1].set_title("Step Size Distributions")
2✔
1208
    axes[0, 1].legend()
2✔
1209
    axes[0, 1].grid(True, alpha=0.3)
2✔
1210

1211
    # Step size adaptation quality
1212
    axes[1, 0].plot(step_means / step_stds, alpha=0.7, linewidth=1)
2✔
1213
    axes[1, 0].set_xlabel("Iteration")
2✔
1214
    axes[1, 0].set_ylabel("Mean / Std")
2✔
1215
    axes[1, 0].set_title("Step Size Consistency")
2✔
1216
    axes[1, 0].grid(True, alpha=0.3)
2✔
1217

1218
    # Summary statistics
1219
    summary_lines = [
2✔
1220
        "Step Size Summary:",
1221
        f"Final mean: {step_means[-1]:.4f}",
1222
        f"Final std: {step_stds[-1]:.4f}",
1223
        f"Mean of means: {np.mean(step_means):.4f}",
1224
        f"Mean of stds: {np.mean(step_stds):.4f}",
1225
        "",
1226
        "Adaptation Quality:",
1227
        f"CV of means: {np.std(step_means)/np.mean(step_means):.3f}",
1228
        f"CV of stds: {np.std(step_stds)/np.mean(step_stds):.3f}",
1229
    ]
1230

1231
    axes[1, 1].text(
2✔
1232
        0.05,
1233
        0.95,
1234
        "\n".join(summary_lines),
1235
        transform=axes[1, 1].transAxes,
1236
        fontsize=10,
1237
        verticalalignment="top",
1238
        fontfamily="monospace",
1239
    )
1240
    axes[1, 1].set_title("Step Size Statistics")
2✔
1241
    axes[1, 1].axis("off")
2✔
1242

1243
    plt.tight_layout()
2✔
1244

1245

1246
def _create_divergences_diagnostics(idata, diag_dir, config):
2✔
1247
    """Create divergences diagnostics for NUTS samplers."""
1248
    # Check if divergences data exists
1249
    has_divergences = "diverging" in idata.sample_stats
2✔
1250
    has_channel_divergences = any(
2✔
1251
        key.startswith("diverging_channel_") for key in idata.sample_stats
1252
    )
1253

1254
    if not has_divergences and not has_channel_divergences:
2✔
1255
        return  # Nothing to plot
2✔
1256

1257
    @safe_plot(f"{diag_dir}/divergences.png", config.dpi)
2✔
1258
    def plot_divergences():
2✔
1259
        _plot_divergences(idata, config)
2✔
1260

1261
    plot_divergences()
2✔
1262

1263

1264
def _plot_divergences(idata, config):
2✔
1265
    """Plot divergences diagnostics."""
1266
    # Collect all divergence data
1267
    divergences_data = {}
2✔
1268

1269
    # Check for main divergences (single chain NUTS)
1270
    if "diverging" in idata.sample_stats:
2✔
1271
        divergences_data["main"] = (
2✔
1272
            idata.sample_stats.diverging.values.flatten()
1273
        )
1274

1275
    # Check for channel-specific divergences (blocked NUTS)
1276
    channel_divergences = {}
2✔
1277
    for key in idata.sample_stats:
2✔
1278
        if key.startswith("diverging_channel_"):
2✔
NEW
1279
            channel_idx = key.replace("diverging_channel_", "")
×
NEW
1280
            channel_divergences[int(channel_idx)] = idata.sample_stats[
×
1281
                key
1282
            ].values.flatten()
1283

1284
    if channel_divergences:
2✔
NEW
1285
        divergences_data.update(channel_divergences)
×
1286

1287
    if not divergences_data:
2✔
NEW
1288
        fig, ax = plt.subplots(figsize=config.figsize)
×
NEW
1289
        ax.text(
×
1290
            0.5, 0.5, "No divergence data available", ha="center", va="center"
1291
        )
NEW
1292
        ax.set_title("Divergences Diagnostics")
×
NEW
1293
        return
×
1294

1295
    # Create subplot layout
1296
    n_plots = len(divergences_data)
2✔
1297
    if n_plots == 1:
2✔
1298
        fig, axes = plt.subplots(1, 2, figsize=config.figsize)
2✔
1299
        trace_ax, summary_ax = axes
2✔
1300
    else:
1301
        # Multiple plots - arrange in grid
NEW
1302
        cols = 2
×
NEW
1303
        rows = (n_plots + 1) // cols  # Ceiling division
×
NEW
1304
        fig, axes = plt.subplots(rows, cols, figsize=config.figsize)
×
NEW
1305
        if rows == 1:
×
NEW
1306
            axes = axes.reshape(1, -1)
×
NEW
1307
        axes = axes.flatten()
×
1308

1309
        # Last plot goes in summary_ax if odd number
NEW
1310
        if n_plots % 2 == 1:
×
NEW
1311
            trace_axes = axes[:-1]
×
NEW
1312
            summary_ax = axes[-1]
×
1313
        else:
NEW
1314
            trace_axes = axes
×
NEW
1315
            summary_ax = None
×
1316

1317
    # Plot divergences traces
1318
    total_divergences = 0
2✔
1319
    total_iterations = 0
2✔
1320

1321
    plot_idx = 0
2✔
1322
    for label, div_values in divergences_data.items():
2✔
1323
        if label == "main":
2✔
1324
            title = "NUTS Divergences"
2✔
1325
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
2✔
1326
        else:
NEW
1327
            title = f"Channel {label} Divergences"
×
NEW
1328
            ax = trace_axes[plot_idx] if n_plots > 1 else axes[0]
×
NEW
1329
            plot_idx += 1
×
1330

1331
        # Plot divergence indicators (where divergences occur)
1332
        div_indices = np.where(div_values)[0]
2✔
1333
        ax.scatter(
2✔
1334
            div_indices,
1335
            np.ones_like(div_indices),
1336
            color="red",
1337
            marker="x",
1338
            s=50,
1339
            linewidth=2,
1340
            label="Divergent",
1341
            alpha=0.8,
1342
        )
1343

1344
        # Add background shading for divergent regions
1345
        if len(div_indices) > 0:
2✔
NEW
1346
            for idx in div_indices:
×
NEW
1347
                ax.axvspan(idx - 0.5, idx + 0.5, alpha=0.2, color="red")
×
1348

1349
        ax.set_xlabel("Iteration")
2✔
1350
        ax.set_ylabel("Divergence Indicator")
2✔
1351
        ax.set_title(title)
2✔
1352
        ax.set_yticks([0, 1])
2✔
1353
        ax.set_yticklabels(["No", "Yes"])
2✔
1354
        ax.grid(True, alpha=0.3)
2✔
1355

1356
        # Add statistics
1357
        n_divergent = np.sum(div_values)
2✔
1358
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1359
        stats_text = f"{n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
2✔
1360
        ax.text(
2✔
1361
            0.02,
1362
            0.98,
1363
            stats_text,
1364
            transform=ax.transAxes,
1365
            fontsize=10,
1366
            bbox=dict(boxstyle="round", facecolor="lightcoral", alpha=0.8),
1367
            verticalalignment="top",
1368
        )
1369

1370
        total_divergences += n_divergent
2✔
1371
        total_iterations += len(div_values)
2✔
1372

1373
        # Legend only if there are divergences
1374
        if n_divergent > 0:
2✔
NEW
1375
            ax.legend(loc="upper right", fontsize="small")
×
1376

1377
    # Summary plot
1378
    if summary_ax is not None and n_plots > 1:
2✔
NEW
1379
        summary_ax.text(
×
1380
            0.05,
1381
            0.95,
1382
            _get_divergences_summary(divergences_data),
1383
            transform=summary_ax.transAxes,
1384
            fontsize=12,
1385
            verticalalignment="top",
1386
            fontfamily="monospace",
1387
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1388
        )
NEW
1389
        summary_ax.set_title("Divergences Summary")
×
NEW
1390
        summary_ax.axis("off")
×
1391
    elif n_plots == 1:
2✔
1392
        axes[1].text(
2✔
1393
            0.05,
1394
            0.95,
1395
            _get_divergences_summary(divergences_data),
1396
            transform=axes[1].transAxes,
1397
            fontsize=12,
1398
            verticalalignment="top",
1399
            fontfamily="monospace",
1400
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1401
        )
1402
        axes[1].set_title("Divergences Summary")
2✔
1403
        axes[1].axis("off")
2✔
1404

1405
    # Overall title
1406
    overall_pct = (
2✔
1407
        total_divergences / total_iterations * 100
1408
        if total_iterations > 0
1409
        else 0
1410
    )
1411
    fig.suptitle(f"Overall Divergences: {overall_pct:.2f}%")
2✔
1412

1413
    plt.tight_layout()
2✔
1414

1415

1416
def _get_divergences_summary(divergences_data):
2✔
1417
    """Generate text summary of divergences."""
1418
    lines = ["Divergences Summary:", ""]
2✔
1419

1420
    total_divergences = 0
2✔
1421
    total_iterations = 0
2✔
1422

1423
    for label, div_values in divergences_data.items():
2✔
1424
        n_divergent = np.sum(div_values)
2✔
1425
        pct_divergent = n_divergent / len(div_values) * 100
2✔
1426

1427
        if label == "main":
2✔
1428
            lines.append(
2✔
1429
                f"NUTS: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1430
            )
1431
        else:
NEW
1432
            lines.append(
×
1433
                f"Channel {label}: {n_divergent}/{len(div_values)} ({pct_divergent:.2f}%)"
1434
            )
1435

1436
        total_divergences += n_divergent
2✔
1437
        total_iterations += len(div_values)
2✔
1438

1439
    lines.append("")
2✔
1440
    overall_pct = (
2✔
1441
        total_divergences / total_iterations * 100
1442
        if total_iterations > 0
1443
        else 0
1444
    )
1445
    lines.append(
2✔
1446
        f"Total: {total_divergences}/{total_iterations} ({overall_pct:.2f}%)"
1447
    )
1448

1449
    lines.append("")
2✔
1450
    lines.append("Interpretation:")
2✔
1451
    if overall_pct == 0:
2✔
1452
        lines.append("  ✓ No divergences detected")
2✔
1453
        lines.append("    Sampling appears well-behaved")
2✔
NEW
1454
    elif overall_pct < 0.1:
×
NEW
1455
        lines.append("  ~ Few divergences")
×
NEW
1456
        lines.append("    Generally good, but monitor")
×
NEW
1457
    elif overall_pct < 1.0:
×
NEW
1458
        lines.append("  ⚠ Some divergences detected")
×
NEW
1459
        lines.append("    May indicate sampling issues")
×
1460
    else:
NEW
1461
        lines.append("  ✗ Many divergences!")
×
NEW
1462
        lines.append("    Significant sampling problems")
×
NEW
1463
        lines.append("    Consider model reparameterization")
×
1464

1465
    return "\n".join(lines)
2✔
1466

1467

1468
def _plot_grouped_traces(idata, figsize):
2✔
1469
    """Create grouped trace plots for delta, phi, and weights parameters."""
1470
    # Define color cycle for multiple parameters in each group
1471
    colors = [
×
1472
        "blue",
1473
        "red",
1474
        "green",
1475
        "orange",
1476
        "purple",
1477
        "brown",
1478
        "pink",
1479
        "gray",
1480
        "olive",
1481
        "cyan",
1482
    ]
1483

1484
    # Group parameters by type
1485
    delta_params = [
×
1486
        param
1487
        for param in idata.posterior.data_vars
1488
        if param.startswith("delta")
1489
    ]
1490
    phi_params = [
×
1491
        param for param in idata.posterior.data_vars if param.startswith("phi")
1492
    ]
1493
    weights_params = [
×
1494
        param
1495
        for param in idata.posterior.data_vars
1496
        if param.startswith("weights")
1497
    ]
1498

1499
    # Create 3 subplots
1500
    fig, axes = plt.subplots(3, 1, figsize=figsize)
×
1501

1502
    # Plot delta parameters
1503
    ax = axes[0]
×
1504
    if delta_params:
×
1505
        for i, param in enumerate(delta_params):
×
1506
            color = colors[i % len(colors)]
×
1507
            # For multivariate parameters, merge across chains
1508
            values = idata.posterior[param].values
×
1509
            if values.ndim == 3:  # (chain, draw, possibly_channel)
×
1510
                if values.shape[-1] == 1:
×
1511
                    values = values.squeeze(-1)  # Remove singleton dimension
×
1512
                else:
1513
                    values = values.reshape(
×
1514
                        values.shape[0] * values.shape[1], -1
1515
                    )  # Flatten chain/draw dims
1516
                    if values.shape[-1] > 1:  # Multiple values per timestep
×
1517
                        values = values.mean(
×
1518
                            axis=-1
1519
                        )  # Average across channels if needed
1520
                    else:
1521
                        values = values.flatten()
×
1522
            elif values.ndim == 2:  # (chain, draw)
×
1523
                values = values.flatten()
×
1524

1525
            ax.plot(values, color=color, alpha=0.7, linewidth=1, label=param)
×
1526
        ax.set_ylabel("Delta Parameters")
×
1527
        ax.set_title("Delta Parameters Trace")
×
1528
        ax.legend(loc="upper right", fontsize="small")
×
1529
        ax.grid(True, alpha=0.3)
×
1530
    else:
1531
        ax.text(
×
1532
            0.5,
1533
            0.5,
1534
            "No delta parameters found",
1535
            ha="center",
1536
            va="center",
1537
            transform=ax.transAxes,
1538
        )
1539
        ax.set_title("Delta Parameters")
×
1540
        ax.axis("off")
×
1541

1542
    # Plot phi parameters
1543
    ax = axes[1]
×
1544
    if phi_params:
×
1545
        for i, param in enumerate(phi_params):
×
1546
            color = colors[i % len(colors)]
×
1547
            # For multivariate parameters, merge across chains
1548
            values = idata.posterior[param].values
×
1549
            if values.ndim == 3:  # (chain, draw, possibly_channel)
×
1550
                if values.shape[-1] == 1:
×
1551
                    values = values.squeeze(-1)  # Remove singleton dimension
×
1552
                else:
1553
                    values = values.reshape(
×
1554
                        values.shape[0] * values.shape[1], -1
1555
                    )  # Flatten chain/draw dims
1556
                    if values.shape[-1] > 1:  # Multiple values per timestep
×
1557
                        values = values.mean(
×
1558
                            axis=-1
1559
                        )  # Average across channels if needed
1560
                    else:
1561
                        values = values.flatten()
×
1562
            elif values.ndim == 2:  # (chain, draw)
×
1563
                values = values.flatten()
×
1564

1565
            ax.plot(values, color=color, alpha=0.7, linewidth=1, label=param)
×
1566
        ax.set_ylabel("Phi Parameters")
×
1567
        ax.set_title("Phi Parameters Trace")
×
1568
        ax.legend(loc="upper right", fontsize="small")
×
1569
        ax.grid(True, alpha=0.3)
×
1570
    else:
1571
        ax.text(
×
1572
            0.5,
1573
            0.5,
1574
            "No phi parameters found",
1575
            ha="center",
1576
            va="center",
1577
            transform=ax.transAxes,
1578
        )
1579
        ax.set_title("Phi Parameters")
×
1580
        ax.axis("off")
×
1581

1582
    # Plot weights parameters (these are higher dimensional)
1583
    ax = axes[2]
×
1584
    if weights_params:
×
1585
        # For weights, we'll show the mean across weight dimensions if they have shape (chain, draw, weight_dim)
1586
        max_traces = min(
×
1587
            10, len(weights_params)
1588
        )  # Limit number of weight parameters to show
1589
        for i, param in enumerate(weights_params[:max_traces]):
×
1590
            color = colors[i % len(colors)]
×
1591
            values = idata.posterior[param].values
×
1592

1593
            # Handle different dimensionalities
1594
            if values.ndim == 4:  # (chain, draw, dim1, dim2)
×
1595
                values = values.mean(axis=-1).mean(axis=-1).flatten()
×
1596
            elif values.ndim == 3:  # (chain, draw, weight_dim)
×
1597
                values = values.mean(
×
1598
                    axis=-1
1599
                ).flatten()  # Average across weight dimension
1600
            elif values.ndim == 2:  # (chain, draw)
×
1601
                values = values.flatten()
×
1602

1603
            ax.plot(values, color=color, alpha=0.7, linewidth=1, label=param)
×
1604

1605
        if len(weights_params) > max_traces:
×
1606
            ax.text(
×
1607
                0.02,
1608
                0.02,
1609
                f"Showing {max_traces} of {len(weights_params)} weight parameters",
1610
                transform=ax.transAxes,
1611
                fontsize="small",
1612
                bbox=dict(boxstyle="round", facecolor="lightcoral", alpha=0.7),
1613
            )
1614

1615
        ax.set_xlabel("Iteration")
×
1616
        ax.set_ylabel("Weights Parameters (mean)")
×
1617
        ax.set_title("Weights Parameters Trace (averaged)")
×
1618
        ax.legend(loc="upper right", fontsize="small")
×
1619
        ax.grid(True, alpha=0.3)
×
1620
    else:
1621
        ax.text(
×
1622
            0.5,
1623
            0.5,
1624
            "No weights parameters found",
1625
            ha="center",
1626
            va="center",
1627
            transform=ax.transAxes,
1628
        )
1629
        ax.set_title("Weights Parameters")
×
1630
        ax.axis("off")
×
1631

1632
    plt.tight_layout()
×
1633

1634

1635
def _group_parameters_simple(idata):
2✔
1636
    """Simple parameter grouping for counting."""
1637
    param_groups = {"phi": [], "delta": [], "weights": [], "other": []}
2✔
1638

1639
    for param in idata.posterior.data_vars:
2✔
1640
        if param.startswith("phi"):
2✔
1641
            param_groups["phi"].append(param)
2✔
1642
        elif param.startswith("delta"):
2✔
1643
            param_groups["delta"].append(param)
2✔
1644
        elif param.startswith("weights"):
2✔
1645
            param_groups["weights"].append(param)
2✔
1646
        else:
1647
            param_groups["other"].append(param)
×
1648

1649
    return {k: v for k, v in param_groups.items() if v}
2✔
1650

1651

1652
def generate_diagnostics_summary(idata, outdir):
2✔
1653
    """Generate comprehensive text summary."""
1654
    summary = []
2✔
1655
    summary.append("=== MCMC Diagnostics Summary ===\n")
2✔
1656

1657
    # Basic info
1658
    attrs = getattr(idata, "attrs", {}) or {}
2✔
1659
    if not hasattr(attrs, "get"):
2✔
1660
        attrs = dict(attrs)
×
1661

1662
    n_samples = idata.posterior.sizes.get("draw", 0)
2✔
1663
    n_chains = idata.posterior.sizes.get("chain", 1)
2✔
1664
    n_params = len(list(idata.posterior.data_vars))
2✔
1665
    sampler_type = attrs.get("sampler_type", "Unknown")
2✔
1666

1667
    summary.append(f"Sampler: {sampler_type}")
2✔
1668
    summary.append(
2✔
1669
        f"Samples: {n_samples} per chain × {n_chains} chains = {n_samples * n_chains} total"
1670
    )
1671
    summary.append(f"Parameters: {n_params}")
2✔
1672

1673
    # Parameter breakdown
1674
    param_groups = _group_parameters_simple(idata)
2✔
1675
    if param_groups:
2✔
1676
        param_summary = ", ".join(
2✔
1677
            [f"{k}: {len(v)}" for k, v in param_groups.items()]
1678
        )
1679
        summary.append(f"Parameter groups: {param_summary}")
2✔
1680

1681
    # ESS
1682
    try:
2✔
1683
        ess = az.ess(idata).to_array().values.flatten()
2✔
1684
        ess_values = ess[~np.isnan(ess)]
2✔
1685

1686
        if len(ess_values) > 0:
2✔
1687
            summary.append(
2✔
1688
                f"\nESS: min={ess_values.min():.0f}, mean={ess_values.mean():.0f}, max={ess_values.max():.0f}"
1689
            )
1690
            summary.append(f"ESS ≥ 400: {(ess_values >= 400).mean()*100:.1f}%")
2✔
1691
    except Exception as e:
×
1692
        summary.append(f"\nESS: unavailable")
×
1693

1694
    # R-hat
1695
    try:
2✔
1696
        rhat = az.rhat(idata).to_array().values.flatten()
2✔
1697
        rhat_values = rhat[~np.isnan(rhat)]
2✔
1698

1699
        if len(rhat_values) > 0:
2✔
1700
            summary.append(
×
1701
                f"R-hat: max={rhat_values.max():.3f}, mean={rhat_values.mean():.3f}"
1702
            )
1703
            summary.append(
×
1704
                f"R-hat > 1.01: {(rhat_values > 1.01).mean()*100:.1f}%"
1705
            )
1706
    except Exception:
×
1707
        summary.append(f"R-hat: unavailable")
×
1708

1709
    # Acceptance
1710
    accept_key = None
2✔
1711
    if "accept_prob" in idata.sample_stats:
2✔
1712
        accept_key = "accept_prob"
2✔
1713
    elif "acceptance_rate" in idata.sample_stats:
×
1714
        accept_key = "acceptance_rate"
×
1715

1716
    if accept_key is not None:
2✔
1717
        accept_rate = idata.sample_stats[accept_key].values.mean()
2✔
1718
        target_rate = attrs.get(
2✔
1719
            "target_accept_rate", attrs.get("target_accept_prob", 0.44)
1720
        )
1721
        summary.append(
2✔
1722
            f"Acceptance rate: {accept_rate:.3f} (target: {target_rate:.3f})"
1723
        )
1724

1725
    # PSD accuracy diagnostics (requires true_psd in attrs)
1726
    has_true_psd = "true_psd" in attrs
2✔
1727

1728
    if has_true_psd:
2✔
1729
        coverage_level = attrs.get("coverage_level")
2✔
1730
        coverage_label = (
2✔
1731
            f"{int(round(coverage_level * 100))}% interval coverage"
1732
            if coverage_level is not None
1733
            else "Interval coverage"
1734
        )
1735

1736
        def _format_riae_line(value, errorbars, prefix="  "):
2✔
1737
            line = f"{prefix}RIAE: {value:.3f}"
2✔
1738
            if errorbars:
2✔
1739
                q05, q25, median, q75, q95 = errorbars
2✔
1740
                line += f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
2✔
1741
            summary.append(line)
2✔
1742

1743
        def _format_coverage_line(value, prefix="  "):
2✔
1744
            if value is None:
2✔
1745
                return
×
1746
            summary.append(f"{prefix}{coverage_label}: {value * 100:.1f}%")
2✔
1747

1748
        summary.append("\nPSD accuracy diagnostics:")
2✔
1749

1750
        if "riae" in attrs:
2✔
1751
            _format_riae_line(attrs["riae"], attrs.get("riae_errorbars"))
2✔
1752
        if "coverage" in attrs:
2✔
1753
            _format_coverage_line(attrs["coverage"])
2✔
1754

1755
        channel_indices = sorted(
2✔
1756
            int(key.replace("riae_ch", ""))
1757
            for key in attrs.keys()
1758
            if key.startswith("riae_ch")
1759
        )
1760

1761
        for idx in channel_indices:
2✔
1762
            metrics = []
×
1763
            riae_key = f"riae_ch{idx}"
×
1764
            cov_key = f"coverage_ch{idx}"
×
1765
            error_key = f"riae_errorbars_ch{idx}"
×
1766

1767
            if riae_key in attrs:
×
1768
                riae_line = f"RIAE {attrs[riae_key]:.3f}"
×
1769
                errorbars = attrs.get(error_key)
×
1770
                if errorbars:
×
1771
                    q05, _, median, _, q95 = errorbars
×
1772
                    riae_line += (
×
1773
                        f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
1774
                    )
1775
                metrics.append(riae_line)
×
1776

1777
            if cov_key in attrs:
×
1778
                metrics.append(f"{coverage_label} {attrs[cov_key] * 100:.1f}%")
×
1779

1780
            if metrics:
×
1781
                summary.append(f"  Channel {idx}: " + "; ".join(metrics))
×
1782

1783
    # Overall assessment
1784
    try:
2✔
1785
        if len(ess_values) > 0 and len(rhat_values) > 0:
2✔
1786
            ess_good = (ess_values >= 400).mean() * 100
×
1787
            rhat_good = (rhat_values <= 1.01).mean() * 100
×
1788

1789
            summary.append(f"\nOverall Convergence Assessment:")
×
1790
            if ess_good >= 90 and rhat_good >= 90:
×
1791
                summary.append("  Status: EXCELLENT ✓")
×
1792
            elif ess_good >= 75 and rhat_good >= 75:
×
1793
                summary.append("  Status: GOOD ✓")
×
1794
            else:
1795
                summary.append("  Status: NEEDS ATTENTION ⚠")
×
1796
    except:
×
1797
        pass
×
1798

1799
    summary_text = "\n".join(summary)
2✔
1800

1801
    if outdir:
2✔
1802
        with open(f"{outdir}/diagnostics_summary.txt", "w") as f:
2✔
1803
            f.write(summary_text)
2✔
1804

1805
    print("\n" + summary_text + "\n")
2✔
1806
    return summary_text
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc