• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 16308981650

16 Jul 2025 02:36AM UTC coverage: 74.293% (-1.4%) from 75.644%
16308981650

push

github

avivajpeyi
refactoring to use a common parent class

76 of 108 branches covered (70.37%)

Branch coverage included in aggregate %.

246 of 280 new or added lines in 7 files covered. (87.86%)

6 existing lines in 1 file now uncovered.

843 of 1129 relevant lines covered (74.67%)

1.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.53
/src/log_psplines/plotting/diagnostics.py
1
import arviz as az
2✔
2
import matplotlib.pyplot as plt
2✔
3
import numpy as np
2✔
4

5

6
def plot_diagnostics(
2✔
7
    idata: az.InferenceData,
8
    outdir: str,
9
    variables: list = ["phi", "delta", "weights"],
10
    figsize: tuple = (12, 8),
11
) -> None:
12
    """
13
    Plot MCMC diagnostics using arviz.
14

15
    Parameters
16
    ----------
17
    idata : az.InferenceData
18
        Inference data from adaptive MCMC
19
    variables : list
20
        Variables to plot
21
    figsize : tuple
22
        Figure size
23
    """
24

25
    # Trace plots
26
    az.plot_trace(idata, var_names=variables, figsize=figsize)
2✔
27
    plt.suptitle("Trace plots - Adaptive MCMC")
2✔
28
    plt.tight_layout()
2✔
29
    plt.savefig(f"{outdir}/trace_plots.png")
2✔
30

31
    # Summary statistics
32
    print("Summary Statistics:")
2✔
33
    print(az.summary(idata, var_names=variables))
2✔
34

35
    # Acceptance rate plot
36
    if "acceptance_rate" in idata.sample_stats:
2✔
37
        fig, ax = plt.subplots(figsize=(10, 4))
2✔
38
        accept_rates = idata.sample_stats.acceptance_rate.values.flatten()
2✔
39
        ax.plot(accept_rates, alpha=0.7)
2✔
40
        ax.axhline(
2✔
41
            idata.attrs.get("target_accept_rate", 0.44),
42
            color="red",
43
            linestyle="--",
44
            label="Target",
45
        )
46
        ax.set_xlabel("Iteration")
2✔
47
        ax.set_ylabel("Acceptance Rate")
2✔
48
        ax.set_title("Acceptance Rate Over Time")
2✔
49
        ax.legend()
2✔
50
        ax.grid(True, alpha=0.3)
2✔
51
        plt.tight_layout()
2✔
52
        plt.savefig(f"{outdir}/acceptance_rate.png")
2✔
53

54
    # Step size evolution
55
    if "step_size_mean" in idata.sample_stats:
2✔
56
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
2✔
57

58
        step_means = idata.sample_stats.step_size_mean.values.flatten()
2✔
59
        step_stds = idata.sample_stats.step_size_std.values.flatten()
2✔
60

61
        ax1.plot(step_means, alpha=0.7)
2✔
62
        ax1.set_xlabel("Iteration")
2✔
63
        ax1.set_ylabel("Mean Step Size")
2✔
64
        ax1.set_title("Step Size Evolution")
2✔
65
        ax1.grid(True, alpha=0.3)
2✔
66

67
        ax2.plot(step_stds, alpha=0.7, color="orange")
2✔
68
        ax2.set_xlabel("Iteration")
2✔
69
        ax2.set_ylabel("Step Size Std")
2✔
70
        ax2.set_title("Step Size Variability")
2✔
71
        ax2.grid(True, alpha=0.3)
2✔
72

73
        plt.tight_layout()
2✔
74
        plt.savefig(f"{outdir}/step_size_evolution.png")
2✔
75

76
    # NUTS-specific plots
77
    if "tree_depth" in idata.sample_stats:
2✔
NEW
78
        fig, axes = plt.subplots(2, 2, figsize=figsize)
×
79

NEW
80
        tree_depths = idata.sample_stats.tree_depth.values.flatten()
×
NEW
81
        axes[0, 0].hist(
×
82
            tree_depths, bins=range(int(tree_depths.max()) + 2), alpha=0.7
83
        )
NEW
84
        axes[0, 0].set_xlabel("Tree Depth")
×
NEW
85
        axes[0, 0].set_ylabel("Count")
×
NEW
86
        axes[0, 0].set_title("Distribution of Tree Depths")
×
87

NEW
88
        if "num_steps" in idata.sample_stats:
×
NEW
89
            num_steps = idata.sample_stats.num_steps.values.flatten()
×
NEW
90
            axes[0, 1].hist(num_steps, bins=30, alpha=0.7)
×
NEW
91
            axes[0, 1].set_xlabel("Number of Steps")
×
NEW
92
            axes[0, 1].set_ylabel("Count")
×
NEW
93
            axes[0, 1].set_title("Distribution of Leapfrog Steps")
×
94

NEW
95
        if "energy" in idata.sample_stats:
×
NEW
96
            energy = idata.sample_stats.energy.values.flatten()
×
NEW
97
            axes[1, 0].plot(energy, alpha=0.7)
×
NEW
98
            axes[1, 0].set_xlabel("Iteration")
×
NEW
99
            axes[1, 0].set_ylabel("Energy")
×
NEW
100
            axes[1, 0].set_title("Energy Over Time")
×
101

NEW
102
        if "diverging" in idata.sample_stats:
×
NEW
103
            diverging = idata.sample_stats.diverging.values.flatten()
×
NEW
104
            n_divergent = np.sum(diverging)
×
NEW
105
            total_samples = len(diverging)
×
NEW
106
            divergent_rate = n_divergent / total_samples
×
107

NEW
108
            axes[1, 1].bar(
×
109
                ["Non-divergent", "Divergent"],
110
                [total_samples - n_divergent, n_divergent],
111
            )
NEW
112
            axes[1, 1].set_ylabel("Count")
×
NEW
113
            axes[1, 1].set_title(
×
114
                f"Divergent Transitions ({divergent_rate:.1%})"
115
            )
116

NEW
117
        plt.tight_layout()
×
NEW
118
        plt.savefig(f"{outdir}/nuts_diagnostics.png")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc