• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18117111339

30 Sep 2025 02:49AM UTC coverage: 82.829% (+0.1%) from 82.708%
18117111339

push

github

avivajpeyi
Add sim study slurm

363 of 434 branches covered (83.64%)

Branch coverage included in aggregate %.

2893 of 3497 relevant lines covered (82.73%)

1.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

62.73
/src/log_psplines/plotting/diagnostics.py
1
import os
2✔
2
from dataclasses import dataclass
2✔
3
from functools import wraps
2✔
4
from typing import Callable, Optional
2✔
5

6
import arviz as az
2✔
7
import matplotlib.pyplot as plt
2✔
8
import numpy as np
2✔
9

10

11
@dataclass
2✔
12
class DiagnosticsConfig:
2✔
13
    figsize: tuple = (12, 8)
2✔
14
    dpi: int = 150
2✔
15
    ess_threshold: int = 400
2✔
16
    rhat_threshold: float = 1.01
2✔
17

18

19
def safe_plot(filename: str, dpi: int = 150):
2✔
20
    """Decorator for safe plotting with error handling."""
21

22
    def decorator(plot_func: Callable):
2✔
23
        @wraps(plot_func)
2✔
24
        def wrapper(*args, **kwargs):
2✔
25
            try:
2✔
26
                result = plot_func(*args, **kwargs)
2✔
27
                plt.savefig(filename, dpi=dpi, bbox_inches="tight")
2✔
28
                plt.close()
2✔
29
                return True
2✔
30
            except Exception as e:
×
31
                print(
×
32
                    f"Warning: Failed to create {os.path.basename(filename)}: {e}"
33
                )
34
                plt.close("all")
×
35
                return False
×
36

37
        return wrapper
2✔
38

39
    return decorator
2✔
40

41

42
def plot_diagnostics(
2✔
43
    idata: az.InferenceData,
44
    outdir: str,
45
    n_channels: Optional[int] = None,
46
    n_freq: Optional[int] = None,
47
    runtime: Optional[float] = None,
48
    config: Optional[DiagnosticsConfig] = None,
49
) -> None:
50
    """
51
    Create essential MCMC diagnostics in organized subdirectories.
52
    """
53
    if outdir is None:
2✔
54
        return
×
55

56
    if config is None:
2✔
57
        config = DiagnosticsConfig()
2✔
58

59
    # Create diagnostics subdirectory
60
    diag_dir = os.path.join(outdir, "diagnostics")
2✔
61
    os.makedirs(diag_dir, exist_ok=True)
2✔
62

63
    print("Generating MCMC diagnostics...")
2✔
64

65
    # Generate summary report
66
    generate_diagnostics_summary(idata, diag_dir)
2✔
67

68
    # Essential diagnostics only
69
    _create_essential_diagnostics(
2✔
70
        idata, diag_dir, config, n_channels, n_freq, runtime
71
    )
72

73
    print(f"Diagnostics saved to {diag_dir}/")
2✔
74

75

76
def _create_essential_diagnostics(
2✔
77
    idata, diag_dir, config, n_channels, n_freq, runtime
78
):
79
    """Create only the essential diagnostic plots."""
80

81
    # 1. ArviZ trace plots
82
    @safe_plot(f"{diag_dir}/trace_plots.png", config.dpi)
2✔
83
    def plot_trace():
2✔
84
        # Scale figure height based on number of parameters for better readability
85
        n_params = len(list(idata.posterior.data_vars))
2✔
86
        base_height = config.figsize[1]
2✔
87
        # Add ~1 inch of height per 5 parameters, but cap at reasonable maximum
88
        scaled_height = min(base_height + (n_params // 5), base_height * 3)
2✔
89
        trace_figsize = (config.figsize[0], scaled_height)
2✔
90

91
        az.plot_trace(idata, figsize=trace_figsize)
2✔
92
        plt.suptitle("Parameter Traces", fontsize=14)
2✔
93
        plt.tight_layout()
2✔
94

95
    plot_trace()
2✔
96

97
    # 2. Summary dashboard with key convergence metrics
98
    @safe_plot(f"{diag_dir}/summary_dashboard.png", config.dpi)
2✔
99
    def plot_summary():
2✔
100
        _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime)
2✔
101

102
    plot_summary()
2✔
103

104
    # 3. Log posterior diagnostics
105
    @safe_plot(f"{diag_dir}/log_posterior.png", config.dpi)
2✔
106
    def plot_lp():
2✔
107
        _plot_log_posterior(idata, config)
2✔
108

109
    plot_lp()
2✔
110

111
    # 4. Acceptance rate diagnostics
112
    @safe_plot(f"{diag_dir}/acceptance_diagnostics.png", config.dpi)
2✔
113
    def plot_acceptance():
2✔
114
        _plot_acceptance_diagnostics(idata, config)
2✔
115

116
    plot_acceptance()
2✔
117

118
    # 5. Sampler-specific diagnostics
119
    _create_sampler_diagnostics(idata, diag_dir, config)
2✔
120

121

122
def _plot_summary_dashboard(idata, config, n_channels, n_freq, runtime):
2✔
123
    """Essential summary dashboard."""
124

125
    # Check if R-hat is available
126
    rhat_available = False
2✔
127
    try:
2✔
128
        rhat = az.rhat(idata).to_array().values.flatten()
2✔
129
        rhat_values = rhat[~np.isnan(rhat)]
2✔
130
        rhat_available = len(rhat_values) > 0
2✔
131
    except Exception:
×
132
        pass
×
133

134
    # Create subplot layout based on data availability
135
    if rhat_available:
2✔
136
        # Full 2x3 layout when R-hat is available
137
        fig, axes = plt.subplots(2, 3, figsize=config.figsize)
×
138
        rhat_ax = axes[0, 1]
×
139
        scatter_ax = axes[0, 2]
×
140
        meta_ax = axes[1, 0]
×
141
        param_ax = axes[1, 1]
×
142
        status_ax = axes[1, 2]
×
143
    else:
144
        # Reduced 2x2 layout when R-hat is not available
145
        fig, axes = plt.subplots(
2✔
146
            2, 2, figsize=(config.figsize[0] * 0.8, config.figsize[1])
147
        )
148
        # Rearrange axes for 2x2 layout
149
        meta_ax = axes[0, 0]
2✔
150
        param_ax = axes[0, 1]
2✔
151
        status_ax = axes[1, 0]
2✔
152
        # Use bottom-right for additional info or leave empty
153
        axes[1, 1].axis("off")  # Hide unused subplot
2✔
154

155
    # ESS histogram (always in top-left)
156
    try:
2✔
157
        ess = az.ess(idata).to_array().values.flatten()
2✔
158
        ess_values = ess[~np.isnan(ess)]
2✔
159

160
        if len(ess_values) > 0:
2✔
161
            # Add color zones for ESS quality
162
            ess_thresholds = [
2✔
163
                (400, "red", "--", "Minimum reliable ESS"),
164
                (1000, "orange", "--", "Good ESS"),
165
                (
166
                    np.max(ess_values),
167
                    "green",
168
                    ":",
169
                    f"Max ESS = {np.max(ess_values):.0f}",
170
                ),
171
            ]
172

173
            ax_ess = axes[0, 0]  # Always available
2✔
174
            n, bins, patches = ax_ess.hist(
2✔
175
                ess_values, bins=30, alpha=0.7, edgecolor="black"
176
            )
177

178
            # Add reference lines
179
            for threshold, color, style, label in ess_thresholds:
2✔
180
                ax_ess.axvline(
2✔
181
                    x=threshold,
182
                    color=color,
183
                    linestyle=style,
184
                    linewidth=2 if threshold < np.max(ess_values) else 1,
185
                    alpha=0.8,
186
                    label=label,
187
                )
188

189
            ax_ess.set_xlabel("ESS")
2✔
190
            ax_ess.set_ylabel("Count")
2✔
191
            ax_ess.set_title("ESS Distribution")
2✔
192
            ax_ess.legend(loc="upper right", fontsize="x-small")
2✔
193
            ax_ess.grid(True, alpha=0.3)
2✔
194

195
            pct_good = (ess_values >= config.ess_threshold).mean() * 100
2✔
196
            ax_ess.text(
2✔
197
                0.02,
198
                0.98,
199
                f"Min: {ess_values.min():.0f}\nMean: {ess_values.mean():.0f}\n≥{config.ess_threshold}: {pct_good:.1f}%",
200
                transform=ax_ess.transAxes,
201
                fontsize=10,
202
                verticalalignment="top",
203
                bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.7),
204
            )
205
    except Exception:
×
206
        axes[0, 0].text(0.5, 0.5, "ESS unavailable", ha="center", va="center")
×
207
        axes[0, 0].set_title("ESS Distribution")
×
208

209
    # R-hat histogram and scatter (only when R-hat is available)
210
    if rhat_available and "rhat_ax" in locals():
2✔
211
        # Add shaded regions for R-hat quality
212
        rhat_ax.axvspan(
×
213
            1.0,
214
            config.rhat_threshold,
215
            alpha=0.1,
216
            color="green",
217
            label="Converged (≤1.01)",
218
        )
219
        rhat_ax.axvspan(
×
220
            config.rhat_threshold,
221
            1.1,
222
            alpha=0.1,
223
            color="yellow",
224
            label="Concerning (1.01-1.10)",
225
        )
226
        rhat_ax.axvspan(
×
227
            1.1,
228
            rhat_values.max(),
229
            alpha=0.1,
230
            color="red",
231
            label="Not converged (>1.10)",
232
        )
233

234
        rhat_ax.hist(rhat_values, bins=30, alpha=0.7, edgecolor="black")
×
235
        rhat_ax.axvline(
×
236
            1.0,
237
            color="green",
238
            linestyle="--",
239
            linewidth=2,
240
            label="Perfectly mixed",
241
        )
242
        rhat_ax.axvline(
×
243
            config.rhat_threshold,
244
            color="orange",
245
            linestyle="--",
246
            linewidth=2,
247
            label="Acceptable",
248
        )
249
        rhat_ax.set_xlabel("R-hat")
×
250
        rhat_ax.set_ylabel("Count")
×
251
        rhat_ax.set_title("R-hat Distribution")
×
252
        rhat_ax.legend(loc="upper right", fontsize="x-small")
×
253
        rhat_ax.grid(True, alpha=0.3)
×
254

255
        pct_excellent = (rhat_values <= 1.01).mean() * 100
×
256
        pct_concerning = (
×
257
            (rhat_values > 1.01) & (rhat_values <= 1.1)
258
        ).mean() * 100
259
        pct_bad = (rhat_values > 1.1).mean() * 100
×
260
        rhat_ax.text(
×
261
            0.02,
262
            0.98,
263
            f"Max: {rhat_values.max():.3f}\nMean: {rhat_values.mean():.3f}\n≤1.01: {pct_excellent:.1f}%\n>1.10: {pct_bad:.1f}%",
264
            transform=rhat_ax.transAxes,
265
            fontsize=9,
266
            verticalalignment="top",
267
            bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.7),
268
        )
269

270
    # ESS vs R-hat scatter (only when R-hat is available)
271
    if rhat_available and "scatter_ax" in locals():
2✔
272
        try:
×
273
            if len(ess_values) > 0 and len(rhat_values) > 0:
×
274
                ess_all = az.ess(idata).to_array().values.flatten()
×
275
                rhat_all = az.rhat(idata).to_array().values.flatten()
×
276
                valid_mask = ~(np.isnan(ess_all) | np.isnan(rhat_all))
×
277

278
                if np.sum(valid_mask) > 0:
×
279
                    scatter_ax.scatter(
×
280
                        rhat_all[valid_mask],
281
                        ess_all[valid_mask],
282
                        alpha=0.6,
283
                        s=20,
284
                    )
285
                    scatter_ax.axvline(
×
286
                        config.rhat_threshold,
287
                        color="red",
288
                        linestyle="--",
289
                        alpha=0.7,
290
                    )
291
                    scatter_ax.axhline(
×
292
                        config.ess_threshold,
293
                        color="orange",
294
                        linestyle="--",
295
                        alpha=0.7,
296
                    )
297
                    scatter_ax.set_xlabel("R-hat")
×
298
                    scatter_ax.set_ylabel("ESS")
×
299
                    scatter_ax.set_title("Convergence Overview")
×
300
                    scatter_ax.grid(True, alpha=0.3)
×
301
        except Exception:
×
302
            scatter_ax.text(
×
303
                0.5, 0.5, "Scatter unavailable", ha="center", va="center"
304
            )
305
            scatter_ax.set_title("Convergence Overview")
×
306

307
    # Analysis metadata
308
    try:
2✔
309
        n_samples = idata.posterior.sizes.get("draw", 0)
2✔
310
        n_chains = idata.posterior.sizes.get("chain", 1)
2✔
311
        n_params = len(list(idata.posterior.data_vars))
2✔
312
        sampler_type = idata.attrs["sampler_type"]
2✔
313

314
        metadata_lines = [
2✔
315
            f"Sampler: {sampler_type}",
316
            f"Samples: {n_samples} × {n_chains} chains",
317
            f"Parameters: {n_params}",
318
        ]
319
        if n_channels is not None:
2✔
320
            metadata_lines.append(f"Channels: {n_channels}")
×
321
        if n_freq is not None:
2✔
322
            metadata_lines.append(f"Frequencies: {n_freq}")
×
323
        if runtime is not None:
2✔
324
            metadata_lines.append(f"Runtime: {runtime:.2f}s")
×
325

326
        meta_ax.text(
2✔
327
            0.05,
328
            0.95,
329
            "\n".join(metadata_lines),
330
            transform=meta_ax.transAxes,
331
            fontsize=12,
332
            verticalalignment="top",
333
            fontfamily="monospace",
334
        )
335
        meta_ax.set_title("Analysis Summary")
2✔
336
        meta_ax.axis("off")
2✔
337
    except Exception:
×
338
        meta_ax.text(
×
339
            0.5, 0.5, "Metadata unavailable", ha="center", va="center"
340
        )
341
        meta_ax.set_title("Analysis Summary")
×
342
        meta_ax.axis("off")
×
343

344
    # Parameter counts (placeholder - turned off as not helpful per user feedback)
345
    try:
2✔
346
        # Just show a summary text instead of the full bar chart
347
        param_groups = _group_parameters_simple(idata)
2✔
348
        if param_groups:
2✔
349
            summary_text = "Parameter Summary:\n"
2✔
350
            for group_name, params in param_groups.items():
2✔
351
                if params:  # Only show non-empty groups
2✔
352
                    summary_text += f"{group_name}: {len(params)}\n"
2✔
353
            param_ax.text(
2✔
354
                0.05,
355
                0.95,
356
                summary_text.strip(),
357
                transform=param_ax.transAxes,
358
                fontsize=11,
359
                verticalalignment="top",
360
                fontfamily="monospace",
361
            )
362
        param_ax.set_title("Parameter Summary")
2✔
363
        param_ax.axis("off")  # Don't show axes for this simple text summary
2✔
364
    except Exception:
×
365
        param_ax.text(
×
366
            0.5,
367
            0.5,
368
            "Parameter summary\nunavailable",
369
            ha="center",
370
            va="center",
371
        )
372
        param_ax.set_title("Parameter Summary")
×
373
        param_ax.axis("off")
×
374

375
    # Convergence status
376
    try:
2✔
377
        ess_retrieved = []
2✔
378
        rhat_retrieved = []
2✔
379
        try:
2✔
380
            ess_retrieved = az.ess(idata).to_array().values.flatten()
2✔
381
            ess_retrieved = ess_retrieved[~np.isnan(ess_retrieved)]
2✔
382
        except:
×
383
            pass
×
384
        try:
2✔
385
            rhat_retrieved = az.rhat(idata).to_array().values.flatten()
2✔
386
            rhat_retrieved = rhat_retrieved[~np.isnan(rhat_retrieved)]
2✔
387
        except:
×
388
            pass
×
389

390
        status_lines = ["Convergence Status:"]
2✔
391

392
        if len(ess_retrieved) > 0:
2✔
393
            ess_good = (ess_retrieved >= config.ess_threshold).mean() * 100
2✔
394
            status_lines.append(
2✔
395
                f"ESS ≥ {config.ess_threshold}: {ess_good:.0f}%"
396
            )
397

398
        if len(rhat_retrieved) > 0:
2✔
399
            rhat_good = (rhat_retrieved <= config.rhat_threshold).mean() * 100
×
400
            status_lines.append(
×
401
                f"R-hat ≤ {config.rhat_threshold}: {rhat_good:.0f}%"
402
            )
403

404
        status_lines.append("")
2✔
405
        status_lines.append("Overall Status:")
2✔
406

407
        if len(ess_retrieved) > 0 and len(rhat_retrieved) > 0:
2✔
408
            if ess_good >= 90 and rhat_good >= 90:
×
409
                status_lines.append("✓ EXCELLENT")
×
410
                color = "green"
×
411
            elif ess_good >= 75 and rhat_good >= 75:
×
412
                status_lines.append("✓ GOOD")
×
413
                color = "orange"
×
414
            else:
415
                status_lines.append("⚠ NEEDS ATTENTION")
×
416
                color = "red"
×
417
        elif len(ess_retrieved) > 0:
2✔
418
            if ess_good >= 90:
2✔
419
                status_lines.append("✓ GOOD (based on ESS only)")
×
420
                color = "green"
×
421
            elif ess_good >= 75:
2✔
422
                status_lines.append("✓ ADEQUATE (based on ESS only)")
×
423
                color = "orange"
×
424
            else:
425
                status_lines.append("⚠ NEEDS ATTENTION")
2✔
426
                color = "red"
2✔
427
        else:
428
            status_lines.append("? UNABLE TO ASSESS")
×
429
            color = "gray"
×
430

431
        status_ax.text(
2✔
432
            0.05,
433
            0.95,
434
            "\n".join(status_lines),
435
            transform=status_ax.transAxes,
436
            fontsize=11,
437
            verticalalignment="top",
438
            fontfamily="monospace",
439
            color=color,
440
        )
441
        status_ax.set_title("Convergence Status")
2✔
442
        status_ax.axis("off")
2✔
443
    except Exception:
×
444
        status_ax.text(
×
445
            0.5, 0.5, "Status unavailable", ha="center", va="center"
446
        )
447
        status_ax.set_title("Convergence Status")
×
448
        status_ax.axis("off")
×
449

450
    plt.tight_layout()
2✔
451

452

453
def _plot_log_posterior(idata, config):
2✔
454
    """Log posterior diagnostics."""
455
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
456

457
    # Check for lp first, then log_likelihood
458
    if "lp" in idata.sample_stats:
2✔
459
        lp_values = idata.sample_stats["lp"].values.flatten()
2✔
460
        var_name = "lp"
2✔
461
        title_prefix = "Log Posterior"
2✔
462
    elif "log_likelihood" in idata.sample_stats:
×
463
        lp_values = idata.sample_stats["log_likelihood"].values.flatten()
×
464
        var_name = "log_likelihood"
×
465
        title_prefix = "Log Likelihood"
×
466
    else:
467
        # Create a fallback layout when no posterior data available
468
        fig, axes = plt.subplots(1, 1, figsize=config.figsize)
×
469
        axes.text(
×
470
            0.5,
471
            0.5,
472
            "No log posterior\nor log likelihood\navailable",
473
            ha="center",
474
            va="center",
475
            fontsize=14,
476
        )
477
        axes.set_title("Log Posterior Diagnostics")
×
478
        axes.axis("off")
×
479
        plt.tight_layout()
×
480
        return
×
481

482
    # Trace plot with running mean overlaid
483
    axes[0, 0].plot(
2✔
484
        lp_values, alpha=0.7, linewidth=1, color="blue", label="Trace"
485
    )
486

487
    # Add running mean on the same plot
488
    window_size = max(10, len(lp_values) // 100)
2✔
489
    if len(lp_values) > window_size:
2✔
490
        running_mean = np.convolve(
×
491
            lp_values, np.ones(window_size) / window_size, mode="valid"
492
        )
493
        axes[0, 0].plot(
×
494
            range(window_size // 2, window_size // 2 + len(running_mean)),
495
            running_mean,
496
            alpha=0.9,
497
            linewidth=3,
498
            color="red",
499
            label=f"Running mean (w={window_size})",
500
        )
501

502
    axes[0, 0].set_xlabel("Iteration")
2✔
503
    axes[0, 0].set_ylabel(title_prefix)
2✔
504
    axes[0, 0].set_title(f"{title_prefix} Trace with Running Mean")
2✔
505
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
506
    axes[0, 0].grid(True, alpha=0.3)
2✔
507

508
    # Distribution
509
    axes[0, 1].hist(
2✔
510
        lp_values, bins=50, alpha=0.7, density=True, edgecolor="black"
511
    )
512
    axes[0, 1].axvline(
2✔
513
        np.mean(lp_values),
514
        color="red",
515
        linestyle="--",
516
        linewidth=2,
517
        label=f"Mean: {np.mean(lp_values):.1f}",
518
    )
519
    axes[0, 1].set_xlabel(title_prefix)
2✔
520
    axes[0, 1].set_ylabel("Density")
2✔
521
    axes[0, 1].set_title(f"{title_prefix} Distribution")
2✔
522
    axes[0, 1].legend(loc="best", fontsize="small")
2✔
523
    axes[0, 1].grid(True, alpha=0.3)
2✔
524

525
    # Step-to-step changes
526
    lp_diff = np.diff(lp_values)
2✔
527
    axes[1, 0].plot(lp_diff, alpha=0.5, linewidth=1)
2✔
528
    axes[1, 0].axhline(0, color="red", linestyle="--", alpha=0.7)
2✔
529
    axes[1, 0].axhline(
2✔
530
        np.mean(lp_diff),
531
        color="blue",
532
        linestyle="--",
533
        alpha=0.7,
534
        label=f"Mean change: {np.mean(lp_diff):.1f}",
535
    )
536
    axes[1, 0].set_xlabel("Iteration")
2✔
537
    axes[1, 0].set_ylabel(f"{title_prefix} Difference")
2✔
538
    axes[1, 0].set_title("Step-to-Step Changes")
2✔
539
    axes[1, 0].legend(loc="best", fontsize="small")
2✔
540
    axes[1, 0].grid(True, alpha=0.3)
2✔
541

542
    # Summary statistics
543
    stats_lines = [
2✔
544
        f"Mean: {np.mean(lp_values):.2f}",
545
        f"Std: {np.std(lp_values):.2f}",
546
        f"Min: {np.min(lp_values):.2f}",
547
        f"Max: {np.max(lp_values):.2f}",
548
        f"Range: {np.max(lp_values) - np.min(lp_values):.2f}",
549
        "",
550
        "Stability:",
551
        f"Final variation: {np.std(lp_values[-len(lp_values)//4:]):.2f}",
552
    ]
553

554
    axes[1, 1].text(
2✔
555
        0.05,
556
        0.95,
557
        "\n".join(stats_lines),
558
        transform=axes[1, 1].transAxes,
559
        fontsize=10,
560
        verticalalignment="top",
561
        fontfamily="monospace",
562
        bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
563
    )
564
    axes[1, 1].set_title("Posterior Statistics")
2✔
565
    axes[1, 1].axis("off")
2✔
566

567
    plt.tight_layout()
2✔
568

569

570
def _plot_acceptance_diagnostics(idata, config):
2✔
571
    """Acceptance rate diagnostics."""
572
    accept_key = None
2✔
573
    if "accept_prob" in idata.sample_stats:
2✔
574
        accept_key = "accept_prob"
2✔
575
    elif "acceptance_rate" in idata.sample_stats:
×
576
        accept_key = "acceptance_rate"
×
577

578
    if accept_key is None:
2✔
579
        fig, ax = plt.subplots(figsize=config.figsize)
×
580
        ax.text(
×
581
            0.5,
582
            0.5,
583
            "Acceptance rate data unavailable",
584
            ha="center",
585
            va="center",
586
        )
587
        ax.set_title("Acceptance Rate Diagnostics")
×
588
        return
×
589

590
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
591

592
    accept_rates = idata.sample_stats[accept_key].values.flatten()
2✔
593
    target_rate = getattr(idata.attrs, "target_accept_rate", 0.44)
2✔
594
    sampler_type = (
2✔
595
        idata.attrs["sampler_type"].lower()
596
        if "sampler_type" in idata.attrs
597
        else "unknown"
598
    )
599
    sampler_type = "NUTS" if "nuts" in sampler_type else "MH"
2✔
600

601
    # Define good ranges based on sampler
602
    if target_rate > 0.5:  # NUTS
2✔
603
        good_range = (0.7, 0.9)
×
604
        low_range = (0.0, 0.6)
×
605
        high_range = (0.9, 1.0)
×
606
        concerning_range = (0.6, 0.7)
×
607
    else:  # MH
608
        good_range = (0.2, 0.5)
2✔
609
        low_range = (0.0, 0.2)
2✔
610
        high_range = (0.5, 1.0)
2✔
611
        concerning_range = (0.1, 0.2)  # MH can be lower than NUTS
2✔
612

613
    # Trace plot with color zones
614
    # Add background zones
615
    axes[0, 0].axhspan(
2✔
616
        good_range[0],
617
        good_range[1],
618
        alpha=0.1,
619
        color="green",
620
        label=f"Good ({good_range[0]:.1f}-{good_range[1]:.1f})",
621
    )
622
    axes[0, 0].axhspan(
2✔
623
        low_range[0], low_range[1], alpha=0.1, color="red", label="Too low"
624
    )
625
    axes[0, 0].axhspan(
2✔
626
        high_range[0],
627
        high_range[1],
628
        alpha=0.1,
629
        color="orange",
630
        label="Too high",
631
    )
632
    if concerning_range[1] > concerning_range[0]:
2✔
633
        axes[0, 0].axhspan(
2✔
634
            concerning_range[0],
635
            concerning_range[1],
636
            alpha=0.1,
637
            color="yellow",
638
            label="Concerning",
639
        )
640

641
    # Main trace plot
642
    axes[0, 0].plot(
2✔
643
        accept_rates, alpha=0.8, linewidth=1, color="blue", label="Trace"
644
    )
645
    axes[0, 0].axhline(
2✔
646
        target_rate,
647
        color="red",
648
        linestyle="--",
649
        linewidth=2,
650
        label=f"Target ({target_rate})",
651
    )
652

653
    # Add running average on the same plot
654
    window_size = max(10, len(accept_rates) // 50)
2✔
655
    if len(accept_rates) > window_size:
2✔
656
        running_mean = np.convolve(
×
657
            accept_rates, np.ones(window_size) / window_size, mode="valid"
658
        )
659
        axes[0, 0].plot(
×
660
            range(window_size // 2, window_size // 2 + len(running_mean)),
661
            running_mean,
662
            alpha=0.9,
663
            linewidth=3,
664
            color="purple",
665
            label=f"Running mean (w={window_size})",
666
        )
667

668
    axes[0, 0].set_xlabel("Iteration")
2✔
669
    axes[0, 0].set_ylabel("Acceptance Rate")
2✔
670
    axes[0, 0].set_title(f"{sampler_type} Acceptance Rate Trace")
2✔
671
    axes[0, 0].legend(loc="best", fontsize="small")
2✔
672
    axes[0, 0].grid(True, alpha=0.3)
2✔
673

674
    # Add interpretation text
675
    interpretation = f"{sampler_type} aims for {target_rate:.2f}."
2✔
676
    if target_rate > 0.5:
2✔
677
        interpretation += " Green: efficient sampling."
×
678
    else:
679
        interpretation += " MH adapts to find optimal rate."
2✔
680
    axes[0, 0].text(
2✔
681
        0.02,
682
        0.02,
683
        interpretation,
684
        transform=axes[0, 0].transAxes,
685
        fontsize=9,
686
        bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
687
    )
688

689
    # Distribution
690
    axes[0, 1].hist(
2✔
691
        accept_rates, bins=30, alpha=0.7, density=True, edgecolor="black"
692
    )
693
    axes[0, 1].axvline(
2✔
694
        target_rate,
695
        color="red",
696
        linestyle="--",
697
        linewidth=2,
698
        label=f"Target ({target_rate})",
699
    )
700
    axes[0, 1].set_xlabel("Acceptance Rate")
2✔
701
    axes[0, 1].set_ylabel("Density")
2✔
702
    axes[0, 1].set_title("Acceptance Rate Distribution")
2✔
703
    axes[0, 1].legend()
2✔
704
    axes[0, 1].grid(True, alpha=0.3)
2✔
705

706
    # Since running means are already overlaid on the main plot, use the bottom row for additional info
707

708
    # Additional acceptance analysis - evolution over time
709
    if len(accept_rates) > 10:
2✔
710
        # Show moving standard deviation or coefficient of variation
711
        window_std = np.array(
×
712
            [
713
                np.std(accept_rates[max(0, i - 20) : i + 1])
714
                for i in range(len(accept_rates))
715
            ]
716
        )
717
        axes[1, 0].plot(window_std, alpha=0.7, color="green")
×
718
        axes[1, 0].set_xlabel("Iteration")
×
719
        axes[1, 0].set_ylabel("Rolling Std")
×
720
        axes[1, 0].set_title("Rolling Standard Deviation")
×
721
        axes[1, 0].grid(True, alpha=0.3)
×
722
    else:
723
        axes[1, 0].text(
2✔
724
            0.5,
725
            0.5,
726
            "Acceptance variability\nanalysis unavailable",
727
            ha="center",
728
            va="center",
729
        )
730
        axes[1, 0].set_title("Acceptance Stability")
2✔
731

732
    # Summary statistics (expanded)
733
    stats_text = [
2✔
734
        f"Sampler: {sampler_type}",
735
        f"Target: {target_rate:.3f}",
736
        f"Mean: {np.mean(accept_rates):.3f}",
737
        f"Std: {np.std(accept_rates):.3f}",
738
        f"CV: {np.std(accept_rates)/np.mean(accept_rates):.3f}",
739
        f"Min: {np.min(accept_rates):.3f}",
740
        f"Max: {np.max(accept_rates):.3f}",
741
        "",
742
        "Stability:",
743
        f"Final std: {np.std(accept_rates[-len(accept_rates)//4:]):.3f}",
744
    ]
745

746
    axes[1, 1].text(
2✔
747
        0.05,
748
        0.95,
749
        "\n".join(stats_text),
750
        transform=axes[1, 1].transAxes,
751
        fontsize=9,
752
        verticalalignment="top",
753
        fontfamily="monospace",
754
    )
755
    axes[1, 1].set_title("Acceptance Analysis")
2✔
756
    axes[1, 1].axis("off")
2✔
757

758
    plt.tight_layout()
2✔
759

760

761
def _create_sampler_diagnostics(idata, diag_dir, config):
2✔
762
    """Create sampler-specific diagnostics."""
763

764
    # Better sampler detection - check sampler type first
765
    sampler_type = (
2✔
766
        idata.attrs["sampler_type"].lower()
767
        if "sampler_type" in idata.attrs
768
        else "unknown"
769
    )
770

771
    # Check for NUTS-specific fields that MH definitely doesn't have
772
    nuts_specific_fields = [
2✔
773
        "energy",
774
        "num_steps",
775
        "tree_depth",
776
        "diverging",
777
        "energy_error",
778
    ]
779

780
    has_nuts = (
2✔
781
        any(field in idata.sample_stats for field in nuts_specific_fields)
782
        or "nuts" in sampler_type
783
    )
784

785
    # Check for MH-specific fields (exclude anything NUTS might have)
786
    has_mh = "step_size_mean" in idata.sample_stats and not has_nuts
2✔
787

788
    if has_nuts:
2✔
789

790
        @safe_plot(f"{diag_dir}/nuts_diagnostics.png", config.dpi)
2✔
791
        def plot_nuts():
2✔
792
            _plot_nuts_diagnostics(idata, config)
2✔
793

794
        plot_nuts()
2✔
795
    elif has_mh:
2✔
796

797
        @safe_plot(f"{diag_dir}/mh_step_sizes.png", config.dpi)
2✔
798
        def plot_mh():
2✔
799
            _plot_mh_step_sizes(idata, config)
2✔
800

801
        plot_mh()
2✔
802

803

804
def _plot_nuts_diagnostics(idata, config):
2✔
805
    """NUTS diagnostics with enhanced information."""
806
    # Determine available data to decide layout
807
    has_energy = "energy" in idata.sample_stats
2✔
808
    has_potential = "potential_energy" in idata.sample_stats
2✔
809
    has_steps = "num_steps" in idata.sample_stats
2✔
810
    has_accept = "accept_prob" in idata.sample_stats
2✔
811
    has_divergences = "diverging" in idata.sample_stats
2✔
812
    has_tree_depth = "tree_depth" in idata.sample_stats
2✔
813
    has_energy_error = "energy_error" in idata.sample_stats
2✔
814

815
    # Create a 2x2 layout, potentially combining energy and potential on same plot
816
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
817

818
    # Top-left: Energy diagnostics (combine Hamiltonian and Potential if both available)
819
    energy_ax = axes[0, 0]
2✔
820

821
    if has_energy and has_potential:
2✔
822
        # Both available - plot them together on one plot
823
        energy = idata.sample_stats.energy.values.flatten()
2✔
824
        potential = idata.sample_stats.potential_energy.values.flatten()
2✔
825

826
        # Plot both energies on same axis
827
        energy_ax.plot(
2✔
828
            energy, alpha=0.7, linewidth=1, color="blue", label="Hamiltonian"
829
        )
830
        energy_ax.plot(
2✔
831
            potential,
832
            alpha=0.7,
833
            linewidth=1,
834
            color="orange",
835
            label="Potential",
836
        )
837

838
        # Add difference (which relates to kinetic energy)
839
        energy_diff = energy - potential
2✔
840
        # Create second y-axis for difference
841
        ax2 = energy_ax.twinx()
2✔
842
        ax2.plot(
2✔
843
            energy_diff,
844
            alpha=0.5,
845
            linewidth=1,
846
            color="red",
847
            label="H - Potential (Kinetic)",
848
            linestyle="--",
849
        )
850
        ax2.set_ylabel("Energy Difference", color="red")
2✔
851
        ax2.tick_params(axis="y", labelcolor="red")
2✔
852

853
        energy_ax.set_xlabel("Iteration")
2✔
854
        energy_ax.set_ylabel("Energy", color="blue")
2✔
855
        energy_ax.tick_params(axis="y", labelcolor="blue")
2✔
856
        energy_ax.set_title("Hamiltonian & Potential Energy")
2✔
857
        energy_ax.legend(loc="best", fontsize="small")
2✔
858
        energy_ax.grid(True, alpha=0.3)
2✔
859

860
        # Add statistics
861
        energy_ax.text(
2✔
862
            0.02,
863
            0.98,
864
            f"H: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}\nP: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}",
865
            transform=energy_ax.transAxes,
866
            fontsize=8,
867
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
868
            verticalalignment="top",
869
        )
870

871
    elif has_energy:
×
872
        # Only Hamiltonian energy
873
        energy = idata.sample_stats.energy.values.flatten()
×
874
        energy_ax.plot(energy, alpha=0.7, linewidth=1, color="blue")
×
875
        energy_ax.set_xlabel("Iteration")
×
876
        energy_ax.set_ylabel("Hamiltonian Energy")
×
877
        energy_ax.set_title("Hamiltonian Energy Trace")
×
878
        energy_ax.grid(True, alpha=0.3)
×
879

880
    elif has_potential:
×
881
        # Only potential energy
882
        potential = idata.sample_stats.potential_energy.values.flatten()
×
883
        energy_ax.plot(potential, alpha=0.7, linewidth=1, color="orange")
×
884
        energy_ax.set_xlabel("Iteration")
×
885
        energy_ax.set_ylabel("Potential Energy")
×
886
        energy_ax.set_title("Potential Energy Trace")
×
887
        energy_ax.grid(True, alpha=0.3)
×
888

889
    else:
890
        energy_ax.text(
×
891
            0.5,
892
            0.5,
893
            "Energy data\nunavailable",
894
            ha="center",
895
            va="center",
896
            transform=energy_ax.transAxes,
897
        )
898
        energy_ax.set_title("Energy Diagnostics")
×
899

900
    # Top-right: Sampling efficiency diagnostics
901
    if has_steps:
2✔
902
        steps_ax = axes[0, 1]
2✔
903
        num_steps = idata.sample_stats.num_steps.values.flatten()
2✔
904

905
        # Show histogram with color zones for step efficiency
906
        n, bins, edges = steps_ax.hist(
2✔
907
            num_steps, bins=20, alpha=0.7, edgecolor="black"
908
        )
909

910
        # Add shaded regions for different efficiency levels
911
        # Green: efficient (tree depth ≤5, ~32 steps)
912
        # Yellow: moderate (tree depth 6-8, ~64-256 steps)
913
        # Red: inefficient (tree depth >8, >256 steps)
914
        steps_ax.axvspan(
2✔
915
            0, 64, alpha=0.1, color="green", label="Efficient (≤64)"
916
        )
917
        steps_ax.axvspan(
2✔
918
            64, 256, alpha=0.1, color="yellow", label="Moderate (65-256)"
919
        )
920
        steps_ax.axvspan(
2✔
921
            256,
922
            np.max(num_steps),
923
            alpha=0.1,
924
            color="red",
925
            label="Inefficient (>256)",
926
        )
927

928
        # Add reference lines for different tree depths
929
        for depth in [5, 7, 10]:  # Common tree depths
2✔
930
            max_steps = 2**depth
2✔
931
            steps_ax.axvline(
2✔
932
                x=max_steps,
933
                color="gray",
934
                linestyle=":",
935
                alpha=0.7,
936
                linewidth=1,
937
                label=f"2^{depth} ({max_steps})",
938
            )
939

940
        steps_ax.set_xlabel("Leapfrog Steps")
2✔
941
        steps_ax.set_ylabel("Trajectories")
2✔
942
        steps_ax.set_title("Leapfrog Steps Distribution")
2✔
943
        steps_ax.legend(loc="best", fontsize="small")
2✔
944
        steps_ax.grid(True, alpha=0.3)
2✔
945

946
        # Add efficiency statistics
947
        pct_inefficient = (num_steps > 256).mean() * 100
2✔
948
        pct_moderate = ((num_steps > 64) & (num_steps <= 256)).mean() * 100
2✔
949
        pct_efficient = (num_steps <= 64).mean() * 100
2✔
950
        steps_ax.text(
2✔
951
            0.02,
952
            0.98,
953
            f"Efficient: {pct_efficient:.1f}%\nModerate: {pct_moderate:.1f}%\nInefficient: {pct_inefficient:.1f}%\nMean steps: {np.mean(num_steps):.1f}",
954
            transform=steps_ax.transAxes,
955
            fontsize=7,
956
            bbox=dict(boxstyle="round", facecolor="lightblue", alpha=0.8),
957
            verticalalignment="top",
958
        )
959

960
    else:
961
        axes[0, 1].text(
×
962
            0.5, 0.5, "Steps data\nunavailable", ha="center", va="center"
963
        )
964
        axes[0, 1].set_title("Sampling Steps")
×
965

966
    # Bottom-left: Acceptance and NS divergence diagnostics
967
    accept_ax = axes[1, 0]
2✔
968

969
    if has_accept:
2✔
970
        accept_prob = idata.sample_stats.accept_prob.values.flatten()
2✔
971

972
        # Plot acceptance probability with guidance zones
973
        accept_ax.fill_between(
2✔
974
            range(len(accept_prob)),
975
            0.7,
976
            0.9,
977
            alpha=0.1,
978
            color="green",
979
            label="Good (0.7-0.9)",
980
        )
981
        accept_ax.fill_between(
2✔
982
            range(len(accept_prob)),
983
            0,
984
            0.6,
985
            alpha=0.1,
986
            color="red",
987
            label="Too low",
988
        )
989
        accept_ax.fill_between(
2✔
990
            range(len(accept_prob)),
991
            0.9,
992
            1.0,
993
            alpha=0.1,
994
            color="orange",
995
            label="Too high",
996
        )
997

998
        accept_ax.plot(
2✔
999
            accept_prob,
1000
            alpha=0.8,
1001
            linewidth=1,
1002
            color="blue",
1003
            label="Acceptance prob",
1004
        )
1005
        accept_ax.axhline(
2✔
1006
            0.8,
1007
            color="red",
1008
            linestyle="--",
1009
            linewidth=2,
1010
            label="NUTS target (0.8)",
1011
        )
1012
        accept_ax.set_xlabel("Iteration")
2✔
1013
        accept_ax.set_ylabel("Acceptance Probability")
2✔
1014
        accept_ax.set_title("NUTS Acceptance Diagnostic")
2✔
1015
        accept_ax.legend(loc="best", fontsize="small")
2✔
1016
        accept_ax.set_ylim(0, 1)
2✔
1017
        accept_ax.grid(True, alpha=0.3)
2✔
1018

1019
    else:
1020
        accept_ax.text(
×
1021
            0.5, 0.5, "Acceptance data\nunavailable", ha="center", va="center"
1022
        )
1023
        accept_ax.set_title("Acceptance Diagnostic")
×
1024

1025
    # Bottom-right: Summary statistics and additional diagnostics
1026
    summary_ax = axes[1, 1]
2✔
1027

1028
    # Collect available statistics
1029
    stats_lines = []
2✔
1030

1031
    if has_energy:
2✔
1032
        energy = idata.sample_stats.energy.values.flatten()
2✔
1033
        stats_lines.append(
2✔
1034
            f"Energy: μ={np.mean(energy):.1f}, σ={np.std(energy):.1f}"
1035
        )
1036

1037
    if has_potential:
2✔
1038
        potential = idata.sample_stats.potential_energy.values.flatten()
2✔
1039
        stats_lines.append(
2✔
1040
            f"Potential: μ={np.mean(potential):.1f}, σ={np.std(potential):.1f}"
1041
        )
1042

1043
    if has_steps:
2✔
1044
        num_steps = idata.sample_stats.num_steps.values.flatten()
2✔
1045
        stats_lines.append(
2✔
1046
            f"Steps: μ={np.mean(num_steps):.1f}, max={np.max(num_steps):.0f}"
1047
        )
1048
        stats_lines.append("")
2✔
1049

1050
    if has_tree_depth:
2✔
1051
        tree_depth = idata.sample_stats.tree_depth.values.flatten()
×
1052
        stats_lines.append(f"Tree depth: μ={np.mean(tree_depth):.1f}")
×
1053
        pct_max_depth = (tree_depth >= 10).mean() * 100
×
1054
        stats_lines.append(f"Max depth (≥10): {pct_max_depth:.1f}%")
×
1055

1056
    if has_divergences:
2✔
1057
        divergences = idata.sample_stats.diverging.values.flatten()
2✔
1058
        n_divergences = np.sum(divergences)
2✔
1059
        pct_divergent = n_divergences / len(divergences) * 100
2✔
1060
        stats_lines.append(
2✔
1061
            f"Divergent: {n_divergences}/{len(divergences)} ({pct_divergent:.2f}%)"
1062
        )
1063

1064
    if has_energy_error:
2✔
1065
        energy_error = idata.sample_stats.energy_error.values.flatten()
×
1066
        stats_lines.append(
×
1067
            f"Energy error: |μ|={np.mean(np.abs(energy_error)):.3f}"
1068
        )
1069

1070
    if not stats_lines:
2✔
1071
        summary_ax.text(
×
1072
            0.5,
1073
            0.5,
1074
            "No diagnostics\ndata available",
1075
            ha="center",
1076
            va="center",
1077
            transform=summary_ax.transAxes,
1078
        )
1079
        summary_ax.set_title("NUTS Statistics")
×
1080
        summary_ax.axis("off")
×
1081
    else:
1082
        summary_text = "\n".join(["NUTS Diagnostics:"] + [""] + stats_lines)
2✔
1083
        summary_ax.text(
2✔
1084
            0.05,
1085
            0.95,
1086
            summary_text,
1087
            transform=summary_ax.transAxes,
1088
            fontsize=10,
1089
            verticalalignment="top",
1090
            fontfamily="monospace",
1091
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.9),
1092
        )
1093
        summary_ax.set_title("NUTS Summary Statistics")
2✔
1094
        summary_ax.axis("off")
2✔
1095

1096
    plt.tight_layout()
2✔
1097

1098

1099
def _plot_mh_step_sizes(idata, config):
2✔
1100
    """MH step size diagnostics."""
1101
    fig, axes = plt.subplots(2, 2, figsize=config.figsize)
2✔
1102

1103
    step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
1104
    step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
1105

1106
    # Step size evolution
1107
    axes[0, 0].plot(
2✔
1108
        step_means, alpha=0.7, linewidth=1, label="Mean", color="blue"
1109
    )
1110
    axes[0, 0].plot(
2✔
1111
        step_stds, alpha=0.7, linewidth=1, label="Std", color="orange"
1112
    )
1113
    axes[0, 0].set_xlabel("Iteration")
2✔
1114
    axes[0, 0].set_ylabel("Step Size")
2✔
1115
    axes[0, 0].set_title("Step Size Evolution")
2✔
1116
    axes[0, 0].legend()
2✔
1117
    axes[0, 0].grid(True, alpha=0.3)
2✔
1118

1119
    # Step size distributions
1120
    axes[0, 1].hist(step_means, bins=30, alpha=0.5, label="Mean", color="blue")
2✔
1121
    axes[0, 1].hist(step_stds, bins=30, alpha=0.5, label="Std", color="orange")
2✔
1122
    axes[0, 1].set_xlabel("Step Size")
2✔
1123
    axes[0, 1].set_ylabel("Count")
2✔
1124
    axes[0, 1].set_title("Step Size Distributions")
2✔
1125
    axes[0, 1].legend()
2✔
1126
    axes[0, 1].grid(True, alpha=0.3)
2✔
1127

1128
    # Step size adaptation quality
1129
    axes[1, 0].plot(step_means / step_stds, alpha=0.7, linewidth=1)
2✔
1130
    axes[1, 0].set_xlabel("Iteration")
2✔
1131
    axes[1, 0].set_ylabel("Mean / Std")
2✔
1132
    axes[1, 0].set_title("Step Size Consistency")
2✔
1133
    axes[1, 0].grid(True, alpha=0.3)
2✔
1134

1135
    # Summary statistics
1136
    summary_lines = [
2✔
1137
        "Step Size Summary:",
1138
        f"Final mean: {step_means[-1]:.4f}",
1139
        f"Final std: {step_stds[-1]:.4f}",
1140
        f"Mean of means: {np.mean(step_means):.4f}",
1141
        f"Mean of stds: {np.mean(step_stds):.4f}",
1142
        "",
1143
        "Adaptation Quality:",
1144
        f"CV of means: {np.std(step_means)/np.mean(step_means):.3f}",
1145
        f"CV of stds: {np.std(step_stds)/np.mean(step_stds):.3f}",
1146
    ]
1147

1148
    axes[1, 1].text(
2✔
1149
        0.05,
1150
        0.95,
1151
        "\n".join(summary_lines),
1152
        transform=axes[1, 1].transAxes,
1153
        fontsize=10,
1154
        verticalalignment="top",
1155
        fontfamily="monospace",
1156
    )
1157
    axes[1, 1].set_title("Step Size Statistics")
2✔
1158
    axes[1, 1].axis("off")
2✔
1159

1160
    plt.tight_layout()
2✔
1161

1162

1163
def _plot_grouped_traces(idata, figsize):
2✔
1164
    """Create grouped trace plots for delta, phi, and weights parameters."""
1165
    # Define color cycle for multiple parameters in each group
1166
    colors = [
×
1167
        "blue",
1168
        "red",
1169
        "green",
1170
        "orange",
1171
        "purple",
1172
        "brown",
1173
        "pink",
1174
        "gray",
1175
        "olive",
1176
        "cyan",
1177
    ]
1178

1179
    # Group parameters by type
1180
    delta_params = [
×
1181
        param
1182
        for param in idata.posterior.data_vars
1183
        if param.startswith("delta")
1184
    ]
1185
    phi_params = [
×
1186
        param for param in idata.posterior.data_vars if param.startswith("phi")
1187
    ]
1188
    weights_params = [
×
1189
        param
1190
        for param in idata.posterior.data_vars
1191
        if param.startswith("weights")
1192
    ]
1193

1194
    # Create 3 subplots
1195
    fig, axes = plt.subplots(3, 1, figsize=figsize)
×
1196

1197
    # Plot delta parameters
1198
    ax = axes[0]
×
1199
    if delta_params:
×
1200
        for i, param in enumerate(delta_params):
×
1201
            color = colors[i % len(colors)]
×
1202
            # For multivariate parameters, merge across chains
1203
            values = idata.posterior[param].values
×
1204
            if values.ndim == 3:  # (chain, draw, possibly_channel)
×
1205
                if values.shape[-1] == 1:
×
1206
                    values = values.squeeze(-1)  # Remove singleton dimension
×
1207
                else:
1208
                    values = values.reshape(
×
1209
                        values.shape[0] * values.shape[1], -1
1210
                    )  # Flatten chain/draw dims
1211
                    if values.shape[-1] > 1:  # Multiple values per timestep
×
1212
                        values = values.mean(
×
1213
                            axis=-1
1214
                        )  # Average across channels if needed
1215
                    else:
1216
                        values = values.flatten()
×
1217
            elif values.ndim == 2:  # (chain, draw)
×
1218
                values = values.flatten()
×
1219

1220
            ax.plot(values, color=color, alpha=0.7, linewidth=1, label=param)
×
1221
        ax.set_ylabel("Delta Parameters")
×
1222
        ax.set_title("Delta Parameters Trace")
×
1223
        ax.legend(loc="upper right", fontsize="small")
×
1224
        ax.grid(True, alpha=0.3)
×
1225
    else:
1226
        ax.text(
×
1227
            0.5,
1228
            0.5,
1229
            "No delta parameters found",
1230
            ha="center",
1231
            va="center",
1232
            transform=ax.transAxes,
1233
        )
1234
        ax.set_title("Delta Parameters")
×
1235
        ax.axis("off")
×
1236

1237
    # Plot phi parameters
1238
    ax = axes[1]
×
1239
    if phi_params:
×
1240
        for i, param in enumerate(phi_params):
×
1241
            color = colors[i % len(colors)]
×
1242
            # For multivariate parameters, merge across chains
1243
            values = idata.posterior[param].values
×
1244
            if values.ndim == 3:  # (chain, draw, possibly_channel)
×
1245
                if values.shape[-1] == 1:
×
1246
                    values = values.squeeze(-1)  # Remove singleton dimension
×
1247
                else:
1248
                    values = values.reshape(
×
1249
                        values.shape[0] * values.shape[1], -1
1250
                    )  # Flatten chain/draw dims
1251
                    if values.shape[-1] > 1:  # Multiple values per timestep
×
1252
                        values = values.mean(
×
1253
                            axis=-1
1254
                        )  # Average across channels if needed
1255
                    else:
1256
                        values = values.flatten()
×
1257
            elif values.ndim == 2:  # (chain, draw)
×
1258
                values = values.flatten()
×
1259

1260
            ax.plot(values, color=color, alpha=0.7, linewidth=1, label=param)
×
1261
        ax.set_ylabel("Phi Parameters")
×
1262
        ax.set_title("Phi Parameters Trace")
×
1263
        ax.legend(loc="upper right", fontsize="small")
×
1264
        ax.grid(True, alpha=0.3)
×
1265
    else:
1266
        ax.text(
×
1267
            0.5,
1268
            0.5,
1269
            "No phi parameters found",
1270
            ha="center",
1271
            va="center",
1272
            transform=ax.transAxes,
1273
        )
1274
        ax.set_title("Phi Parameters")
×
1275
        ax.axis("off")
×
1276

1277
    # Plot weights parameters (these are higher dimensional)
1278
    ax = axes[2]
×
1279
    if weights_params:
×
1280
        # For weights, we'll show the mean across weight dimensions if they have shape (chain, draw, weight_dim)
1281
        max_traces = min(
×
1282
            10, len(weights_params)
1283
        )  # Limit number of weight parameters to show
1284
        for i, param in enumerate(weights_params[:max_traces]):
×
1285
            color = colors[i % len(colors)]
×
1286
            values = idata.posterior[param].values
×
1287

1288
            # Handle different dimensionalities
1289
            if values.ndim == 4:  # (chain, draw, dim1, dim2)
×
1290
                values = values.mean(axis=-1).mean(axis=-1).flatten()
×
1291
            elif values.ndim == 3:  # (chain, draw, weight_dim)
×
1292
                values = values.mean(
×
1293
                    axis=-1
1294
                ).flatten()  # Average across weight dimension
1295
            elif values.ndim == 2:  # (chain, draw)
×
1296
                values = values.flatten()
×
1297

1298
            ax.plot(values, color=color, alpha=0.7, linewidth=1, label=param)
×
1299

1300
        if len(weights_params) > max_traces:
×
1301
            ax.text(
×
1302
                0.02,
1303
                0.02,
1304
                f"Showing {max_traces} of {len(weights_params)} weight parameters",
1305
                transform=ax.transAxes,
1306
                fontsize="small",
1307
                bbox=dict(boxstyle="round", facecolor="lightcoral", alpha=0.7),
1308
            )
1309

1310
        ax.set_xlabel("Iteration")
×
1311
        ax.set_ylabel("Weights Parameters (mean)")
×
1312
        ax.set_title("Weights Parameters Trace (averaged)")
×
1313
        ax.legend(loc="upper right", fontsize="small")
×
1314
        ax.grid(True, alpha=0.3)
×
1315
    else:
1316
        ax.text(
×
1317
            0.5,
1318
            0.5,
1319
            "No weights parameters found",
1320
            ha="center",
1321
            va="center",
1322
            transform=ax.transAxes,
1323
        )
1324
        ax.set_title("Weights Parameters")
×
1325
        ax.axis("off")
×
1326

1327
    plt.tight_layout()
×
1328

1329

1330
def _group_parameters_simple(idata):
2✔
1331
    """Simple parameter grouping for counting."""
1332
    param_groups = {"phi": [], "delta": [], "weights": [], "other": []}
2✔
1333

1334
    for param in idata.posterior.data_vars:
2✔
1335
        if param.startswith("phi"):
2✔
1336
            param_groups["phi"].append(param)
2✔
1337
        elif param.startswith("delta"):
2✔
1338
            param_groups["delta"].append(param)
2✔
1339
        elif param.startswith("weights"):
2✔
1340
            param_groups["weights"].append(param)
2✔
1341
        else:
1342
            param_groups["other"].append(param)
×
1343

1344
    return {k: v for k, v in param_groups.items() if v}
2✔
1345

1346

1347
def generate_diagnostics_summary(idata, outdir):
2✔
1348
    """Generate comprehensive text summary."""
1349
    summary = []
2✔
1350
    summary.append("=== MCMC Diagnostics Summary ===\n")
2✔
1351

1352
    # Basic info
1353
    attrs = getattr(idata, "attrs", {}) or {}
2✔
1354
    if not hasattr(attrs, "get"):
2✔
1355
        attrs = dict(attrs)
×
1356

1357
    n_samples = idata.posterior.sizes.get("draw", 0)
2✔
1358
    n_chains = idata.posterior.sizes.get("chain", 1)
2✔
1359
    n_params = len(list(idata.posterior.data_vars))
2✔
1360
    sampler_type = attrs.get("sampler_type", "Unknown")
2✔
1361

1362
    summary.append(f"Sampler: {sampler_type}")
2✔
1363
    summary.append(
2✔
1364
        f"Samples: {n_samples} per chain × {n_chains} chains = {n_samples * n_chains} total"
1365
    )
1366
    summary.append(f"Parameters: {n_params}")
2✔
1367

1368
    # Parameter breakdown
1369
    param_groups = _group_parameters_simple(idata)
2✔
1370
    if param_groups:
2✔
1371
        param_summary = ", ".join(
2✔
1372
            [f"{k}: {len(v)}" for k, v in param_groups.items()]
1373
        )
1374
        summary.append(f"Parameter groups: {param_summary}")
2✔
1375

1376
    # ESS
1377
    try:
2✔
1378
        ess = az.ess(idata).to_array().values.flatten()
2✔
1379
        ess_values = ess[~np.isnan(ess)]
2✔
1380

1381
        if len(ess_values) > 0:
2✔
1382
            summary.append(
2✔
1383
                f"\nESS: min={ess_values.min():.0f}, mean={ess_values.mean():.0f}, max={ess_values.max():.0f}"
1384
            )
1385
            summary.append(f"ESS ≥ 400: {(ess_values >= 400).mean()*100:.1f}%")
2✔
1386
    except Exception as e:
×
1387
        summary.append(f"\nESS: unavailable")
×
1388

1389
    # R-hat
1390
    try:
2✔
1391
        rhat = az.rhat(idata).to_array().values.flatten()
2✔
1392
        rhat_values = rhat[~np.isnan(rhat)]
2✔
1393

1394
        if len(rhat_values) > 0:
2✔
1395
            summary.append(
×
1396
                f"R-hat: max={rhat_values.max():.3f}, mean={rhat_values.mean():.3f}"
1397
            )
1398
            summary.append(
×
1399
                f"R-hat > 1.01: {(rhat_values > 1.01).mean()*100:.1f}%"
1400
            )
1401
    except Exception:
×
1402
        summary.append(f"R-hat: unavailable")
×
1403

1404
    # Acceptance
1405
    accept_key = None
2✔
1406
    if "accept_prob" in idata.sample_stats:
2✔
1407
        accept_key = "accept_prob"
2✔
1408
    elif "acceptance_rate" in idata.sample_stats:
×
1409
        accept_key = "acceptance_rate"
×
1410

1411
    if accept_key is not None:
2✔
1412
        accept_rate = idata.sample_stats[accept_key].values.mean()
2✔
1413
        target_rate = attrs.get(
2✔
1414
            "target_accept_rate", attrs.get("target_accept_prob", 0.44)
1415
        )
1416
        summary.append(
2✔
1417
            f"Acceptance rate: {accept_rate:.3f} (target: {target_rate:.3f})"
1418
        )
1419

1420
    # PSD accuracy diagnostics (requires true_psd in attrs)
1421
    has_true_psd = "true_psd" in attrs
2✔
1422

1423
    if has_true_psd:
2✔
1424
        coverage_level = attrs.get("coverage_level")
2✔
1425
        coverage_label = (
2✔
1426
            f"{int(round(coverage_level * 100))}% interval coverage"
1427
            if coverage_level is not None
1428
            else "Interval coverage"
1429
        )
1430

1431
        def _format_riae_line(value, errorbars, prefix="  "):
2✔
1432
            line = f"{prefix}RIAE: {value:.3f}"
2✔
1433
            if errorbars:
2✔
1434
                q05, q25, median, q75, q95 = errorbars
2✔
1435
                line += f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
2✔
1436
            summary.append(line)
2✔
1437

1438
        def _format_coverage_line(value, prefix="  "):
2✔
1439
            if value is None:
2✔
1440
                return
×
1441
            summary.append(f"{prefix}{coverage_label}: {value * 100:.1f}%")
2✔
1442

1443
        summary.append("\nPSD accuracy diagnostics:")
2✔
1444

1445
        if "riae" in attrs:
2✔
1446
            _format_riae_line(attrs["riae"], attrs.get("riae_errorbars"))
2✔
1447
        if "coverage" in attrs:
2✔
1448
            _format_coverage_line(attrs["coverage"])
2✔
1449

1450
        channel_indices = sorted(
2✔
1451
            int(key.replace("riae_ch", ""))
1452
            for key in attrs.keys()
1453
            if key.startswith("riae_ch")
1454
        )
1455

1456
        for idx in channel_indices:
2✔
1457
            metrics = []
×
1458
            riae_key = f"riae_ch{idx}"
×
1459
            cov_key = f"coverage_ch{idx}"
×
1460
            error_key = f"riae_errorbars_ch{idx}"
×
1461

1462
            if riae_key in attrs:
×
1463
                riae_line = f"RIAE {attrs[riae_key]:.3f}"
×
1464
                errorbars = attrs.get(error_key)
×
1465
                if errorbars:
×
1466
                    q05, _, median, _, q95 = errorbars
×
1467
                    riae_line += (
×
1468
                        f" (median {median:.3f}, 5-95% [{q05:.3f}, {q95:.3f}])"
1469
                    )
1470
                metrics.append(riae_line)
×
1471

1472
            if cov_key in attrs:
×
1473
                metrics.append(f"{coverage_label} {attrs[cov_key] * 100:.1f}%")
×
1474

1475
            if metrics:
×
1476
                summary.append(f"  Channel {idx}: " + "; ".join(metrics))
×
1477

1478
    # Overall assessment
1479
    try:
2✔
1480
        if len(ess_values) > 0 and len(rhat_values) > 0:
2✔
1481
            ess_good = (ess_values >= 400).mean() * 100
×
1482
            rhat_good = (rhat_values <= 1.01).mean() * 100
×
1483

1484
            summary.append(f"\nOverall Convergence Assessment:")
×
1485
            if ess_good >= 90 and rhat_good >= 90:
×
1486
                summary.append("  Status: EXCELLENT ✓")
×
1487
            elif ess_good >= 75 and rhat_good >= 75:
×
1488
                summary.append("  Status: GOOD ✓")
×
1489
            else:
1490
                summary.append("  Status: NEEDS ATTENTION ⚠")
×
1491
    except:
×
1492
        pass
×
1493

1494
    summary_text = "\n".join(summary)
2✔
1495

1496
    if outdir:
2✔
1497
        with open(f"{outdir}/diagnostics_summary.txt", "w") as f:
2✔
1498
            f.write(summary_text)
2✔
1499

1500
    print("\n" + summary_text + "\n")
2✔
1501
    return summary_text
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc