• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18303071607

07 Oct 2025 05:32AM UTC coverage: 80.332% (-0.6%) from 80.952%
18303071607

push

github

web-flow
Merge pull request #10 from nz-gravity/save_vi_diagnostics_before_sampling

save VI plots at the start

576 of 694 branches covered (83.0%)

Branch coverage included in aggregate %.

63 of 80 new or added lines in 7 files covered. (78.75%)

498 existing lines in 11 files now uncovered.

3925 of 4909 relevant lines covered (79.96%)

1.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

22.03
/src/log_psplines/benchmark/plotting.py
1
import json
2✔
2
import os
2✔
3
from typing import List
2✔
4

5
import matplotlib.pyplot as plt
2✔
6
import matplotlib.ticker as ticker
2✔
7
import numpy as np
2✔
8
from matplotlib.patches import Patch
2✔
9

10
from ..logger import logger
2✔
11

12
MH_COLOR = "tab:blue"
2✔
13
NUTS_COLOR = "tab:orange"
2✔
14
GPU_MARKER = "o--"
2✔
15
CPU_MARKER = "s-"
2✔
16
CPU_ALPHA = 0.75
2✔
17
GPU_ALPHA = 1.0
2✔
18

19
CPU_KWGS = dict(alpha=0.75, filled=False)
2✔
20
GPU_KWGS = dict(alpha=1.0, filled=True)
2✔
21
MH_KWGS = dict(color=MH_COLOR)
2✔
22
NUTS_KWGS = dict(color=NUTS_COLOR)
2✔
23

24

25
def logspace_widths(xs, log_width=0.1):
2✔
UNCOV
26
    xs = np.array(xs)
×
UNCOV
27
    return 10 ** (np.log10(xs) + log_width / 2) - 10 ** (
×
28
        np.log10(xs) - log_width / 2
29
    )
30

31

32
def plot_box(ax, xs, ys, color="C0", alpha=0.7, filled=True):
2✔
UNCOV
33
    xscale = ax.get_xscale()
×
34

UNCOV
35
    if xscale == "log":
×
UNCOV
36
        widths = logspace_widths(xs, log_width=0.1)
×
37
    else:
UNCOV
38
        widths = None
×
39

UNCOV
40
    bp = ax.boxplot(
×
41
        ys,
42
        positions=xs,
43
        widths=widths,
44
        patch_artist=True,
45
        showfliers=False,
46
        label=None,
47
    )
48

UNCOV
49
    for box in bp["boxes"]:
×
UNCOV
50
        if filled:
×
51
            box.set_facecolor(color)
×
52
            box.set_alpha(alpha)
×
53
            box.set_linewidth(0)
×
54
        else:
UNCOV
55
            box.set_facecolor("none")
×
UNCOV
56
            box.set_edgecolor(color)
×
UNCOV
57
            box.set_alpha(alpha)
×
UNCOV
58
            box.set_linewidth(3)
×
59

UNCOV
60
    for median in bp["medians"]:
×
UNCOV
61
        median.set_color(color)
×
UNCOV
62
        median.set_alpha(alpha)
×
63

UNCOV
64
    for element in ["whiskers", "caps"]:
×
UNCOV
65
        for line in bp[element]:
×
UNCOV
66
            line.set_color(color)
×
UNCOV
67
            line.set_alpha(alpha)
×
68

69

70
def plot_ess(*args, **kwargs):
2✔
UNCOV
71
    plot_box(*args, **kwargs)
×
UNCOV
72
    args[0].set_ylabel("ESS")
×
73

74

75
def plot_runtimes(*args, **kwargs):
2✔
UNCOV
76
    plot_box(*args, **kwargs)
×
UNCOV
77
    args[0].set_ylabel("Runtime (seconds)")
×
UNCOV
78
    args[0].set_yscale("log")
×
79

80

81
def plot_data_size_results(filepaths: List[str]) -> None:
2✔
82
    """Plot data size analysis results."""
83

UNCOV
84
    fig, axes = plt.subplots(2, 1, sharex=True)
×
UNCOV
85
    axes[0].set_xscale("log")
×
86

UNCOV
87
    for filepath in filepaths:
×
UNCOV
88
        if not os.path.exists(filepath):
×
89
            logger.info(f"Data file {filepath} not found")
×
90
            continue
×
91

UNCOV
92
        with open(filepath, "r") as f:
×
UNCOV
93
            data = json.load(f)
×
94

UNCOV
95
        kwgs = _get_kwgs(filepath)
×
UNCOV
96
        plot_ess(axes[0], data["ns"], data["ess"], **kwgs)
×
UNCOV
97
        plot_runtimes(axes[1], data["ns"], data["runtimes"], **kwgs)
×
98

UNCOV
99
    axes[1].set_xlabel(r"$N$")
×
100

UNCOV
101
    axes[1].set_xscale("log")
×
UNCOV
102
    axes[1].xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=10))
×
UNCOV
103
    axes[1].xaxis.set_minor_locator(
×
104
        ticker.LogLocator(base=10.0, subs="auto", numticks=10)
105
    )
UNCOV
106
    axes[1].xaxis.set_major_formatter(ticker.LogFormatterMathtext())
×
UNCOV
107
    axes[1].xaxis.set_minor_formatter(ticker.NullFormatter())
×
108

109
    # remove vertical space between subplots
UNCOV
110
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
×
UNCOV
111
    plt.subplots_adjust(hspace=0.0)
×
UNCOV
112
    fdir = os.path.dirname(filepaths[0])
×
113

UNCOV
114
    plt.savefig(f"{fdir}/N_vs_runtime.png", dpi=150)
×
UNCOV
115
    plt.close()
×
116

117

118
def plot_knots_results(filepaths: List[str]) -> None:
2✔
119
    """Plot knots analysis results."""
120

UNCOV
121
    fig, axes = plt.subplots(2, 1, sharex=True)
×
122

UNCOV
123
    for filepath in filepaths:
×
UNCOV
124
        if not os.path.exists(filepath):
×
125
            logger.info(f"Data file {filepath} not found")
×
126
            continue
×
127

UNCOV
128
        with open(filepath, "r") as f:
×
UNCOV
129
            data = json.load(f)
×
130

UNCOV
131
        kwgs = {
×
132
            **(MH_KWGS if data["sampler"] == "mh" else NUTS_KWGS),
133
            **(CPU_KWGS if data["device"] == "cpu" else GPU_KWGS),
134
        }
UNCOV
135
        plot_ess(axes[0], data["ks"], data["ess"], **kwgs)
×
UNCOV
136
        plot_runtimes(axes[1], data["ks"], data["runtimes"], **kwgs)
×
137

UNCOV
138
    axes[1].set_xlabel(r"$K$")
×
139

140
    # use autoformatter for x ticks
UNCOV
141
    axes[1].xaxis.set_major_locator(ticker.AutoLocator())
×
UNCOV
142
    axes[1].xaxis.set_major_formatter(ticker.ScalarFormatter())
×
143

UNCOV
144
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
×
UNCOV
145
    plt.subplots_adjust(hspace=0.0)
×
UNCOV
146
    fdir = os.path.dirname(filepaths[0])
×
147

UNCOV
148
    plt.savefig(f"{fdir}/K_vs_runtime.png", dpi=150)
×
UNCOV
149
    plt.close()
×
150

151

152
def _get_kwgs(fname: str):
2✔
UNCOV
153
    return {
×
154
        **(MH_KWGS if "_mh_" in fname else NUTS_KWGS),
155
        **(CPU_KWGS if "cpu" in fname else GPU_KWGS),
156
    }
157

158

159
def _add_legend(ax, fnames: List[str]) -> None:
2✔
160
    """Add legend to the axes."""
161

UNCOV
162
    patches, labels = [], []
×
UNCOV
163
    for fname in fnames:
×
UNCOV
164
        kwgs = _get_kwgs(fname)
×
UNCOV
165
        if kwgs["filled"]:
×
166
            p = Patch(color=kwgs["color"], alpha=kwgs["alpha"])
×
167
        else:
UNCOV
168
            p = Patch(
×
169
                edgecolor=kwgs["color"], alpha=kwgs["alpha"], facecolor="none"
170
            )
UNCOV
171
        patches.append(p)
×
172

UNCOV
173
        sampler = "MH" if "mh" in fname else "NUTS"
×
UNCOV
174
        device = "CPU" if "cpu" in fname else "GPU"
×
UNCOV
175
        labels.append(f"{sampler} ({device})")
×
176

UNCOV
177
    ax.legend(
×
178
        handles=patches,
179
        labels=labels,
180
        frameon=True,
181
        fontsize="small",
182
        loc="upper right",
183
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc