• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

avivajpeyi / pywavelet / 21430410834

28 Jan 2026 08:14AM UTC coverage: 75.545% (+6.5%) from 69.075%
21430410834

push

github

avivajpeyi
fix: update offs calculation in transform_wavelet_freq_helper to prevent overwriting center term

192 of 308 branches covered (62.34%)

Branch coverage included in aggregate %.

3 of 3 new or added lines in 1 file covered. (100.0%)

164 existing lines in 16 files now uncovered.

1229 of 1573 relevant lines covered (78.13%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.95
/src/pywavelet/types/plotting.py
1
import warnings
1✔
2
from typing import Optional, Tuple, Union
1✔
3

4
import matplotlib.pyplot as plt
1✔
5
import numpy as np
1✔
6
from matplotlib.colors import LogNorm, TwoSlopeNorm
1✔
7
from scipy.interpolate import interp1d
1✔
8
from scipy.signal import savgol_filter, spectrogram
1✔
9

10
MIN_S = 60
1✔
11
HOUR_S = 60 * MIN_S
1✔
12
DAY_S = 24 * HOUR_S
1✔
13

14

15
def plot_wavelet_trend(
1✔
16
    wavelet_data: np.ndarray,
17
    time_grid: np.ndarray,
18
    freq_grid: np.ndarray,
19
    ax: Optional[plt.Axes] = None,
20
    freq_scale: str = "linear",
21
    freq_range: Optional[Tuple[float, float]] = None,
22
    color: str = "black",
23
):
24
    x = time_grid
1✔
25
    y = __get_smoothed_y(x, np.abs(wavelet_data), freq_grid)
1✔
26
    if ax == None:
1!
27
        fig, ax = plt.subplots()
×
28
    ax.plot(x, y, color=color)
1✔
29

30
    # Configure axes scales
31
    ax.set_yscale(freq_scale)
1✔
32
    _fmt_time_axis(time_grid, ax)
1✔
33
    ax.set_ylabel("Frequency [Hz]")
1✔
34

35
    # Set frequency range if specified
36
    freq_range = freq_range or (freq_grid[0], freq_grid[-1])
1✔
37
    ax.set_ylim(freq_range)
1✔
38

39

40
def __get_smoothed_y(x, z, y_grid):
1✔
41
    Nf, Nt = z.shape
1✔
42
    y = np.zeros(Nt)
1✔
43
    dy = np.diff(y_grid)[0]
1✔
44
    for i in range(Nt):
1✔
45
        # if all values are nan, set to nan
46
        if np.all(np.isnan(z[:, i])):
1✔
47
            y[i] = np.nan
1✔
48
        else:
49
            y[i] = y_grid[np.nanargmax(z[:, i])]
1✔
50

51
    if not np.isnan(y).all():
1!
52
        # Interpolate to fill NaNs in y before smoothing
53
        nan_mask = ~np.isnan(y)
1✔
54
        if np.isnan(y).any():
1!
55
            interpolator = interp1d(
1✔
56
                x[nan_mask],
57
                y[nan_mask],
58
                kind="cubic",
59
                bounds_error=False,
60
                fill_value="extrapolate",
61
            )
62
            y = interpolator(x)  # Fill NaNs with interpolated values
1✔
63

64
        # Smooth the curve
65
        window_length = min(51, len(y) - 1 if len(y) % 2 == 0 else len(y))
1✔
66
        y = savgol_filter(y, window_length, 3)
1✔
67
        y[~nan_mask] = np.nan
1✔
68
    return y
1✔
69

70

71
def plot_wavelet_grid(
1✔
72
    wavelet_data: np.ndarray,
73
    time_grid: np.ndarray,
74
    freq_grid: np.ndarray,
75
    ax: Optional[plt.Axes] = None,
76
    zscale: str = "linear",
77
    freq_scale: str = "linear",
78
    absolute: bool = False,
79
    freq_range: Optional[Tuple[float, float]] = None,
80
    show_colorbar: bool = True,
81
    cmap: Optional[str] = None,
82
    norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
83
    cbar_label: Optional[str] = None,
84
    nan_color: Optional[str] = "black",
85
    detailed_axes: bool = False,
86
    show_gridinfo: bool = True,
87
    txtbox_kwargs: dict = {},
88
    trend_color: Optional[str] = None,
89
    whiten_by: Optional[np.ndarray] = None,
90
    **kwargs,
91
) -> Tuple[plt.Figure, plt.Axes]:
92
    """
93
    Plot a 2D grid of wavelet coefficients.
94

95
    Parameters
96
    ----------
97
    wavelet_data : np.ndarray
98
        A 2D array containing the wavelet coefficients with shape (Nf, Nt),
99
        where Nf is the number of frequency bins and Nt is the number of time bins.
100

101
    time_grid : np.ndarray, optional
102
        1D array of time values corresponding to the time bins. If None, uses np.arange(Nt).
103

104
    freq_grid : np.ndarray, optional
105
        1D array of frequency values corresponding to the frequency bins. If None, uses np.arange(Nf).
106

107
    ax : plt.Axes, optional
108
        Matplotlib Axes object to plot on. If None, creates a new figure and axes.
109

110
    zscale : str, optional
111
        Scale for the color mapping. Options are 'linear' or 'log'. Default is 'linear'.
112

113
    freq_scale : str, optional
114
        Scale for the frequency axis. Options are 'linear' or 'log'. Default is 'linear'.
115

116
    absolute : bool, optional
117
        If True, plots the absolute value of the wavelet coefficients. Default is False.
118

119
    freq_range : tuple of float, optional
120
        Tuple specifying the (min, max) frequency range to display. If None, displays the full range.
121

122
    show_colorbar : bool, optional
123
        If True, displays a colorbar next to the plot. Default is True.
124

125
    cmap : str, optional
126
        Colormap to use for the plot. If None, uses 'viridis' for absolute values or 'bwr' for signed values.
127

128
    norm : matplotlib.colors.Normalize, optional
129
        Normalization instance to scale data values. If None, a suitable normalization is chosen based on `zscale`.
130

131
    cbar_label : str, optional
132
        Label for the colorbar. If None, a default label is used based on the `absolute` parameter.
133

134
    nan_color : str, optional
135
        Color to use for NaN values. Default is 'black'.
136

137
    trend_color : bool, optional
138
        Color to use for the trend line. Not shown if None.
139

140
    **kwargs
141
        Additional keyword arguments passed to `ax.imshow()`.
142

143
    Returns
144
    -------
145
    Tuple[plt.Figure, plt.Axes]
146
        The figure and axes objects of the plot.
147

148
    Raises
149
    ------
150
    ValueError
151
        If the dimensions of `wavelet_data` do not match the lengths of `freq_grid` and `time_grid`.
152
    """
153

154
    # Determine the dimensions of the data
155
    Nf, Nt = wavelet_data.shape
1✔
156

157
    # Validate the dimensions
158
    if (Nf, Nt) != (len(freq_grid), len(time_grid)):
1!
UNCOV
159
        raise ValueError(
×
160
            f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
161
        )
162

163
    # Prepare the data for plotting
164
    z = wavelet_data.copy()
1✔
165
    if whiten_by is not None:
1!
UNCOV
166
        z = z / whiten_by
×
167
    if absolute:
1✔
168
        z = np.abs(z)
1✔
169

170
    # Determine normalization and colormap
171
    if norm is None:
1✔
172
        try:
1✔
173
            if np.all(np.isnan(z)):
1!
UNCOV
174
                raise ValueError("All wavelet data is NaN.")
×
175
            if zscale == "log":
1!
176
                vmin = np.nanmin(z[z > 0])
×
177
                vmax = np.nanmax(z[z < np.inf])
×
178
                if vmin > vmax:
×
179
                    raise ValueError("vmin > vmax... something wrong")
×
UNCOV
180
                norm = LogNorm(vmin=vmin, vmax=vmax)
×
181
            elif not absolute:
1✔
182
                vmin, vmax = np.nanmin(z), np.nanmax(z)
1✔
183
                vcenter = 0.0
1✔
184
                if vmin > vmax:
1!
UNCOV
185
                    raise ValueError("vmin > vmax... something wrong")
×
186

187
                norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
1✔
188
            else:
189
                norm = None  # Default linear scaling
1✔
190
        except Exception as e:
1✔
191
            warnings.warn(
1✔
192
                f"Error in determining normalization: {e}. Using default linear scaling."
193
            )
194
            norm = None
1✔
195

196
    if cmap is None:
1✔
197
        cmap = "viridis" if absolute else "bwr"
1✔
198
        cmap = plt.get_cmap(cmap)
1✔
199
        cmap.set_bad(color=nan_color)
1✔
200

201
    # Set up the plot
202
    if ax is None:
1✔
203
        fig, ax = plt.subplots()
1✔
204
    else:
205
        fig = ax.get_figure()
1✔
206

207
    # Plot the data
208
    im = ax.imshow(
1✔
209
        z,
210
        aspect="auto",
211
        extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
212
        origin="lower",
213
        cmap=cmap,
214
        norm=norm,
215
        interpolation="nearest",
216
        **kwargs,
217
    )
218
    if trend_color is not None:
1!
UNCOV
219
        plot_wavelet_trend(
×
220
            wavelet_data,
221
            time_grid,
222
            freq_grid,
223
            ax,
224
            color=trend_color,
225
            freq_range=freq_range,
226
            freq_scale=freq_scale,
227
        )
228

229
    # Add colorbar if requested
230
    if show_colorbar:
1✔
231
        cbar = fig.colorbar(im, ax=ax)
1✔
232
        if cbar_label is None:
1✔
233
            cbar_label = (
1✔
234
                "Absolute Wavelet Amplitude"
235
                if absolute
236
                else "Wavelet Amplitude"
237
            )
238
        cbar.set_label(cbar_label)
1✔
239

240
    # Configure axes scales
241
    ax.set_yscale(freq_scale)
1✔
242
    _fmt_time_axis(time_grid, ax)
1✔
243
    ax.set_ylabel("Frequency [Hz]")
1✔
244

245
    # Set frequency range if specified
246
    freq_range = freq_range or (freq_grid[0], freq_grid[-1])
1✔
247
    ax.set_ylim(freq_range)
1✔
248

249
    if detailed_axes:
1!
250
        ax.set_xlabel(r"Time Bins [$\Delta T$=" + f"{1 / Nt:.4f}s, Nt={Nt}]")
×
UNCOV
251
        ax.set_ylabel(r"Freq Bins [$\Delta F$=" + f"{1 / Nf:.4f}Hz, Nf={Nf}]")
×
252

253
    label = kwargs.get("label", "")
1✔
254
    NfNt_label = f"{Nf}x{Nt}" if show_gridinfo else ""
1✔
255
    txt = f"{label}\n{NfNt_label}" if label else NfNt_label
1✔
256
    if txt:
1!
257
        txtbox_kwargs.setdefault("boxstyle", "round")
1✔
258
        txtbox_kwargs.setdefault("facecolor", "white")
1✔
259
        txtbox_kwargs.setdefault("alpha", 0.2)
1✔
260
        ax.text(
1✔
261
            0.05,
262
            0.95,
263
            txt,
264
            transform=ax.transAxes,
265
            fontsize=14,
266
            verticalalignment="top",
267
            bbox=txtbox_kwargs,
268
        )
269

270
    # Adjust layout
271
    fig.tight_layout()
1✔
272

273
    return fig, ax
1✔
274

275

276
def plot_freqseries(
1✔
277
    data: np.ndarray,
278
    freq: np.ndarray,
279
    nyquist_frequency: float,
280
    ax=None,
281
    **kwargs,
282
):
283
    if ax == None:
×
284
        fig, ax = plt.subplots()
×
285
    ax.plot(freq, data, **kwargs)
×
UNCOV
286
    ax.set_xlabel("Frequency Bin [Hz]")
×
UNCOV
287
    ax.set_ylabel("Amplitude")
×
UNCOV
288
    ax.set_xlim(0, nyquist_frequency)
×
UNCOV
289
    return ax.figure, ax
×
290

291

292
def plot_periodogram(
1✔
293
    data: np.ndarray,
294
    freq: np.ndarray,
295
    nyquist_frequency: float,
296
    ax=None,
297
    **kwargs,
298
):
299
    if ax == None:
1!
UNCOV
300
        fig, ax = plt.subplots()
×
301

302
    ax.loglog(freq, np.abs(data) ** 2, **kwargs)
1✔
303
    flow = np.min(np.abs(freq))
1✔
304
    ax.set_xlabel("Frequency [Hz]")
1✔
305
    ax.set_ylabel("Periodigram")
1✔
306
    ax.set_xlim(left=flow, right=nyquist_frequency / 2)
1✔
307
    return ax.figure, ax
1✔
308

309

310
def plot_timeseries(
1✔
311
    data: np.ndarray, time: np.ndarray, ax=None, **kwargs
312
) -> Tuple[plt.Figure, plt.Axes]:
313
    """Custom method."""
314
    if ax == None:
1!
UNCOV
315
        fig, ax = plt.subplots()
×
316
    ax.plot(time, data, **kwargs)
1✔
317

318
    ax.set_ylabel("Amplitude")
1✔
319
    ax.set_xlim(left=time[0], right=time[-1])
1✔
320

321
    _fmt_time_axis(time, ax)
1✔
322

323
    return ax.figure, ax
1✔
324

325

326
def plot_spectrogram(
1✔
327
    timeseries_data: np.ndarray,
328
    fs: float,
329
    ax=None,
330
    spec_kwargs={},
331
    plot_kwargs={},
332
) -> Tuple[plt.Figure, plt.Axes]:
333
    f, t, Sxx = spectrogram(timeseries_data, fs=fs, **spec_kwargs)
×
334
    if ax == None:
×
UNCOV
335
        fig, ax = plt.subplots()
×
336

UNCOV
337
    if "cmap" not in plot_kwargs:
×
338
        plot_kwargs["cmap"] = "Reds"
×
339

340
    cm = ax.pcolormesh(t, f, Sxx, shading="nearest", **plot_kwargs)
×
341

342
    _fmt_time_axis(t, ax)
×
343

344
    ax.set_ylabel("Frequency [Hz]")
×
UNCOV
345
    ax.set_ylim(top=fs / 2.0)
×
UNCOV
346
    cbar = plt.colorbar(cm, ax=ax)
×
UNCOV
347
    cbar.set_label("Spectrogram Amplitude")
×
UNCOV
348
    return ax.figure, ax
×
349

350

351
def _fmt_time_axis(t, axes, t0=None, tmax=None):
1✔
352
    if t[-1] > DAY_S:  # If time goes beyond a day
1!
UNCOV
353
        axes.xaxis.set_major_formatter(
×
354
            plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
355
        )
UNCOV
356
        axes.set_xlabel("Time [days]")
×
357
    elif t[-1] > HOUR_S:  # If time goes beyond an hour
1!
UNCOV
358
        axes.xaxis.set_major_formatter(
×
359
            plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
360
        )
UNCOV
361
        axes.set_xlabel("Time [hr]")
×
362
    elif t[-1] > MIN_S:  # If time goes beyond a minute
1✔
363
        axes.xaxis.set_major_formatter(
1✔
364
            plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
365
        )
366
        axes.set_xlabel("Time [min]")
1✔
367
    else:
368
        axes.set_xlabel("Time [s]")
1✔
369
    t0 = t[0] if t0 is None else t0
1✔
370
    tmax = t[-1] if tmax is None else tmax
1✔
371
    axes.set_xlim(t0, tmax)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc