• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 17935883473

23 Sep 2025 04:36AM UTC coverage: 84.453% (-2.2%) from 86.608%
17935883473

push

github

avivajpeyi
improve diagnostics

316 of 355 branches covered (89.01%)

Branch coverage included in aggregate %.

168 of 294 new or added lines in 5 files covered. (57.14%)

78 existing lines in 5 files now uncovered.

2449 of 2919 relevant lines covered (83.9%)

1.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.17
/src/log_psplines/plotting/diagnostics.py
1
import arviz as az
2✔
2
import matplotlib.pyplot as plt
3
import numpy as np
2✔
4
from typing import Optional
2✔
5

2✔
6

7

8

2✔
9

10

11
def create_fallback_plot(
12
    outdir: str,
13
    n_channels: Optional[int] = None,
14
    n_freq: Optional[int] = None,
15
    sampler_type: str = "Unknown",
16
    runtime: Optional[float] = None,
17
    figsize: tuple = (8, 6),
18
) -> None:
19
    """
20
    Create basic fallback plot when main plotting fails.
21

22
    Parameters
23
    ----------
24
    outdir : str
25
        Output directory for saving plot
26
    n_channels : int, optional
27
        Number of channels (for multivariate)
28
    n_freq : int, optional
29
        Number of frequencies
30
    sampler_type : str
31
        Type of sampler used
32
    runtime : float, optional
33
        Runtime in seconds
NEW
34
    figsize : tuple
×
NEW
35
        Figure size for plot
×
36
    """
NEW
37
    if outdir is None:
×
NEW
38
        return
×
NEW
39

×
NEW
40
    fig, ax = plt.subplots(figsize=figsize)
×
NEW
41
    text_parts = ["Analysis Complete"]
×
NEW
42
    if n_channels is not None:
×
NEW
43
        text_parts.append(f"Channels: {n_channels}")
×
NEW
44
    if n_freq is not None:
×
NEW
45
        text_parts.append(f"Frequencies: {n_freq}")
×
46
    text_parts.append(f"Sampler: {sampler_type}")
NEW
47
    if runtime is not None:
×
48
        text_parts.append(f"Runtime: {runtime:.2f}s")
49

50
    ax.text(0.5, 0.5, "\n".join(text_parts), ha='center', va='center', fontsize=14, transform=ax.transAxes)
51
    ax.set_title("Analysis Summary")
52
    ax.axis('off')
53
    plt.savefig(f"{outdir}/analysis_summary.png", bbox_inches='tight')
54
    plt.close(fig)
55

UNCOV
56

×
UNCOV
57
def plot_diagnostics(
×
UNCOV
58
    idata: az.InferenceData,
×
UNCOV
59
    outdir: str,
×
60
    n_channels: Optional[int] = None,
61
    n_freq: Optional[int] = None,
62
    runtime: Optional[float] = None,
2✔
63
    figsize: tuple = (12, 8),
64
) -> None:
65
    """
66
    Plot comprehensive MCMC diagnostics including trace plots, summary statistics,
67
    and sampler-specific metrics.
68

69
    Parameters
70
    ----------
71
    idata : az.InferenceData
72
        Inference data from MCMC analysis
73
    outdir : str
74
        Output directory for saving plots
75
    n_channels : int, optional
76
        Number of channels (for multivariate)
77
    n_freq : int, optional
78
        Number of frequencies
79
    runtime : float, optional
80
        Runtime in seconds
81
    figsize : tuple
82
        Figure size for plots
83
    """
84
    if outdir is None:
85
        return
86

87
    try:
88
        # Trace plots
89
        az.plot_trace(idata, figsize=figsize)
2✔
NEW
90
        plt.suptitle("Trace plots")
×
91
        plt.tight_layout()
92
        plt.savefig(f"{outdir}/trace_plots.png")
2✔
93
        plt.close()
94

2✔
95
        # Acceptance rate plot with enhanced interpretation
2✔
96
        if "acceptance_rate" in idata.sample_stats:
2✔
97
            fig, ax = plt.subplots(figsize=(10, 4))
2✔
98
            accept_rates = idata.sample_stats.acceptance_rate.values.flatten()
2✔
99

100
            # Determine if this is NUTS or MH based on target acceptance rate
101
            target_rate = idata.attrs.get("target_accept_rate", 0.44)
2✔
102
            sampler_type = "NUTS" if target_rate > 0.5 else "MH"
2✔
103
            good_range = (0.7, 0.9) if target_rate > 0.5 else (0.2, 0.5)
2✔
104

105
            ax.plot(accept_rates, alpha=0.7, color='blue')
106

2✔
107
            # Add target line
2✔
108
            ax.axhline(y=target_rate, color="red", linestyle="--", linewidth=2,
2✔
109
                      label=f"Target ({target_rate})")
110

2✔
111
            # Add optimal range shading
112
            ax.axhspan(good_range[0], good_range[1], alpha=0.1, color='green',
113
                      label='.1f')
2✔
114

115
            ax.axhspan(0, good_range[0], alpha=0.1, color='red', label='Too low')
116
            if good_range[1] < 1.0:
117
                ax.axhspan(good_range[1], 1.0, alpha=0.1, color='orange', label='Too high')
118

119
            ax.set_xlabel("Iteration")
120
            ax.set_ylabel("Acceptance Rate")
121
            ax.set_title(f"{sampler_type} Acceptance Rate Over Time")
122
            ax.legend(loc='best', fontsize='small')
2✔
123
            ax.grid(True, alpha=0.3)
124

125
            # Add interpretation text
126
            interpretation = f"{sampler_type} target: {target_rate}. "
127
            if target_rate > 0.5:
128
                interpretation += "NUTS aims for 0.7-0.9"
129
            else:
130
                interpretation += "MH aims for 0.2-0.5"
2✔
131
            ax.text(0.02, 0.02, interpretation, transform=ax.transAxes,
132
                   fontsize='small', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
133

2✔
134
            plt.tight_layout()
2✔
135
            plt.savefig(f"{outdir}/acceptance_rate.png")
136
            plt.close()
137

138
        # Step size evolution
139
        if "step_size_mean" in idata.sample_stats:
140
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
141

142
            step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
143
            step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
144

2✔
145
            ax1.plot(step_means, alpha=0.7)
2✔
146
            ax1.set_xlabel("Iteration")
2✔
147
            ax1.set_ylabel("Mean Step Size")
148
            ax1.set_title("Step Size Evolution")
149
            ax1.grid(True, alpha=0.3)
2✔
150

2✔
UNCOV
151
            ax2.plot(step_stds, alpha=0.7, color="orange")
×
152
            ax2.set_xlabel("Iteration")
153
            ax2.set_ylabel("Step Size Std")
2✔
154
            ax2.set_title("Step Size Variability")
2✔
155
            ax2.grid(True, alpha=0.3)
156

157
            plt.tight_layout()
158
            plt.savefig(f"{outdir}/step_size_evolution.png")
159
            plt.close()
160

161
        # Summary diagnostics plot (2x2)
162
        summary_fig, axes = plt.subplots(2, 2, figsize=figsize)
163

2✔
164
        # Plot 1: Log likelihood trace
2✔
165
        if "log_likelihood" in idata.sample_stats:
2✔
166
            axes[0, 0].plot(idata.sample_stats["log_likelihood"].values.flatten())
167
            axes[0, 0].set_title("Log Likelihood Trace")
168
            axes[0, 0].set_xlabel("Iteration")
2✔
169
            axes[0, 0].set_ylabel("Log Likelihood")
2✔
170

171
        # Plot 2: Sample summary
2✔
172
        try:
2✔
173
            summary_df = az.summary(idata)
174
            axes[0, 1].text(0.1, 0.9, f"Parameters: {len(summary_df)}", transform=axes[0, 1].transAxes)
2✔
175
            if n_channels is not None:
2✔
176
                axes[0, 1].text(0.1, 0.8, f"Channels: {n_channels}", transform=axes[0, 1].transAxes)
2✔
177
            if n_freq is not None:
2✔
178
                axes[0, 1].text(0.1, 0.7, f"Frequencies: {n_freq}", transform=axes[0, 1].transAxes)
2✔
179
            if runtime is not None:
180
                axes[0, 1].text(0.1, 0.6, f"Runtime: {runtime:.2f}s", transform=axes[0, 1].transAxes)
2✔
181
            axes[0, 1].set_title("Summary Statistics")
2✔
182
            axes[0, 1].axis('off')
2✔
183
        except:
2✔
184
            axes[0, 1].text(0.5, 0.5, "Summary unavailable", ha='center', va='center', transform=axes[0, 1].transAxes)
2✔
185
            axes[0, 1].set_title("Summary Statistics")
186
            axes[0, 1].axis('off')
2✔
187

2✔
188
        # Plot 3: Parameter count by type
2✔
189
        param_counts = {}
190
        for param in idata.posterior.data_vars:
191
            param_type = param.split('_')[0]  # Extract prefix (delta, phi, weights)
2✔
192
            param_counts[param_type] = param_counts.get(param_type, 0) + 1
193

194
        if param_counts:
2✔
195
            axes[1, 0].bar(param_counts.keys(), param_counts.values())
2✔
196
            axes[1, 0].set_title("Parameter Count by Type")
197
            axes[1, 0].set_ylabel("Count")
198

2✔
199
        # Plot 4: ESS summary with reference lines and interpretation
2✔
200
        try:
2✔
201
            ess = az.ess(idata)
202
            ess_values = ess.to_array().values.flatten()
203
            ess_values = ess_values[~np.isnan(ess_values)]
2✔
204
            if len(ess_values) > 0:
2✔
205
                n, bins, patches = axes[1, 1].hist(ess_values, bins=20, alpha=0.7)
2✔
206

207
                # Add reference lines for ESS interpretation
208
                ess_thresholds = [
209
                    (400, 'red', '--', 'Minimum reliable ESS'),
210
                    (1000, 'orange', '--', 'Good ESS'),
211
                    (np.max(ess_values), 'green', ':', f'Max ESS = {np.max(ess_values):.0f}')
2✔
NEW
212
                ]
×
213

214
                for threshold, color, style, label in ess_thresholds:
215
                    axes[1, 1].axvline(x=threshold, color=color, linestyle=style,
216
                                     linewidth=2, alpha=0.8, label=label)
217

218
                axes[1, 1].set_title("Effective Sample Size Distribution")
2✔
NEW
219
                axes[1, 1].set_xlabel("ESS")
×
220
                axes[1, 1].set_ylabel("Count")
221
                axes[1, 1].legend(loc='upper right', fontsize='x-small')
222

223
                # Add summary stats text
224
                min_ess = np.min(ess_values)
225
                mean_ess = np.mean(ess_values)
2✔
NEW
226
                pct_good = np.mean(ess_values >= 400) * 100
×
227
                axes[1, 1].text(0.02, 0.98, f'Min ESS: {min_ess:.0f}\nMean ESS: {mean_ess:.0f}\n≥400: {pct_good:.1f}%',
228
                               transform=axes[1, 1].transAxes, fontsize='small',
229
                               verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
230
            else:
231
                axes[1, 1].text(0.5, 0.5, "ESS unavailable", ha='center', va='center', transform=axes[1, 1].transAxes)
232
        except:
2✔
233
            axes[1, 1].text(0.5, 0.5, "ESS unavailable", ha='center', va='center', transform=axes[1, 1].transAxes)
2✔
NEW
234

×
NEW
235
        plt.tight_layout()
×
236
        plt.savefig(f"{outdir}/summary_diagnostics.png", dpi=150, bbox_inches='tight')
237
        plt.close(summary_fig)
238

239
    except Exception as e:
240
        print(f"Error generating MCMC diagnostics: {e}")
241

242
    # NUTS-specific diagnostics
NEW
243
    try:
×
NEW
244
        # Divergences over time (cumulative count)
×
245
        if "diverging" in idata.sample_stats:
246
            divergences = idata.sample_stats.diverging.values.flatten()
247
            cumulative_divergences = np.cumsum(divergences.astype(int))
2✔
248

2✔
249
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
2✔
250

251
            # Cumulative divergences
252
            ax1.plot(cumulative_divergences, alpha=0.7)
2✔
253
            ax1.set_xlabel("Iteration")
254
            ax1.set_ylabel("Cumulative Divergences")
2✔
255
            ax1.set_title("Cumulative Divergent Transitions")
2✔
256
            ax1.grid(True, alpha=0.3)
2✔
257

2✔
258
            # Divergence rate over time (rolling window)
259
            window_size = min(100, len(divergences) // 10)
260
            if window_size > 0:
2✔
261
                rolling_divergence_rate = []
2✔
262
                for i in range(window_size, len(divergences)):
2✔
263
                    rate = np.mean(divergences[i-window_size:i])
2✔
264
                    rolling_divergence_rate.append(rate)
2✔
265

2✔
266
                ax2.plot(np.arange(window_size, len(divergences)), rolling_divergence_rate, alpha=0.7)
267
                ax2.set_xlabel("Iteration")
268
                ax2.set_ylabel(f"Divergence Rate (last {window_size} steps)")
269
                ax2.set_title("Rolling Divergence Rate")
270
                ax2.grid(True, alpha=0.3)
2✔
271

272
            plt.tight_layout()
273
            plt.savefig(f"{outdir}/divergences.png")
274
            plt.close()
275

276
        # Energy diagnostics (Hamiltonian and Potential energy)
277
        energy_available = "energy" in idata.sample_stats
278
        potential_energy_available = "potential_energy" in idata.sample_stats
279

280
        if energy_available or potential_energy_available:
281
            n_rows = 1 if (energy_available and potential_energy_available) else 1
2✔
282
            n_cols = 4 if (energy_available and potential_energy_available) else 2
2✔
283

284
            if energy_available and potential_energy_available:
285
                fig, axes = plt.subplots(2, 4, figsize=(16, 8))
286
            elif energy_available:
287
                fig, axes = plt.subplots(2, 2, figsize=figsize)
288
            else:  # potential_energy only
289
                fig, axes = plt.subplots(2, 2, figsize=figsize)
290

291
            plot_col = 0
2✔
292

2✔
293
            # Hamiltonian Energy plots
2✔
294
            if energy_available:
2✔
295
                energy = idata.sample_stats.energy.values.flatten()
296

297
                # Energy trace
2✔
298
                if energy_available and potential_energy_available:
2✔
299
                    ax1, ax2 = axes[0, plot_col], axes[0, plot_col+1]
2✔
300
                    ax3, ax4 = axes[1, plot_col], axes[1, plot_col+1]
2✔
301
                else:
302
                    ax1, ax2 = axes[0, 0], axes[0, 1]
303
                    ax3, ax4 = axes[1, 0], axes[1, 1]
304

305
                ax1.plot(energy, alpha=0.7)
306
                ax1.set_xlabel("Iteration")
307
                ax1.set_ylabel("Hamiltonian Energy")
308
                ax1.set_title("Hamiltonian Energy Trace")
309
                ax1.grid(True, alpha=0.3)
310

311
                # Energy histogram
NEW
312
                ax2.hist(energy, bins=30, alpha=0.7, density=True)
×
313
                ax2.set_xlabel("Hamiltonian Energy")
314
                ax2.set_ylabel("Density")
315
                ax2.set_title("Energy Distribution")
316
                ax2.grid(True, alpha=0.3)
317

318
                # Energy change per step
319
                energy_diffs = np.diff(energy)
NEW
320
                ax3.plot(energy_diffs, alpha=0.7)
×
NEW
321
                ax3.set_xlabel("Iteration")
×
322
                ax3.set_ylabel("Energy Change")
323
                ax3.set_title("Energy Changes Between Steps")
324
                ax3.grid(True, alpha=0.3)
325

326
                # Energy change distribution
327
                ax4.hist(energy_diffs, bins=30, alpha=0.7, density=True)
328
                ax4.set_xlabel("Energy Change")
329
                ax4.set_ylabel("Density")
330
                ax4.set_title("Energy Change Distribution")
2✔
331
                ax4.grid(True, alpha=0.3)
2✔
332

333
                if energy_available and potential_energy_available:
334
                    plot_col += 2
2✔
335

NEW
336
            # Potential Energy plots
×
NEW
337
            if potential_energy_available:
×
338
                potential_energy = idata.sample_stats.potential_energy.values.flatten()
339

340
                if energy_available and potential_energy_available:
2✔
341
                    ax1, ax2 = axes[0, plot_col], axes[0, plot_col+1]
342
                    ax3, ax4 = axes[1, plot_col], axes[1, plot_col+1]
2✔
343
                else:
2✔
344
                    ax1, ax2 = axes[0, 0], axes[0, 1]
2✔
345
                    ax3, ax4 = axes[1, 0], axes[1, 1]
346

2✔
347
                # Potential energy trace
348
                ax1.plot(potential_energy, alpha=0.7, color='orange')
349
                ax1.set_xlabel("Iteration")
2✔
350
                ax1.set_ylabel("Potential Energy")
2✔
351
                ax1.set_title("Potential Energy Trace")
2✔
352
                ax1.grid(True, alpha=0.3)
2✔
353

2✔
354
                # Potential energy histogram
355
                ax2.hist(potential_energy, bins=30, alpha=0.7, density=True, color='orange')
356
                ax2.set_xlabel("Potential Energy")
2✔
357
                ax2.set_ylabel("Density")
2✔
358
                ax2.set_title("Potential Energy Distribution")
2✔
359
                ax2.grid(True, alpha=0.3)
2✔
360

2✔
361
                # Potential energy change per step
2✔
362
                potential_energy_diffs = np.diff(potential_energy)
363
                ax3.plot(potential_energy_diffs, alpha=0.7, color='orange')
2✔
364
                ax3.set_xlabel("Iteration")
365
                ax3.set_ylabel("Potential Energy Change")
366
                ax3.set_title("Potential Energy Changes Between Steps")
367
                ax3.grid(True, alpha=0.3)
368

2✔
369
                # Potential energy change distribution
2✔
370
                ax4.hist(potential_energy_diffs, bins=30, alpha=0.7, density=True, color='orange')
2✔
371
                ax4.set_xlabel("Potential Energy Change")
2✔
372
                ax4.set_ylabel("Density")
373
                ax4.set_title("Potential Energy Change Distribution")
2✔
374
                ax4.grid(True, alpha=0.3)
2✔
375

2✔
376
            plt.tight_layout()
377
            plt.savefig(f"{outdir}/energy_diagnostics.png")
378
            plt.close()
2✔
379

2✔
380
        # Tree depth and step count diagnostics
381
        if "tree_depth" in idata.sample_stats and "n_steps" in idata.sample_stats:
2✔
382
            tree_depth = idata.sample_stats.tree_depth.values.flatten()
2✔
383
            n_steps = idata.sample_stats.n_steps.values.flatten()
384

385
            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
2✔
386

387
            # Tree depth trace
388
            ax1.plot(tree_depth, alpha=0.7)
389
            ax1.set_xlabel("Iteration")
2✔
NEW
390
            ax1.set_ylabel("Tree Depth")
×
391
            ax1.set_title("NUTS Tree Depth")
2✔
NEW
392
            ax1.grid(True, alpha=0.3)
×
393

394
            # Steps per trajectory
2✔
395
            ax2.plot(n_steps, alpha=0.7)
396
            ax2.set_xlabel("Iteration")
2✔
397
            ax2.set_ylabel("Leapfrog Steps")
398
            ax2.set_title("Leapfrog Steps Per Trajectory")
399
            ax2.grid(True, alpha=0.3)
2✔
NEW
400

×
401
            # Tree depth distribution
402
            unique_depths, counts_depth = np.unique(tree_depth, return_counts=True)
NEW
403
            ax3.bar(unique_depths, counts_depth, alpha=0.7)
×
NEW
404
            ax3.set_xlabel("Tree Depth")
×
NEW
405
            ax3.set_ylabel("Count")
×
406
            ax3.set_title("Tree Depth Distribution")
NEW
407
            ax3.grid(True, alpha=0.3)
×
NEW
408

×
409
            # Steps distribution
NEW
410
            unique_steps, counts_steps = np.unique(n_steps, return_counts=True)
×
NEW
411
            ax4.bar(unique_steps, counts_steps, alpha=0.7)
×
NEW
412
            ax4.set_xlabel("Leapfrog Steps")
×
NEW
413
            ax4.set_ylabel("Count")
×
NEW
414
            ax4.set_title("Leapfrog Steps Distribution")
×
415
            ax4.grid(True, alpha=0.3)
416

NEW
417
            plt.tight_layout()
×
NEW
418
            plt.savefig(f"{outdir}/nuts_trajectory_stats.png")
×
NEW
419
            plt.close()
×
NEW
420

×
NEW
421
        # Energy error diagnostics
×
422
        if "energy_error" in idata.sample_stats:
423
            energy_error = idata.sample_stats.energy_error.values.flatten()
NEW
424
            abs_energy_error = np.abs(energy_error)
×
NEW
425

×
NEW
426
            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
×
NEW
427

×
NEW
428
            # Energy error trace
×
NEW
429
            ax1.plot(energy_error, alpha=0.7)
×
430
            ax1.set_xlabel("Iteration")
431
            ax1.set_ylabel("Energy Error")
NEW
432
            ax1.set_title("Energy Error Trace")
×
NEW
433
            ax1.grid(True, alpha=0.3)
×
NEW
434

×
NEW
435
            # Absolute energy error trace
×
NEW
436
            ax2.plot(abs_energy_error, alpha=0.7)
×
437
            ax2.set_xlabel("Iteration")
NEW
438
            ax2.set_ylabel("Absolute Energy Error")
×
NEW
439
            ax2.set_title("Absolute Energy Error Trace")
×
440
            ax2.grid(True, alpha=0.3)
441

442
            # Energy error histogram
2✔
443
            ax3.hist(energy_error, bins=30, alpha=0.7, density=True)
2✔
444
            ax3.set_xlabel("Energy Error")
445
            ax3.set_ylabel("Density")
446
            ax3.set_title("Energy Error Distribution")
447
            ax3.grid(True, alpha=0.3)
2✔
NEW
448

×
NEW
449
            # Large energy errors (> 0.1) with divergences
×
450
            error_threshold = 0.1
451
            large_errors = abs_energy_error > error_threshold
2✔
452

2✔
453
            if "diverging" in idata.sample_stats:
454
                divergences = idata.sample_stats.diverging.values.flatten()
455
                ax4.scatter(abs_energy_error, divergences.astype(int),
2✔
456
                          alpha=0.5, s=2, c='red', label='Divergent')
2✔
457
                ax4.scatter(abs_energy_error[~divergences], np.zeros(sum(~divergences)),
2✔
458
                          alpha=0.5, s=2, c='blue', label='Non-divergent')
2✔
459
                ax4.set_xlabel("Absolute Energy Error")
2✔
460
                ax4.set_ylabel("Divergence (0/1)")
461
                ax4.set_title(f"Energy Error vs Divergence (> {error_threshold})")
462
                ax4.legend()
2✔
463
                ax4.grid(True, alpha=0.3)
464
            else:
465
                ax4.hist(abs_energy_error, bins=30, alpha=0.7, density=True, log=True)
466
                ax4.set_xlabel("Absolute Energy Error")
467
                ax4.set_ylabel("Log Density")
468
                ax4.set_title("Energy Error Distribution (log scale)")
469
                ax4.grid(True, alpha=0.3)
2✔
470

2✔
471
            plt.tight_layout()
2✔
472
            plt.savefig(f"{outdir}/energy_error_diagnostics.png")
2✔
473
            plt.close()
474

475
        # NUTS extra fields diagnostics: potential_energy, num_steps, accept_prob
2✔
476
        nuts_fields_available = []
2✔
477
        if "potential_energy" in idata.sample_stats:
2✔
478
            nuts_fields_available.append("potential_energy")
2✔
479
        if "num_steps" in idata.sample_stats:
2✔
480
            nuts_fields_available.append("num_steps")
2✔
481
        if "accept_prob" in idata.sample_stats:
482
            nuts_fields_available.append("accept_prob")
483

2✔
484
        if nuts_fields_available:
485
            n_fields = len(nuts_fields_available)
486
            if n_fields == 1:
487
                fig, axes = plt.subplots(2, 2, figsize=figsize)
488
                axes = axes.reshape(1, 4) if n_fields == 1 else axes
489
            elif n_fields == 2:
490
                fig, axes = plt.subplots(2, 4, figsize=(16, 8))
2✔
491
            else:  # 3 fields
2✔
492
                fig, axes = plt.subplots(2, 6, figsize=(24, 8))
2✔
493

2✔
494
            plot_idx = 0
495
            for field_name in nuts_fields_available:
2✔
496
                if field_name == "potential_energy":
2✔
497
                    field_data = idata.sample_stats.potential_energy.values.flatten()
2✔
498
                    field_title = "Potential Energy"
499
                    color = 'orange'
500
                    # Add rolling mean line for stability reference
2✔
501
                    window_size = min(50, len(field_data) // 4)
502
                    if window_size > 1:
503
                        rolling_mean = np.convolve(field_data, np.ones(window_size)/window_size, mode='valid')
NEW
504
                        rolling_start = window_size // 2
×
NEW
505
                        reference_values = [(f'Rolling Mean (w={window_size})', 'blue', '--',
×
506
                                           rolling_mean, rolling_start, 'Stability reference')]
NEW
507
                    else:
×
508
                        reference_values = []
509
                    # Color zones for potential energy (relative to mean)
NEW
510
                    data_mean = np.mean(field_data)
×
NEW
511
                    data_std = np.std(field_data)
×
NEW
512
                    zones = [
×
NEW
513
                        (data_mean - 2*data_std, data_mean + 2*data_std, 'green', 'Normal range'),
×
NEW
514
                        (data_mean - 3*data_std, data_mean - 2*data_std, 'yellow', 'Concerning'),
×
515
                        (data_mean + 2*data_std, data_mean + 3*data_std, 'yellow', 'Concerning'),
516
                    ]
NEW
517

×
NEW
518
                elif field_name == "num_steps":
×
NEW
519
                    field_data = idata.sample_stats.num_steps.values.flatten()
×
NEW
520
                    field_title = "Number of Steps"
×
NEW
521
                    color = 'green'
×
522
                    # Reference line at max possible steps (2^max_tree_depth)
523
                    max_steps = 2 ** 10  # 1024, assuming max_tree_depth=10
NEW
524
                    reference_values = [(f'Max Steps (2^{10})', 'red', '--', max_steps, None, 'Hitting max = inefficient')]
×
525
                    # Color zones for steps
526
                    zones = [
NEW
527
                        (1, max_steps * 0.5, 'green', 'Efficient'),
×
NEW
528
                        (max_steps * 0.5, max_steps * 0.8, 'yellow', 'Moderate'),
×
NEW
529
                        (max_steps * 0.8, max_steps, 'red', 'Inefficient'),
×
NEW
530
                    ]
×
NEW
531

×
532
                elif field_name == "accept_prob":
533
                    field_data = idata.sample_stats.accept_prob.values.flatten()
NEW
534
                    field_title = "Acceptance Probability"
×
NEW
535
                    color = 'purple'
×
NEW
536
                    # Target acceptance probability
×
NEW
537
                    target_accept = 0.8
×
NEW
538
                    reference_values = [(f'Target (0.8)', 'red', '--', target_accept, None, 'NUTS target acceptance')]
×
NEW
539
                    # Color zones for acceptance probability
×
540
                    zones = [
NEW
541
                        (0.7, 0.9, 'green', 'Good range'),
×
NEW
542
                        (0.6, 0.7, 'yellow', 'Borderline'),
×
NEW
543
                        (0.9, 1.0, 'yellow', 'Too high'),
×
544
                        (0.0, 0.6, 'red', 'Too low'),
545
                    ]
546

2✔
NEW
547
                # Trace plot
×
NEW
548
                row = plot_idx // 2
×
549
                col_start = (plot_idx % 2) * 2
NEW
550
                ax_trace = axes[row, col_start]
×
551

552
                # Add color zones as background
NEW
553
                if 'zones' in locals():
×
NEW
554
                    for zone_min, zone_max, zone_color, label in zones:
×
NEW
555
                        ax_trace.axhspan(zone_min, zone_max, alpha=0.1, color=zone_color, label=label)
×
NEW
556

×
NEW
557
                ax_trace.plot(field_data, alpha=0.7, color=color, linewidth=1)
×
558
                ax_trace.set_xlabel("Iteration")
559
                ax_trace.set_ylabel(field_title)
NEW
560
                ax_trace.set_title(f"{field_title} Trace")
×
NEW
561
                ax_trace.grid(True, alpha=0.3)
×
NEW
562

×
NEW
563
                # Add reference lines/values to trace plot
×
NEW
564
                if 'reference_values' in locals():
×
565
                    for ref_label, ref_color, ref_style, ref_value, ref_start, ref_desc in reference_values:
566
                        if ref_start is None:
NEW
567
                            # Horizontal reference line
×
NEW
568
                            ax_trace.axhline(y=ref_value, color=ref_color, linestyle=ref_style,
×
NEW
569
                                           linewidth=2, alpha=0.8, label=ref_label)
×
NEW
570
                        else:
×
NEW
571
                            # Rolling reference line
×
572
                            ax_trace.plot(range(ref_start, ref_start + len(ref_value)),
573
                                        ref_value, color=ref_color, linestyle=ref_style,
NEW
574
                                        linewidth=2, alpha=0.8, label=ref_label)
×
NEW
575

×
576
                # Add legend if we have reference lines
NEW
577
                if 'reference_values' in locals() and reference_values:
×
NEW
578
                    ax_trace.legend(loc='best', fontsize='small')
×
NEW
579

×
580
                # Histogram
581
                ax_hist = axes[row, col_start+1]
582
                n, bins, patches = ax_hist.hist(field_data, bins=30, alpha=0.7, density=True, color=color)
583

584
                # Add vertical reference lines to histogram
585
                if 'reference_values' in locals():
586
                    for ref_label, ref_color, ref_style, ref_value, ref_start, ref_desc in reference_values:
NEW
587
                        if ref_start is None and isinstance(ref_value, (int, float)):
×
588
                            # Vertical reference line on histogram
589
                            ax_hist.axvline(x=ref_value, color=ref_color, linestyle=ref_style,
590
                                          linewidth=2, alpha=0.8, label=ref_label)
591

592
                # Add color zones to histogram background
593
                if 'zones' in locals():
594
                    for zone_min, zone_max, zone_color, label in zones:
NEW
595
                        # Fill histogram background in zones
×
NEW
596
                        ax_hist.axvspan(zone_min, zone_max, alpha=0.05, color=zone_color)
×
NEW
597

×
598
                ax_hist.set_xlabel(field_title)
599
                ax_hist.set_ylabel("Density")
NEW
600
                ax_hist.set_title(f"{field_title} Distribution")
×
NEW
601
                ax_hist.grid(True, alpha=0.3)
×
602

NEW
603
                # Add histogram legend if we have reference lines
×
604
                if 'reference_values' in locals() and reference_values:
605
                    ax_hist.legend(loc='best', fontsize='small')
606

607
                # Add summary statistics as text on histogram
608
                stats_text = ".2e" if field_name == "potential_energy" else ".3f"
609
                mean_val = np.mean(field_data)
NEW
610
                std_val = np.std(field_data)
×
NEW
611
                ax_hist.text(0.02, 0.98, f'Mean: {mean_val:{stats_text}}\nStd: {std_val:{stats_text}}',
×
NEW
612
                           transform=ax_hist.transAxes, fontsize='small',
×
NEW
613
                           verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
×
614

NEW
615
                plot_idx += 1
×
NEW
616

×
NEW
617
                # Clean up local variables for next iteration
×
618
                if 'reference_values' in locals():
619
                    del reference_values
620
                if 'zones' in locals():
2✔
621
                    del zones
2✔
622

2✔
623
            # Fill unused subplots if any
2✔
624
            total_subplots = axes.size
2✔
625
            used_subplots = plot_idx * 2
2✔
626
            for i in range(used_subplots, total_subplots):
2✔
627
                row = i // axes.shape[1]
628
                col = i % axes.shape[1]
2✔
629
                axes[row, col].axis('off')
2✔
630

2✔
NEW
631
            plt.tight_layout()
×
NEW
632
            plt.savefig(f"{outdir}/nuts_extra_fields_diagnostics.png", dpi=150, bbox_inches='tight')
×
633
            plt.close()
2✔
NEW
634

×
635
    except Exception as e:
636
        print(f"Error generating NUTS-specific diagnostics: {e}")
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc