• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18209569030

03 Oct 2025 12:42AM UTC coverage: 80.878% (+1.6%) from 79.246%
18209569030

push

github

avivajpeyi
fix logger

549 of 654 branches covered (83.94%)

Branch coverage included in aggregate %.

2 of 2 new or added lines in 1 file covered. (100.0%)

469 existing lines in 17 files now uncovered.

3799 of 4722 relevant lines covered (80.45%)

1.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.22
/src/log_psplines/benchmark/plotting.py
1
import json
2✔
2
import os
2✔
3
from typing import List
2✔
4

5
import matplotlib.pyplot as plt
2✔
6
import matplotlib.ticker as ticker
2✔
7
import numpy as np
2✔
8
from matplotlib.patches import Patch
2✔
9

10
from ..logger import logger
2✔
11

12
MH_COLOR = "tab:blue"
2✔
13
NUTS_COLOR = "tab:orange"
2✔
14
GPU_MARKER = "o--"
2✔
15
CPU_MARKER = "s-"
2✔
16
CPU_ALPHA = 0.75
2✔
17
GPU_ALPHA = 1.0
2✔
18

19
CPU_KWGS = dict(alpha=0.75, filled=False)
2✔
20
GPU_KWGS = dict(alpha=1.0, filled=True)
2✔
21
MH_KWGS = dict(color=MH_COLOR)
2✔
22
NUTS_KWGS = dict(color=NUTS_COLOR)
2✔
23

24

25
def logspace_widths(xs, log_width=0.1):
2✔
26
    xs = np.array(xs)
2✔
27
    return 10 ** (np.log10(xs) + log_width / 2) - 10 ** (
2✔
28
        np.log10(xs) - log_width / 2
29
    )
30

31

32
def plot_box(ax, xs, ys, color="C0", alpha=0.7, filled=True):
2✔
33
    xscale = ax.get_xscale()
2✔
34

35
    if xscale == "log":
2✔
36
        widths = logspace_widths(xs, log_width=0.1)
2✔
37
    else:
38
        widths = None
2✔
39

40
    bp = ax.boxplot(
2✔
41
        ys,
42
        positions=xs,
43
        widths=widths,
44
        patch_artist=True,
45
        showfliers=False,
46
        label=None,
47
    )
48

49
    for box in bp["boxes"]:
2✔
50
        if filled:
2✔
51
            box.set_facecolor(color)
×
UNCOV
52
            box.set_alpha(alpha)
×
UNCOV
53
            box.set_linewidth(0)
×
54
        else:
55
            box.set_facecolor("none")
2✔
56
            box.set_edgecolor(color)
2✔
57
            box.set_alpha(alpha)
2✔
58
            box.set_linewidth(3)
2✔
59

60
    for median in bp["medians"]:
2✔
61
        median.set_color(color)
2✔
62
        median.set_alpha(alpha)
2✔
63

64
    for element in ["whiskers", "caps"]:
2✔
65
        for line in bp[element]:
2✔
66
            line.set_color(color)
2✔
67
            line.set_alpha(alpha)
2✔
68

69

70
def plot_ess(*args, **kwargs):
2✔
71
    plot_box(*args, **kwargs)
2✔
72
    args[0].set_ylabel("ESS")
2✔
73

74

75
def plot_runtimes(*args, **kwargs):
2✔
76
    plot_box(*args, **kwargs)
2✔
77
    args[0].set_ylabel("Runtime (seconds)")
2✔
78
    args[0].set_yscale("log")
2✔
79

80

81
def plot_data_size_results(filepaths: List[str]) -> None:
2✔
82
    """Plot data size analysis results."""
83

84
    fig, axes = plt.subplots(2, 1, sharex=True)
2✔
85
    axes[0].set_xscale("log")
2✔
86

87
    for filepath in filepaths:
2✔
88
        if not os.path.exists(filepath):
2✔
UNCOV
89
            logger.info(f"Data file {filepath} not found")
×
UNCOV
90
            continue
×
91

92
        with open(filepath, "r") as f:
2✔
93
            data = json.load(f)
2✔
94

95
        kwgs = _get_kwgs(filepath)
2✔
96
        plot_ess(axes[0], data["ns"], data["ess"], **kwgs)
2✔
97
        plot_runtimes(axes[1], data["ns"], data["runtimes"], **kwgs)
2✔
98

99
    axes[1].set_xlabel(r"$N$")
2✔
100

101
    axes[1].set_xscale("log")
2✔
102
    axes[1].xaxis.set_major_locator(ticker.LogLocator(base=10.0, numticks=10))
2✔
103
    axes[1].xaxis.set_minor_locator(
2✔
104
        ticker.LogLocator(base=10.0, subs="auto", numticks=10)
105
    )
106
    axes[1].xaxis.set_major_formatter(ticker.LogFormatterMathtext())
2✔
107
    axes[1].xaxis.set_minor_formatter(ticker.NullFormatter())
2✔
108

109
    # remove vertical space between subplots
110
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
2✔
111
    plt.subplots_adjust(hspace=0.0)
2✔
112
    fdir = os.path.dirname(filepaths[0])
2✔
113

114
    plt.savefig(f"{fdir}/N_vs_runtime.png", dpi=150)
2✔
115
    plt.close()
2✔
116

117

118
def plot_knots_results(filepaths: List[str]) -> None:
2✔
119
    """Plot knots analysis results."""
120

121
    fig, axes = plt.subplots(2, 1, sharex=True)
2✔
122

123
    for filepath in filepaths:
2✔
124
        if not os.path.exists(filepath):
2✔
UNCOV
125
            logger.info(f"Data file {filepath} not found")
×
UNCOV
126
            continue
×
127

128
        with open(filepath, "r") as f:
2✔
129
            data = json.load(f)
2✔
130

131
        kwgs = {
2✔
132
            **(MH_KWGS if data["sampler"] == "mh" else NUTS_KWGS),
133
            **(CPU_KWGS if data["device"] == "cpu" else GPU_KWGS),
134
        }
135
        plot_ess(axes[0], data["ks"], data["ess"], **kwgs)
2✔
136
        plot_runtimes(axes[1], data["ks"], data["runtimes"], **kwgs)
2✔
137

138
    axes[1].set_xlabel(r"$K$")
2✔
139

140
    # use autoformatter for x ticks
141
    axes[1].xaxis.set_major_locator(ticker.AutoLocator())
2✔
142
    axes[1].xaxis.set_major_formatter(ticker.ScalarFormatter())
2✔
143

144
    _add_legend(axes[0], [os.path.basename(f) for f in filepaths])
2✔
145
    plt.subplots_adjust(hspace=0.0)
2✔
146
    fdir = os.path.dirname(filepaths[0])
2✔
147

148
    plt.savefig(f"{fdir}/K_vs_runtime.png", dpi=150)
2✔
149
    plt.close()
2✔
150

151

152
def _get_kwgs(fname: str):
2✔
153
    return {
2✔
154
        **(MH_KWGS if "_mh_" in fname else NUTS_KWGS),
155
        **(CPU_KWGS if "cpu" in fname else GPU_KWGS),
156
    }
157

158

159
def _add_legend(ax, fnames: List[str]) -> None:
2✔
160
    """Add legend to the axes."""
161

162
    patches, labels = [], []
2✔
163
    for fname in fnames:
2✔
164
        kwgs = _get_kwgs(fname)
2✔
165
        if kwgs["filled"]:
2✔
UNCOV
166
            p = Patch(color=kwgs["color"], alpha=kwgs["alpha"])
×
167
        else:
168
            p = Patch(
2✔
169
                edgecolor=kwgs["color"], alpha=kwgs["alpha"], facecolor="none"
170
            )
171
        patches.append(p)
2✔
172

173
        sampler = "MH" if "mh" in fname else "NUTS"
2✔
174
        device = "CPU" if "cpu" in fname else "GPU"
2✔
175
        labels.append(f"{sampler} ({device})")
2✔
176

177
    ax.legend(
2✔
178
        handles=patches,
179
        labels=labels,
180
        frameon=True,
181
        fontsize="small",
182
        loc="upper right",
183
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc