• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 16407223521

21 Jul 2025 02:31AM UTC coverage: 80.997% (+3.1%) from 77.939%
16407223521

push

github

avivajpeyi
fix: add benchmarking fix for cli

96 of 123 branches covered (78.05%)

Branch coverage included in aggregate %.

143 of 151 new or added lines in 4 files covered. (94.7%)

7 existing lines in 2 files now uncovered.

1106 of 1361 relevant lines covered (81.26%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.16
/src/log_psplines/benchmark/plotting.py
1
import matplotlib.pyplot as plt
2✔
2
import os
2✔
3
import json
2✔
4
from typing import List
5
import numpy as np
2✔
6
import matplotlib.ticker as ticker
2✔
7
from matplotlib.patches import Patch
2✔
8

2✔
9
MH_COLOR = "tab:blue"
10
NUTS_COLOR = "tab:orange"
2✔
11
GPU_MARKER = "o--"
2✔
12
CPU_MARKER = "s-"
2✔
13
CPU_ALPHA = 0.75
2✔
14
GPU_ALPHA = 1.0
2✔
15

2✔
16
CPU_KWGS = dict(alpha=0.75, filled=False)
17
GPU_KWGS = dict(alpha=1.0, filled=True)
2✔
18
MH_KWGS = dict(color=MH_COLOR)
2✔
19
NUTS_KWGS = dict(color=NUTS_COLOR)
2✔
20

2✔
21

22
def logspace_widths(xs, log_width=0.1):
23
    xs = np.array(xs)
2✔
24
    return 10 ** (np.log10(xs) + log_width / 2) - 10 ** (np.log10(xs) - log_width / 2)
2✔
25

2✔
26

27
def plot_box(ax, xs, ys, color='C0', alpha=0.7, filled=True):
28
    xscale = ax.get_xscale()
29

30
    if xscale == "log":
2✔
31
        widths = logspace_widths(xs, log_width=0.1)
2✔
32
    else:
33
        widths = None
2✔
34

2✔
35
    bp = ax.boxplot(
36
        ys,
2✔
37
        positions=xs,
38
        widths=widths,
2✔
39
        patch_artist=True,
40
        showfliers=False,
41
        label=None
42
    )
43

44
    for box in bp['boxes']:
45
        if filled:
46
            box.set_facecolor(color)
47
            box.set_alpha(alpha)
2✔
48
            box.set_linewidth(0)
2✔
NEW
49
        else:
×
NEW
50
            box.set_facecolor('none')
×
NEW
51
            box.set_edgecolor(color)
×
52
            box.set_alpha(alpha)
53
            box.set_linewidth(3)
2✔
54

2✔
55
    for median in bp['medians']:
2✔
56
        median.set_color(color)
2✔
57
        median.set_alpha(alpha)
58

2✔
59
    for element in ['whiskers', 'caps']:
2✔
60
        for line in bp[element]:
2✔
61
            line.set_color(color)
62
            line.set_alpha(alpha)
2✔
63

2✔
64

2✔
65
def plot_ess(*args, **kwargs):
2✔
66
    plot_box(*args, **kwargs)
67
    args[0].set_ylabel("ESS")
68

2✔
69

2✔
70
def plot_runtimes(*args, **kwargs):
2✔
71
    plot_box(*args, **kwargs)
72
    args[0].set_ylabel("Runtime (seconds)")
73
    args[0].set_yscale("log")
2✔
74

2✔
75

2✔
76
def plot_data_size_results(filepaths: List[str]) -> None:
2✔
77
    """Plot data size analysis results."""
78

79
    fig, axes = plt.subplots(2, 1, sharex=True)
2✔
80
    axes[0].set_xscale("log")
81

82
    for filepath in filepaths:
2✔
83
        if not os.path.exists(filepath):
2✔
84
            print(f"Data file {filepath} not found")
85
            continue
2✔
86

2✔
NEW
87
        with open(filepath, "r") as f:
×
NEW
88
            data = json.load(f)
×
89

90
        kwgs = _get_kwgs(filepath)
2✔
91
        plot_ess(axes[0], data["ns"], data["ess"], **kwgs)
2✔
92
        plot_runtimes(axes[1], data["ns"], data["runtimes"], **kwgs)
93

2✔
94
    axes[1].set_xlabel(r"$N$")
2✔
95

2✔
96
    axes[1].set_xscale("log")
97
    axes[1].xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=10))
2✔
98
    axes[1].xaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs='auto', numticks=10))
99
    axes[1].xaxis.set_major_formatter(ticker.LogFormatterMathtext())
2✔
100
    axes[1].xaxis.set_minor_formatter(ticker.NullFormatter())
2✔
101

2✔
102
    # remove vertical space between subplots
103
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
104
    plt.subplots_adjust(hspace=0.)
2✔
105
    fdir = os.path.dirname(filepaths[0])
2✔
106

107
    plt.savefig(f"{fdir}/N_vs_runtime.png", dpi=150)
108
    plt.close()
2✔
109

2✔
110

2✔
111
def plot_knots_results(filepaths: List[str]) -> None:
112
    """Plot knots analysis results."""
2✔
113

2✔
114
    fig, axes = plt.subplots(2, 1, sharex=True)
115

116
    for filepath in filepaths:
2✔
117
        if not os.path.exists(filepath):
118
            print(f"Data file {filepath} not found")
119
            continue
2✔
120

121
        with open(filepath, "r") as f:
2✔
122
            data = json.load(f)
2✔
NEW
123

×
NEW
124
        kwgs = {
×
125
            **(MH_KWGS if data["sampler"] == "mh" else NUTS_KWGS),
126
            **(CPU_KWGS if data["device"] == "cpu" else GPU_KWGS)
2✔
127
        }
2✔
128
        plot_ess(axes[0], data["ks"], data["ess"], **kwgs)
129
        plot_runtimes(axes[1], data["ks"], data["runtimes"], **kwgs)
2✔
130

131
    axes[1].set_xlabel(r"$K$")
132

133
    # use autoformatter for x ticks
2✔
134
    axes[1].xaxis.set_major_locator(ticker.AutoLocator())
2✔
135
    axes[1].xaxis.set_major_formatter(ticker.ScalarFormatter())
136

2✔
137
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
138
    plt.subplots_adjust(hspace=0.)
139
    fdir = os.path.dirname(filepaths[0])
2✔
140

2✔
141
    plt.savefig(f"{fdir}/K_vs_runtime.png", dpi=150)
142
    plt.close()
2✔
143

2✔
144

2✔
145
def _get_kwgs(fname: str):
146
    return {
2✔
147
        **(MH_KWGS if "_mh_" in fname else NUTS_KWGS),
2✔
148
        **(CPU_KWGS if "cpu" in fname else GPU_KWGS)
149
    }
150

2✔
151

2✔
152
def _add_legend(ax, fnames: List[str]) -> None:
153
    """Add legend to the axes."""
154

155
    patches, labels = [], []
156
    for fname in fnames:
157
        kwgs = _get_kwgs(fname)
2✔
158
        if kwgs['filled']:
159
            p = Patch(
160
                color=kwgs['color'],
2✔
161
                alpha=kwgs['alpha']
2✔
162
            )
2✔
163
        else:
2✔
NEW
164
            p = Patch(
×
165
                edgecolor=kwgs['color'],
166
                alpha=kwgs['alpha'],
2✔
167
                facecolor='none'
168
            )
169
        patches.append(p)
2✔
170

171
        sampler = "MH" if "mh" in fname else "NUTS"
2✔
172
        device = "CPU" if "cpu" in fname else "GPU"
2✔
173
        labels.append(f"{sampler} ({device})")
2✔
174

175
    ax.legend(
2✔
176
        handles=patches,
177
        labels=labels,
178
        frameon=True,
179
        fontsize='small',
180
        loc='upper right',
181
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc