• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

avivajpeyi / pywavelet / 13961110627

20 Mar 2025 03:06AM UTC coverage: 69.205% (-0.8%) from 70.039%
13961110627

push

github

avivajpeyi
increase the py version

150 of 244 branches covered (61.48%)

Branch coverage included in aggregate %.

868 of 1227 relevant lines covered (70.74%)

0.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.53
/src/pywavelet/types/plotting.py
1
import warnings
1✔
2
from typing import Optional, Tuple, Union
1✔
3

4
import matplotlib.pyplot as plt
1✔
5
import numpy as np
1✔
6
from matplotlib.colors import LogNorm, TwoSlopeNorm
1✔
7
from scipy.interpolate import interp1d
1✔
8
from scipy.signal import savgol_filter, spectrogram
1✔
9

10
MIN_S = 60
1✔
11
HOUR_S = 60 * MIN_S
1✔
12
DAY_S = 24 * HOUR_S
1✔
13

14

15
def plot_wavelet_trend(
1✔
16
    wavelet_data: np.ndarray,
17
    time_grid: np.ndarray,
18
    freq_grid: np.ndarray,
19
    ax: Optional[plt.Axes] = None,
20
    freq_scale: str = "linear",
21
    freq_range: Optional[Tuple[float, float]] = None,
22
    color: str = "black",
23
):
24
    x = time_grid
1✔
25
    y = __get_smoothed_y(x, np.abs(wavelet_data), freq_grid)
1✔
26
    if ax == None:
1!
27
        fig, ax = plt.subplots()
×
28
    ax.plot(x, y, color=color)
1✔
29

30
    # Configure axes scales
31
    ax.set_yscale(freq_scale)
1✔
32
    _fmt_time_axis(time_grid, ax)
1✔
33
    ax.set_ylabel("Frequency [Hz]")
1✔
34

35
    # Set frequency range if specified
36
    freq_range = freq_range or (freq_grid[0], freq_grid[-1])
1✔
37
    ax.set_ylim(freq_range)
1✔
38

39

40
def __get_smoothed_y(x, z, y_grid):
1✔
41
    Nf, Nt = z.shape
1✔
42
    y = np.zeros(Nt)
1✔
43
    dy = np.diff(y_grid)[0]
1✔
44
    for i in range(Nt):
1✔
45
        # if all values are nan, set to nan
46
        if np.all(np.isnan(z[:, i])):
1✔
47
            y[i] = np.nan
1✔
48
        else:
49
            y[i] = y_grid[np.nanargmax(z[:, i])]
1✔
50

51
    if not np.isnan(y).all():
1!
52
        # Interpolate to fill NaNs in y before smoothing
53
        nan_mask = ~np.isnan(y)
1✔
54
        if np.isnan(y).any():
1!
55
            interpolator = interp1d(
1✔
56
                x[nan_mask],
57
                y[nan_mask],
58
                kind="cubic",
59
                bounds_error=False,
60
                fill_value="extrapolate",
61
            )
62
            y = interpolator(x)  # Fill NaNs with interpolated values
1✔
63

64
        # Smooth the curve
65
        window_length = min(51, len(y) - 1 if len(y) % 2 == 0 else len(y))
1✔
66
        y = savgol_filter(y, window_length, 3)
1✔
67
        y[~nan_mask] = np.nan
1✔
68
    return y
1✔
69

70

71
def plot_wavelet_grid(
1✔
72
    wavelet_data: np.ndarray,
73
    time_grid: np.ndarray,
74
    freq_grid: np.ndarray,
75
    ax: Optional[plt.Axes] = None,
76
    zscale: str = "linear",
77
    freq_scale: str = "linear",
78
    absolute: bool = False,
79
    freq_range: Optional[Tuple[float, float]] = None,
80
    show_colorbar: bool = True,
81
    cmap: Optional[str] = None,
82
    norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
83
    cbar_label: Optional[str] = None,
84
    nan_color: Optional[str] = "black",
85
    detailed_axes: bool = False,
86
    show_gridinfo: bool = True,
87
    trend_color: Optional[str] = None,
88
    whiten_by: Optional[np.ndarray] = None,
89
    **kwargs,
90
) -> Tuple[plt.Figure, plt.Axes]:
91
    """
92
    Plot a 2D grid of wavelet coefficients.
93

94
    Parameters
95
    ----------
96
    wavelet_data : np.ndarray
97
        A 2D array containing the wavelet coefficients with shape (Nf, Nt),
98
        where Nf is the number of frequency bins and Nt is the number of time bins.
99

100
    time_grid : np.ndarray, optional
101
        1D array of time values corresponding to the time bins. If None, uses np.arange(Nt).
102

103
    freq_grid : np.ndarray, optional
104
        1D array of frequency values corresponding to the frequency bins. If None, uses np.arange(Nf).
105

106
    ax : plt.Axes, optional
107
        Matplotlib Axes object to plot on. If None, creates a new figure and axes.
108

109
    zscale : str, optional
110
        Scale for the color mapping. Options are 'linear' or 'log'. Default is 'linear'.
111

112
    freq_scale : str, optional
113
        Scale for the frequency axis. Options are 'linear' or 'log'. Default is 'linear'.
114

115
    absolute : bool, optional
116
        If True, plots the absolute value of the wavelet coefficients. Default is False.
117

118
    freq_range : tuple of float, optional
119
        Tuple specifying the (min, max) frequency range to display. If None, displays the full range.
120

121
    show_colorbar : bool, optional
122
        If True, displays a colorbar next to the plot. Default is True.
123

124
    cmap : str, optional
125
        Colormap to use for the plot. If None, uses 'viridis' for absolute values or 'bwr' for signed values.
126

127
    norm : matplotlib.colors.Normalize, optional
128
        Normalization instance to scale data values. If None, a suitable normalization is chosen based on `zscale`.
129

130
    cbar_label : str, optional
131
        Label for the colorbar. If None, a default label is used based on the `absolute` parameter.
132

133
    nan_color : str, optional
134
        Color to use for NaN values. Default is 'black'.
135

136
    trend_color : bool, optional
137
        Color to use for the trend line. Not shown if None.
138

139
    **kwargs
140
        Additional keyword arguments passed to `ax.imshow()`.
141

142
    Returns
143
    -------
144
    Tuple[plt.Figure, plt.Axes]
145
        The figure and axes objects of the plot.
146

147
    Raises
148
    ------
149
    ValueError
150
        If the dimensions of `wavelet_data` do not match the lengths of `freq_grid` and `time_grid`.
151
    """
152

153
    # Determine the dimensions of the data
154
    Nf, Nt = wavelet_data.shape
1✔
155

156
    # Validate the dimensions
157
    if (Nf, Nt) != (len(freq_grid), len(time_grid)):
1!
158
        raise ValueError(
×
159
            f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
160
        )
161

162
    # Prepare the data for plotting
163
    z = wavelet_data.copy()
1✔
164
    if whiten_by is not None:
1!
165
        z = z / whiten_by
×
166
    if absolute:
1✔
167
        z = np.abs(z)
1✔
168

169
    # Determine normalization and colormap
170
    if norm is None:
1✔
171
        try:
1✔
172
            if np.all(np.isnan(z)):
1!
173
                raise ValueError("All wavelet data is NaN.")
×
174
            if zscale == "log":
1!
175
                vmin = np.nanmin(z[z > 0])
×
176
                vmax = np.nanmax(z[z < np.inf])
×
177
                if vmin > vmax:
×
178
                    raise ValueError("vmin > vmax... something wrong")
×
179
                norm = LogNorm(vmin=vmin, vmax=vmax)
×
180
            elif not absolute:
1✔
181
                vmin, vmax = np.nanmin(z), np.nanmax(z)
1✔
182
                vcenter = 0.0
1✔
183
                if vmin > vmax:
1!
184
                    raise ValueError("vmin > vmax... something wrong")
×
185

186
                norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
1✔
187
            else:
188
                norm = None  # Default linear scaling
1✔
189
        except Exception as e:
1✔
190
            warnings.warn(
1✔
191
                f"Error in determining normalization: {e}. Using default linear scaling."
192
            )
193
            norm = None
1✔
194

195
    if cmap is None:
1✔
196
        cmap = "viridis" if absolute else "bwr"
1✔
197
        cmap = plt.get_cmap(cmap)
1✔
198
        cmap.set_bad(color=nan_color)
1✔
199

200
    # Set up the plot
201
    if ax is None:
1✔
202
        fig, ax = plt.subplots()
1✔
203
    else:
204
        fig = ax.get_figure()
1✔
205

206
    # Plot the data
207
    im = ax.imshow(
1✔
208
        z,
209
        aspect="auto",
210
        extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
211
        origin="lower",
212
        cmap=cmap,
213
        norm=norm,
214
        interpolation="nearest",
215
        **kwargs,
216
    )
217
    if trend_color is not None:
1!
218
        plot_wavelet_trend(
×
219
            wavelet_data,
220
            time_grid,
221
            freq_grid,
222
            ax,
223
            color=trend_color,
224
            freq_range=freq_range,
225
            freq_scale=freq_scale,
226
        )
227

228
    # Add colorbar if requested
229
    if show_colorbar:
1✔
230
        cbar = fig.colorbar(im, ax=ax)
1✔
231
        if cbar_label is None:
1✔
232
            cbar_label = (
1✔
233
                "Absolute Wavelet Amplitude"
234
                if absolute
235
                else "Wavelet Amplitude"
236
            )
237
        cbar.set_label(cbar_label)
1✔
238

239
    # Configure axes scales
240
    ax.set_yscale(freq_scale)
1✔
241
    _fmt_time_axis(time_grid, ax)
1✔
242
    ax.set_ylabel("Frequency [Hz]")
1✔
243

244
    # Set frequency range if specified
245
    freq_range = freq_range or (freq_grid[0], freq_grid[-1])
1✔
246
    ax.set_ylim(freq_range)
1✔
247

248
    if detailed_axes:
1!
249
        ax.set_xlabel(r"Time Bins [$\Delta T$=" + f"{1 / Nt:.4f}s, Nt={Nt}]")
×
250
        ax.set_ylabel(r"Freq Bins [$\Delta F$=" + f"{1 / Nf:.4f}Hz, Nf={Nf}]")
×
251

252
    label = kwargs.get("label", "")
1✔
253
    NfNt_label = f"{Nf}x{Nt}" if show_gridinfo else ""
1✔
254
    txt = f"{label}\n{NfNt_label}" if label else NfNt_label
1✔
255
    if txt:
1!
256
        ax.text(
1✔
257
            0.05,
258
            0.95,
259
            txt,
260
            transform=ax.transAxes,
261
            fontsize=14,
262
            verticalalignment="top",
263
            bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
264
        )
265

266
    # Adjust layout
267
    fig.tight_layout()
1✔
268

269
    return fig, ax
1✔
270

271

272
def plot_freqseries(
1✔
273
    data: np.ndarray,
274
    freq: np.ndarray,
275
    nyquist_frequency: float,
276
    ax=None,
277
    **kwargs,
278
):
279
    if ax == None:
×
280
        fig, ax = plt.subplots()
×
281
    ax.plot(freq, data, **kwargs)
×
282
    ax.set_xlabel("Frequency Bin [Hz]")
×
283
    ax.set_ylabel("Amplitude")
×
284
    ax.set_xlim(0, nyquist_frequency)
×
285
    return ax.figure, ax
×
286

287

288
def plot_periodogram(
1✔
289
    data: np.ndarray,
290
    freq: np.ndarray,
291
    nyquist_frequency: float,
292
    ax=None,
293
    **kwargs,
294
):
295
    if ax == None:
1!
296
        fig, ax = plt.subplots()
×
297

298
    ax.loglog(freq, np.abs(data) ** 2, **kwargs)
1✔
299
    flow = np.min(np.abs(freq))
1✔
300
    ax.set_xlabel("Frequency [Hz]")
1✔
301
    ax.set_ylabel("Periodigram")
1✔
302
    ax.set_xlim(left=flow, right=nyquist_frequency / 2)
1✔
303
    return ax.figure, ax
1✔
304

305

306
def plot_timeseries(
1✔
307
    data: np.ndarray, time: np.ndarray, ax=None, **kwargs
308
) -> Tuple[plt.Figure, plt.Axes]:
309
    """Custom method."""
310
    if ax == None:
1!
311
        fig, ax = plt.subplots()
×
312
    ax.plot(time, data, **kwargs)
1✔
313

314
    ax.set_ylabel("Amplitude")
1✔
315
    ax.set_xlim(left=time[0], right=time[-1])
1✔
316

317
    _fmt_time_axis(time, ax)
1✔
318

319
    return ax.figure, ax
1✔
320

321

322
def plot_spectrogram(
1✔
323
    timeseries_data: np.ndarray,
324
    fs: float,
325
    ax=None,
326
    spec_kwargs={},
327
    plot_kwargs={},
328
) -> Tuple[plt.Figure, plt.Axes]:
329
    f, t, Sxx = spectrogram(timeseries_data, fs=fs, **spec_kwargs)
×
330
    if ax == None:
×
331
        fig, ax = plt.subplots()
×
332

333
    if "cmap" not in plot_kwargs:
×
334
        plot_kwargs["cmap"] = "Reds"
×
335

336
    cm = ax.pcolormesh(t, f, Sxx, shading="nearest", **plot_kwargs)
×
337

338
    _fmt_time_axis(t, ax)
×
339

340
    ax.set_ylabel("Frequency [Hz]")
×
341
    ax.set_ylim(top=fs / 2.0)
×
342
    cbar = plt.colorbar(cm, ax=ax)
×
343
    cbar.set_label("Spectrogram Amplitude")
×
344
    return ax.figure, ax
×
345

346

347
def _fmt_time_axis(t, axes, t0=None, tmax=None):
1✔
348
    if t[-1] > DAY_S:  # If time goes beyond a day
1!
349
        axes.xaxis.set_major_formatter(
×
350
            plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
351
        )
352
        axes.set_xlabel("Time [days]")
×
353
    elif t[-1] > HOUR_S:  # If time goes beyond an hour
1!
354
        axes.xaxis.set_major_formatter(
×
355
            plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
356
        )
357
        axes.set_xlabel("Time [hr]")
×
358
    elif t[-1] > MIN_S:  # If time goes beyond a minute
1✔
359
        axes.xaxis.set_major_formatter(
1✔
360
            plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
361
        )
362
        axes.set_xlabel("Time [min]")
1✔
363
    else:
364
        axes.set_xlabel("Time [s]")
1✔
365
    t0 = t[0] if t0 is None else t0
1✔
366
    tmax = t[-1] if tmax is None else tmax
1✔
367
    axes.set_xlim(t0, tmax)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc