• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 17394129984

02 Sep 2025 05:25AM UTC coverage: 90.419% (+9.4%) from 81.047%
17394129984

push

github

avivajpeyi
run precommits

169 of 180 branches covered (93.89%)

Branch coverage included in aggregate %.

145 of 166 new or added lines in 15 files covered. (87.35%)

62 existing lines in 11 files now uncovered.

1492 of 1657 relevant lines covered (90.04%)

1.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.16
/src/log_psplines/benchmark/plotting.py
1
import json
2✔
2
import os
2✔
3
from typing import List
2✔
4

5
import matplotlib.pyplot as plt
2✔
6
import matplotlib.ticker as ticker
2✔
7
import numpy as np
2✔
8
from matplotlib.patches import Patch
2✔
9

10
MH_COLOR = "tab:blue"
2✔
11
NUTS_COLOR = "tab:orange"
2✔
12
GPU_MARKER = "o--"
2✔
13
CPU_MARKER = "s-"
2✔
14
CPU_ALPHA = 0.75
2✔
15
GPU_ALPHA = 1.0
2✔
16

17
CPU_KWGS = dict(alpha=0.75, filled=False)
2✔
18
GPU_KWGS = dict(alpha=1.0, filled=True)
2✔
19
MH_KWGS = dict(color=MH_COLOR)
2✔
20
NUTS_KWGS = dict(color=NUTS_COLOR)
2✔
21

22

23
def logspace_widths(xs, log_width=0.1):
2✔
24
    xs = np.array(xs)
2✔
25
    return 10 ** (np.log10(xs) + log_width / 2) - 10 ** (
2✔
26
        np.log10(xs) - log_width / 2
27
    )
28

29

30
def plot_box(ax, xs, ys, color="C0", alpha=0.7, filled=True):
2✔
31
    xscale = ax.get_xscale()
2✔
32

33
    if xscale == "log":
2✔
34
        widths = logspace_widths(xs, log_width=0.1)
2✔
35
    else:
36
        widths = None
2✔
37

38
    bp = ax.boxplot(
2✔
39
        ys,
40
        positions=xs,
41
        widths=widths,
42
        patch_artist=True,
43
        showfliers=False,
44
        label=None,
45
    )
46

47
    for box in bp["boxes"]:
2✔
48
        if filled:
2✔
UNCOV
49
            box.set_facecolor(color)
×
UNCOV
50
            box.set_alpha(alpha)
×
UNCOV
51
            box.set_linewidth(0)
×
52
        else:
53
            box.set_facecolor("none")
2✔
54
            box.set_edgecolor(color)
2✔
55
            box.set_alpha(alpha)
2✔
56
            box.set_linewidth(3)
2✔
57

58
    for median in bp["medians"]:
2✔
59
        median.set_color(color)
2✔
60
        median.set_alpha(alpha)
2✔
61

62
    for element in ["whiskers", "caps"]:
2✔
63
        for line in bp[element]:
2✔
64
            line.set_color(color)
2✔
65
            line.set_alpha(alpha)
2✔
66

67

68
def plot_ess(*args, **kwargs):
2✔
69
    plot_box(*args, **kwargs)
2✔
70
    args[0].set_ylabel("ESS")
2✔
71

72

73
def plot_runtimes(*args, **kwargs):
2✔
74
    plot_box(*args, **kwargs)
2✔
75
    args[0].set_ylabel("Runtime (seconds)")
2✔
76
    args[0].set_yscale("log")
2✔
77

78

79
def plot_data_size_results(filepaths: List[str]) -> None:
2✔
80
    """Plot data size analysis results."""
81

82
    fig, axes = plt.subplots(2, 1, sharex=True)
2✔
83
    axes[0].set_xscale("log")
2✔
84

85
    for filepath in filepaths:
2✔
86
        if not os.path.exists(filepath):
2✔
UNCOV
87
            print(f"Data file {filepath} not found")
×
UNCOV
88
            continue
×
89

90
        with open(filepath, "r") as f:
2✔
91
            data = json.load(f)
2✔
92

93
        kwgs = _get_kwgs(filepath)
2✔
94
        plot_ess(axes[0], data["ns"], data["ess"], **kwgs)
2✔
95
        plot_runtimes(axes[1], data["ns"], data["runtimes"], **kwgs)
2✔
96

97
    axes[1].set_xlabel(r"$N$")
2✔
98

99
    axes[1].set_xscale("log")
2✔
100
    axes[1].xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=10))
2✔
101
    axes[1].xaxis.set_minor_locator(
2✔
102
        ticker.LogLocator(base=10.0, subs="auto", numticks=10)
103
    )
104
    axes[1].xaxis.set_major_formatter(ticker.LogFormatterMathtext())
2✔
105
    axes[1].xaxis.set_minor_formatter(ticker.NullFormatter())
2✔
106

107
    # remove vertical space between subplots
108
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
2✔
109
    plt.subplots_adjust(hspace=0.0)
2✔
110
    fdir = os.path.dirname(filepaths[0])
2✔
111

112
    plt.savefig(f"{fdir}/N_vs_runtime.png", dpi=150)
2✔
113
    plt.close()
2✔
114

115

116
def plot_knots_results(filepaths: List[str]) -> None:
2✔
117
    """Plot knots analysis results."""
118

119
    fig, axes = plt.subplots(2, 1, sharex=True)
2✔
120

121
    for filepath in filepaths:
2✔
122
        if not os.path.exists(filepath):
2✔
UNCOV
123
            print(f"Data file {filepath} not found")
×
UNCOV
124
            continue
×
125

126
        with open(filepath, "r") as f:
2✔
127
            data = json.load(f)
2✔
128

129
        kwgs = {
2✔
130
            **(MH_KWGS if data["sampler"] == "mh" else NUTS_KWGS),
131
            **(CPU_KWGS if data["device"] == "cpu" else GPU_KWGS),
132
        }
133
        plot_ess(axes[0], data["ks"], data["ess"], **kwgs)
2✔
134
        plot_runtimes(axes[1], data["ks"], data["runtimes"], **kwgs)
2✔
135

136
    axes[1].set_xlabel(r"$K$")
2✔
137

138
    # use autoformatter for x ticks
139
    axes[1].xaxis.set_major_locator(ticker.AutoLocator())
2✔
140
    axes[1].xaxis.set_major_formatter(ticker.ScalarFormatter())
2✔
141

142
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
2✔
143
    plt.subplots_adjust(hspace=0.0)
2✔
144
    fdir = os.path.dirname(filepaths[0])
2✔
145

146
    plt.savefig(f"{fdir}/K_vs_runtime.png", dpi=150)
2✔
147
    plt.close()
2✔
148

149

150
def _get_kwgs(fname: str):
2✔
151
    return {
2✔
152
        **(MH_KWGS if "_mh_" in fname else NUTS_KWGS),
153
        **(CPU_KWGS if "cpu" in fname else GPU_KWGS),
154
    }
155

156

157
def _add_legend(ax, fnames: List[str]) -> None:
2✔
158
    """Add legend to the axes."""
159

160
    patches, labels = [], []
2✔
161
    for fname in fnames:
2✔
162
        kwgs = _get_kwgs(fname)
2✔
163
        if kwgs["filled"]:
2✔
NEW
164
            p = Patch(color=kwgs["color"], alpha=kwgs["alpha"])
×
165
        else:
166
            p = Patch(
2✔
167
                edgecolor=kwgs["color"], alpha=kwgs["alpha"], facecolor="none"
168
            )
169
        patches.append(p)
2✔
170

171
        sampler = "MH" if "mh" in fname else "NUTS"
2✔
172
        device = "CPU" if "cpu" in fname else "GPU"
2✔
173
        labels.append(f"{sampler} ({device})")
2✔
174

175
    ax.legend(
2✔
176
        handles=patches,
177
        labels=labels,
178
        frameon=True,
179
        fontsize="small",
180
        loc="upper right",
181
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc