• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

avivajpeyi / pywavelet / 21431587909

28 Jan 2026 08:53AM UTC coverage: 97.297% (+21.8%) from 75.545%
21431587909

push

github

avivajpeyi
Merge branch 'main' of github.com:pywavelet/pywavelet

230 of 248 branches covered (92.74%)

Branch coverage included in aggregate %.

1174 of 1195 relevant lines covered (98.24%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.68
/src/pywavelet/types/wavelet.py
1
from typing import List, Tuple
1✔
2

3
import matplotlib.pyplot as plt
1✔
4
import numpy as np
1✔
5

6
from .common import float_dtype, fmt_timerange, is_documented_by, xp
1✔
7
from .plotting import plot_wavelet_grid, plot_wavelet_trend
1✔
8
from .wavelet_bins import compute_bins
1✔
9

10

11
class Wavelet:
1✔
12
    """
13
    A class to represent a wavelet transform result, with methods for plotting and accessing key properties like duration, sample rate, and grid size.
14

15
    Attributes
16
    ----------
17
    data : xp.ndarray
18
        2D array representing the wavelet coefficients (frequency x time).
19
    time : xp.ndarray
20
        Array of time points.
21
    freq : xp.ndarray
22
        Array of corresponding frequency points.
23
    """
24

25
    def __init__(
1✔
26
        self,
27
        data: xp.ndarray,
28
        time: xp.ndarray,
29
        freq: xp.ndarray,
30
    ):
31
        """
32
        Initialize the Wavelet object with data, time, and frequency arrays.
33

34
        Parameters
35
        ----------
36
        data : xp.ndarray
37
            2D array representing the wavelet coefficients (frequency x time).
38
        time : xp.ndarray
39
            Array of time points.
40
        freq : xp.ndarray
41
            Array of corresponding frequency points.
42

43
        Raises
44
        ------
45
        AssertionError
46
            If the length of the time array does not match the number of time bins in `data`.
47
            If the length of the frequency array does not match the number of frequency bins in `data`.
48
        """
49
        nf, nt = data.shape
1✔
50
        assert len(time) == nt, f"len(time)={len(time)} != nt={nt}"
1✔
51
        assert len(freq) == nf, f"len(freq)={len(freq)} != nf={nf}"
1✔
52

53
        self.data = data
1✔
54
        self.time = time
1✔
55
        self.freq = freq
1✔
56

57
    @classmethod
1✔
58
    def zeros_from_grid(cls, time: xp.ndarray, freq: xp.ndarray) -> "Wavelet":
1✔
59
        """
60
        Create a Wavelet object filled with zeros.
61

62
        Parameters
63
        ----------
64
        time: xp.ndarray
65
        freq: xp.ndarray
66

67
        Returns
68
        -------
69
        Wavelet
70
            A Wavelet object with zero-filled data array.
71
        """
72
        Nf, Nt = len(freq), len(time)
1✔
73
        return cls(
1✔
74
            data=xp.zeros((Nf, Nt), dtype=float_dtype), time=time, freq=freq
75
        )
76

77
    @classmethod
1✔
78
    def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
1✔
79
        """
80
        Create a Wavelet object filled with zeros.
81

82
        Parameters
83
        ----------
84
        Nf : int
85
            Number of frequency bins.
86
        Nt : int
87
            Number of time bins.
88

89
        Returns
90
        -------
91
        Wavelet
92
            A Wavelet object with zero-filled data array.
93
        """
94
        return cls.zeros_from_grid(*compute_bins(Nf, Nt, T))
1✔
95

96
    @is_documented_by(plot_wavelet_grid)
1✔
97
    def plot(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
1✔
98
        kwargs["time_grid"] = kwargs.get("time_grid", self.time)
1✔
99
        kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
1✔
100
        return plot_wavelet_grid(
1✔
101
            wavelet_data=self.data, ax=ax, *args, **kwargs
102
        )
103

104
    @is_documented_by(plot_wavelet_trend)
1✔
105
    def plot_trend(
1✔
106
        self, ax=None, *args, **kwargs
107
    ) -> Tuple[plt.Figure, plt.Axes]:
108
        kwargs["time_grid"] = kwargs.get("time_grid", self.time)
1✔
109
        kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
1✔
110
        return plot_wavelet_trend(
1✔
111
            wavelet_data=self.data, ax=ax, *args, **kwargs
112
        )
113

114
    @property
1✔
115
    def Nt(self) -> int:
1✔
116
        """
117
        Number of time bins.
118

119
        Returns
120
        -------
121
        int
122
            Length of the time array.
123
        """
124
        return len(self.time)
1✔
125

126
    @property
1✔
127
    def Nf(self) -> int:
1✔
128
        """
129
        Number of frequency bins.
130

131
        Returns
132
        -------
133
        int
134
            Length of the frequency array.
135
        """
136
        return len(self.freq)
1✔
137

138
    @property
1✔
139
    def ND(self) -> int:
1✔
140
        """
141
        Total number of data points in the wavelet grid.
142

143
        Returns
144
        -------
145
        int
146
            The product of `Nt` and `Nf`.
147
        """
148
        return self.Nt * self.Nf
1✔
149

150
    @property
1✔
151
    def delta_T(self) -> float:
1✔
152
        """
153
        Time resolution (ΔT) of the wavelet grid.
154

155
        Returns
156
        -------
157
        float
158
            Difference between consecutive time points.
159
        """
160
        return self.time[1] - self.time[0]
1✔
161

162
    @property
1✔
163
    def delta_F(self) -> float:
1✔
164
        """
165
        Frequency resolution (ΔF) of the wavelet grid.
166

167
        Returns
168
        -------
169
        float
170
            Inverse of twice the time resolution.
171
        """
172
        return 1 / (2 * self.delta_T)
1✔
173

174
    @property
1✔
175
    def duration(self) -> float:
1✔
176
        """
177
        Duration of the wavelet grid.
178

179
        Returns
180
        -------
181
        float
182
            Total duration in seconds.
183
        """
184
        return float(self.Nt * self.delta_T)
1✔
185

186
    @property
1✔
187
    def delta_t(self) -> float:
1✔
188
        """
189
        Time resolution of the wavelet grid, normalized by the total number of data points.
190

191
        Returns
192
        -------
193
        float
194
            Time resolution per data point.
195
        """
196
        return float(self.duration / self.ND)
1✔
197

198
    @property
1✔
199
    def delta_f(self) -> float:
1✔
200
        """
201
        Frequency resolution of the wavelet grid, normalized by the total number of data points.
202

203
        Returns
204
        -------
205
        float
206
            Frequency resolution per data point.
207
        """
208
        return 1 / (2 * self.delta_t)
1✔
209

210
    @property
1✔
211
    def t0(self) -> float:
1✔
212
        """
213
        Initial time point of the wavelet grid.
214

215
        Returns
216
        -------
217
        float
218
            First time point in the time array.
219
        """
220
        return float(self.time[0])
1✔
221

222
    @property
1✔
223
    def tend(self) -> float:
1✔
224
        """
225
        Final time point of the wavelet grid.
226

227
        Returns
228
        -------
229
        float
230
            Last time point in the time array.
231
        """
232
        return float(self.time[-1])
1✔
233

234
    @property
1✔
235
    def shape(self) -> Tuple[int, int]:
1✔
236
        """
237
        Shape of the wavelet grid.
238

239
        Returns
240
        -------
241
        Tuple[int, int]
242
            Tuple representing the shape of the data array (Nf, Nt).
243
        """
244
        return self.data.shape
1✔
245

246
    @property
1✔
247
    def sample_rate(self) -> float:
1✔
248
        """
249
        Sample rate of the wavelet grid.
250

251
        Returns
252
        -------
253
        float
254
            Sample rate calculated as the inverse of the time resolution.
255
        """
256
        return 1 / self.delta_t
1✔
257

258
    @property
1✔
259
    def fs(self) -> float:
1✔
260
        """
261
        Sample rate (fs) of the wavelet grid.
262

263
        Returns
264
        -------
265
        float
266
            The sample rate.
267
        """
268
        return self.sample_rate
1✔
269

270
    @property
1✔
271
    def nyquist_frequency(self) -> float:
1✔
272
        """
273
        Nyquist frequency of the wavelet grid.
274

275
        Returns
276
        -------
277
        float
278
            Nyquist frequency, which is half of the sample rate.
279
        """
280
        return self.sample_rate / 2
1✔
281

282
    def to_timeseries(self, nx: float = 4.0, mult: int = 32) -> "TimeSeries":
1✔
283
        """
284
        Convert the wavelet grid to a time-domain signal.
285

286
        Returns
287
        -------
288
        TimeSeries
289
            A `TimeSeries` object representing the time-domain signal.
290
        """
291
        from ..transforms import from_wavelet_to_time
1✔
292

293
        return from_wavelet_to_time(self, dt=self.delta_t, nx=nx, mult=mult)
1✔
294

295
    def to_frequencyseries(self, nx: float = 4.0) -> "FrequencySeries":
1✔
296
        """
297
        Convert the wavelet grid to a frequency-domain signal.
298

299
        Returns
300
        -------
301
        FrequencySeries
302
            A `FrequencySeries` object representing the frequency-domain signal.
303
        """
304
        from ..transforms import from_wavelet_to_freq
1✔
305

306
        return from_wavelet_to_freq(self, dt=self.delta_t, nx=nx)
1✔
307

308
    def __repr__(self) -> str:
309
        """
310
        Return a string representation of the Wavelet object.
311

312
        Returns
313
        -------
314
        str
315
            String containing information about the shape of the wavelet grid.
316
        """
317

318
        frange = ",".join([f"{f:.2e}" for f in (self.freq[0], self.freq[-1])])
319
        trange = fmt_timerange((self.t0, self.tend))
320
        Nfpow2 = int(xp.log2(self.shape[0]))
321
        Ntpow2 = int(xp.log2(self.shape[1]))
322
        shapef = f"NfxNf=[2^{Nfpow2}, 2^{Ntpow2}]"
323
        return f"Wavelet({shapef}, [{frange}]Hz, {trange})"
324

325
    def __add__(self, other):
1✔
326
        """Element-wise addition of two Wavelet objects."""
327
        if isinstance(other, Wavelet):
1✔
328
            return Wavelet(
1✔
329
                data=self.data + other.data, time=self.time, freq=self.freq
330
            )
331
        elif isinstance(other, float):
1✔
332
            return Wavelet(
1✔
333
                data=self.data + other, time=self.time, freq=self.freq
334
            )
335
        return NotImplemented
1✔
336

337
    def __sub__(self, other):
1✔
338
        """Element-wise subtraction of two Wavelet objects."""
339
        if isinstance(other, Wavelet):
1✔
340
            return Wavelet(
1✔
341
                data=self.data - other.data, time=self.time, freq=self.freq
342
            )
343
        elif isinstance(other, float):
1✔
344
            return Wavelet(
1✔
345
                data=self.data - other, time=self.time, freq=self.freq
346
            )
347
        return NotImplemented
1✔
348

349
    def __mul__(self, other):
1✔
350
        """Element-wise multiplication of two Wavelet objects."""
351
        if isinstance(other, WaveletMask):
1✔
352
            data = self.data.copy()
1✔
353
            data[~other.mask] = np.nan
1✔
354
            return Wavelet(data=data, time=self.time, freq=self.freq)
1✔
355
        elif isinstance(other, float):
1!
356
            return Wavelet(
×
357
                data=self.data * other, time=self.time, freq=self.freq
358
            )
359
        return NotImplemented
1✔
360

361
    def __truediv__(self, other):
1✔
362
        """Element-wise division of two Wavelet objects."""
363
        if isinstance(other, Wavelet):
1✔
364
            return Wavelet(
1✔
365
                data=self.data / other.data, time=self.time, freq=self.freq
366
            )
367
        elif isinstance(other, float):
1✔
368
            return Wavelet(
1✔
369
                data=self.data / other, time=self.time, freq=self.freq
370
            )
371
        return NotImplemented
1✔
372

373
    def __eq__(self, other: "Wavelet") -> bool:
1✔
374
        """Element-wise comparison of two Wavelet objects."""
375
        data_all_same = xp.isclose(xp.nansum(self.data - other.data), 0)
1✔
376
        time_same = (self.time == other.time).all()
1✔
377
        freq_same = (self.freq == other.freq).all()
1✔
378
        return data_all_same and time_same and freq_same
1✔
379

380
    def noise_weighted_inner_product(
1✔
381
        self, other: "Wavelet", psd: "Wavelet"
382
    ) -> float:
383
        """
384
        Compute the noise-weighted inner product of two wavelet grids given a PSD.
385

386
        Parameters
387
        ----------
388
        other : Wavelet
389
            A `Wavelet` object representing the other wavelet grid.
390
        psd : Wavelet
391
            A `Wavelet` object representing the power spectral density.
392

393
        Returns
394
        -------
395
        float
396
            The noise-weighted inner product.
397
        """
398
        from ..utils import noise_weighted_inner_product
1✔
399

400
        return noise_weighted_inner_product(self, other, psd)
1✔
401

402
    def matched_filter_snr(self, template: "Wavelet", psd: "Wavelet") -> float:
1✔
403
        """
404
        Compute the matched filter SNR of the wavelet grid given a template.
405

406
        Parameters
407
        ----------
408
        template : Wavelet
409
            A `Wavelet` object representing the template.
410

411
        Returns
412
        -------
413
        float
414
            The matched filter signal-to-noise ratio.
415
        """
416
        mf = self.noise_weighted_inner_product(template, psd)
1✔
417
        return mf / self.optimal_snr(psd)
1✔
418

419
    def optimal_snr(self, psd: "Wavelet") -> float:
1✔
420
        """
421
        Compute the optimal SNR of the wavelet grid given a PSD.
422

423
        Parameters
424
        ----------
425
        psd : Wavelet
426
            A `Wavelet` object representing the power spectral density.
427

428
        Returns
429
        -------
430
        float
431
            The optimal signal-to-noise ratio.
432
        """
433
        return xp.sqrt(self.noise_weighted_inner_product(self, psd))
1✔
434

435
    def __copy__(self):
1✔
436
        return Wavelet(
1✔
437
            data=self.data.copy(), time=self.time.copy(), freq=self.freq.copy()
438
        )
439

440
    def copy(self):
1✔
441
        return self.__copy__()
1✔
442

443

444
class WaveletMask(Wavelet):
1✔
445
    @property
1✔
446
    def mask(self):
1✔
447
        return self.data
1✔
448

449
    def __repr__(self):
450
        rpr = super().__repr__()
451
        rpr = rpr.replace("Wavelet", "WaveletMask")
452
        return rpr
453

454
    @classmethod
1✔
455
    def from_restrictions(
1✔
456
        cls,
457
        time_grid: xp.ndarray,
458
        freq_grid: xp.ndarray,
459
        frange: List[float],
460
        tgaps: List[Tuple[float, float]] = [],
461
    ):
462
        """
463
        Create a WaveletMask object from restrictions on time and frequency.
464

465
        Parameters
466
        ----------
467
        time_grid : xp.ndarray
468
            Array of time points.
469
        freq_grid : xp.ndarray
470
            Array of corresponding frequency points.
471
        frange : List[float]
472
            Frequency range to include.
473
        tgaps : List[Tuple[float, float]]
474
            List of time gaps to exclude.
475

476
        Returns
477
        -------
478
        WaveletMask
479
            A WaveletMask object with the specified restrictions.
480
        """
481
        self = cls.zeros_from_grid(time_grid, freq_grid)
1✔
482
        self.data[(freq_grid >= frange[0]) & (freq_grid <= frange[1]), :] = (
1✔
483
            True
484
        )
485

486
        for tgap in tgaps:
1✔
487
            self.data[:, (time_grid >= tgap[0]) & (time_grid <= tgap[1])] = (
1✔
488
                False
489
            )
490
        self.data = self.data.astype(bool)
1✔
491
        return self
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc