• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nz-gravity / LogPSplinePSD / 18113528024

29 Sep 2025 11:21PM UTC coverage: 81.475% (-2.8%) from 84.297%
18113528024

push

github

avivajpeyi
add GW tests

363 of 440 branches covered (82.5%)

Branch coverage included in aggregate %.

2830 of 3479 relevant lines covered (81.35%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.15
/src/log_psplines/example_datasets/varma_data.py
1
from typing import Optional
2✔
2

3
import matplotlib.pyplot as plt
2✔
4
import numpy as np
2✔
5
from numpy.fft import rfft
2✔
6

7

8
class VARMAData:
2✔
9
    """
10
    Simulate Vector Autoregressive Moving Average (VARMA) processes and compute related spectral properties.
11
    """
12

13
    def __init__(
2✔
14
        self,
15
        n_samples: int = 1024,
16
        sigma: np.ndarray = np.array([[1.0, 0.9], [0.9, 1.0]]),
17
        var_coeffs: np.ndarray = np.array(
18
            [[[0.5, 0.0], [0.0, -0.3]], [[0.0, 0.0], [0.0, -0.5]]]
19
        ),
20
        vma_coeffs: np.ndarray = np.array([[[1.0, 0.0], [0.0, 1.0]]]),
21
        seed: int = None,
22
    ):
23
        """
24
        Initialize the SimVARMA class.
25

26
        Args:
27
            n_samples (int): Number of samples to generate.
28
            var_coeffs (np.ndarray): VAR coefficient array.
29
            vma_coeffs (np.ndarray): VMA coefficient array.
30
            sigma (np.ndarray): Covariance matrix or scalar variance.
31
        """
32
        self.n_samples = n_samples
2✔
33
        self.var_coeffs = var_coeffs
2✔
34
        self.vma_coeffs = vma_coeffs
2✔
35
        self.sigma = sigma
2✔
36
        self.dim = vma_coeffs.shape[1]
2✔
37
        self.psd_scaling = 1
2✔
38
        self.n_freq_samples = n_samples // 2
2✔
39

40
        self.fs = 2 * np.pi
2✔
41
        self.freq = (
2✔
42
            np.linspace(0, 0.5, self.n_freq_samples, endpoint=False)[1:]
43
            * self.fs
44
        )
45
        self.time = np.arange(n_samples) / self.fs
2✔
46
        self.data = None  # set in "resimulate"
2✔
47
        self.periodogram = None  # set in "resimulate"
2✔
48
        self.welch_psd = None  # set in "resimulate"
2✔
49
        self.welch_f = None  # set in "resimulate"
2✔
50
        self.resimulate(seed=seed)
2✔
51

52
        self.psd = _calculate_true_varma_psd(
2✔
53
            self.n_freq_samples,
54
            self.dim,
55
            self.var_coeffs,
56
            self.vma_coeffs,
57
            self.sigma,
58
        )
59

60
    def resimulate(self, seed=None):
2✔
61
        """
62
        Simulate VARMA process.
63

64
        Args:
65
            seed (int, optional): Random seed for reproducibility.
66

67
        Returns:
68
            np.ndarray: Simulated VARMA process data.
69
        """
70
        if seed is not None:
2✔
71
            np.random.seed(seed)
×
72

73
        lag_ma = self.vma_coeffs.shape[0]
2✔
74
        lag_ar = self.var_coeffs.shape[0]
2✔
75

76
        if self.sigma.shape[0] == 1:
2✔
77
            cov_matrix = np.identity(self.dim) * self.sigma
×
78
        else:
79
            cov_matrix = self.sigma
2✔
80

81
        x_init = np.zeros((lag_ar + 1, self.dim))
2✔
82
        x = np.empty((self.n_samples + 101, self.dim))
2✔
83
        x[:] = np.nan
2✔
84
        x[: lag_ar + 1] = x_init
2✔
85
        epsilon = np.random.multivariate_normal(
2✔
86
            np.zeros(self.dim), cov_matrix, size=[lag_ma]
87
        )
88

89
        for i in range(lag_ar + 1, x.shape[0]):
2✔
90
            epsilon = np.concatenate(
2✔
91
                [
92
                    np.random.multivariate_normal(
93
                        np.zeros(self.dim), cov_matrix, size=[1]
94
                    ),
95
                    epsilon[:-1],
96
                ]
97
            )
98
            x[i] = np.sum(
2✔
99
                np.matmul(
100
                    self.var_coeffs,
101
                    x[i - 1 : i - lag_ar - 1 : -1][..., np.newaxis],
102
                ),
103
                axis=(0, -1),
104
            ) + np.sum(
105
                np.matmul(self.vma_coeffs, epsilon[..., np.newaxis]),
106
                axis=(0, -1),
107
            )
108

109
        self.data = x[101:]
2✔
110

111
    def get_periodogram(self):
2✔
112
        """
113
        Return consistently scaled periodogram (PSD matrix) for the data.
114
        Scaling: 2 * FFT_i * FFT_j^* / (N * 2 * pi)
115
        This matches the theoretical and empirical expectations for multivariate PSD.
116
        The first frequency bin is skipped for numerical stability.
117
        """
118
        n_freq = self.n_freq_samples
2✔
119
        dim = self.dim
2✔
120
        data = self.data
2✔
121
        N = data.shape[0]
2✔
122
        data = data - np.mean(data, axis=0)
2✔
123
        fft_data = rfft(data, axis=0)[:n_freq]
2✔
124
        periodogram = np.empty((n_freq, dim, dim), dtype=np.complex128)
2✔
125
        for i in range(dim):
2✔
126
            for j in range(dim):
2✔
127
                periodogram[:, i, j] = (
2✔
128
                    2
129
                    * (fft_data[:, i] * np.conj(fft_data[:, j]))
130
                    / (N * 2 * np.pi)
131
                )
132
        # Add epsilon to avoid log(0) and extremely small values
133
        eps = 1e-12
2✔
134
        periodogram = np.where(np.abs(periodogram) < eps, eps, periodogram)
2✔
135
        # Skip first frequency bin
136
        return periodogram[1:, ...]
2✔
137

138
    def get_true_psd(self):
2✔
139
        """
140
        Return consistently scaled true PSD matrix.
141
        Scaling: Already normalized by (2 * pi) in _calculate_true_varma_psd.
142
        The first frequency bin is skipped for numerical stability.
143
        """
144
        eps = 1e-12
2✔
145
        true_psd = np.where(np.abs(self.psd) < eps, eps, self.psd)
2✔
146
        return true_psd
2✔
147

148
    def plot(self, axs=None, fname: Optional[str] = None):
2✔
149
        """
150
        Matrix plot: diagonal is the PSD, below diagonal is real CSD, above diagonal is imag CSD.
151
        Plots both the true PSD and the periodogram using consistent scaling.
152
        """
153
        if self.data is None:
2✔
154
            raise ValueError("No data to plot. Run resimulate first.")
×
155
        dim = self.dim
2✔
156
        periodogram = self.get_periodogram()
2✔
157
        true_psd = self.get_true_psd()
2✔
158
        freq = self.freq
2✔
159
        # Setup axes
160
        if axs is None:
2✔
161
            fig, axs = plt.subplots(
2✔
162
                dim, dim, figsize=(4 * dim, 4 * dim), sharex=True
163
            )
164
        else:
165
            fig = axs[0, 0].figure
×
166
        data_kwgs = dict(alpha=0.3, lw=2, zorder=-10, color="k")
2✔
167
        true_kwgs = dict(lw=1, zorder=10, color="k")
2✔
168
        for i in range(dim):
2✔
169
            for j in range(dim):
2✔
170
                ax = axs[i, j]
2✔
171
                if i == j:
2✔
172
                    ax.plot(
2✔
173
                        freq,
174
                        true_psd[:, i, i].real,
175
                        label="True PSD",
176
                        **true_kwgs,
177
                    )
178
                    ax.plot(
2✔
179
                        freq,
180
                        periodogram[:, i, i].real,
181
                        label="Periodogram",
182
                        **data_kwgs,
183
                    )
184
                    ax.set_title(f"PSD: channel {i + 1}")
2✔
185
                    ax.set_yscale("log")
2✔
186
                elif i > j:
2✔
187
                    ax.plot(
2✔
188
                        freq,
189
                        true_psd[:, i, j].real,
190
                        label="True Re(CSD)",
191
                        **true_kwgs,
192
                    )
193
                    ax.plot(
2✔
194
                        freq,
195
                        periodogram[:, i, j].real,
196
                        label="Periodogram Re(CSD)",
197
                        **data_kwgs,
198
                    )
199
                    ax.set_title(f"Re(CSD): {i + 1},{j + 1}")
2✔
200
                else:
201
                    ax.plot(
2✔
202
                        freq,
203
                        true_psd[:, i, j].imag,
204
                        label="True Im(CSD)",
205
                        **true_kwgs,
206
                    )
207
                    ax.plot(
2✔
208
                        freq,
209
                        periodogram[:, i, j].imag,
210
                        label="Periodogram Im(CSD)",
211
                        **data_kwgs,
212
                    )
213
                    ax.set_title(f"Im(CSD): {i + 1},{j + 1}")
2✔
214
                if i == dim - 1:
2✔
215
                    ax.set_xlabel("Frequency (rad)")
2✔
216
                if j == 0:
2✔
217
                    ax.set_ylabel("Power / CSD")
2✔
218
                ax.legend(fontsize=8)
2✔
219
        fig.tight_layout()
2✔
220
        if fname:
2✔
221
            fig.savefig(fname, bbox_inches="tight")
2✔
222
        return axs
2✔
223

224

225
def _calculate_true_varma_psd(
2✔
226
    n_samples: int,
227
    dim: int,
228
    var_coeffs: np.ndarray,
229
    vma_coeffs: np.ndarray,
230
    sigma: np.ndarray,
231
) -> np.ndarray:
232
    """
233
    Calculate the spectral matrix for given frequencies.
234

235
    Args:
236
        n_samples int: Number of samples to generate for the true PSD (up to 0.5).
237
        var_coeffs (np.ndarray): VAR coefficient array.
238
        vma_coeffs (np.ndarray): VMA coefficient array.
239
        sigma (np.ndarray): Covariance matrix or scalar variance.
240

241
    Returns:
242
        np.ndarray: VARMA spectral matrix (PSD) for freq from 0 to 0.5.
243
    """
244
    freq = np.linspace(0, 0.5, n_samples, endpoint=False)[1:]
2✔
245
    spec_matrix = np.apply_along_axis(
2✔
246
        lambda f: _calculate_spec_matrix_helper(
247
            f, dim, var_coeffs, vma_coeffs, sigma
248
        ),
249
        axis=1,
250
        arr=freq.reshape(-1, 1),
251
    )
252
    return spec_matrix / (2 * np.pi)
2✔
253

254

255
def _calculate_spec_matrix_helper(f, dim, var_coeffs, vma_coeffs, sigma):
2✔
256
    """
257
    Helper function to calculate spectral matrix for a single frequency.
258

259
    Args:
260
        f (float): Single frequency value.
261
        var_coeffs (np.ndarray): VAR coefficient array.
262
        vma_coeffs (np.ndarray): VMA coefficient array.
263
        sigma (np.ndarray): Covariance matrix or scalar variance.
264

265
    Returns:
266
        np.ndarray: Calculated spectral matrix for the given frequency.
267
    """
268
    if sigma.shape[0] == 1:
2✔
269
        cov_matrix = np.identity(dim) * sigma
×
270
    else:
271
        cov_matrix = sigma
2✔
272

273
    k_ar = np.arange(1, var_coeffs.shape[0] + 1)
2✔
274
    A_f_re_ar = np.sum(
2✔
275
        var_coeffs * np.cos(np.pi * 2 * k_ar * f)[:, np.newaxis, np.newaxis],
276
        axis=0,
277
    )
278
    A_f_im_ar = -np.sum(
2✔
279
        var_coeffs * np.sin(np.pi * 2 * k_ar * f)[:, np.newaxis, np.newaxis],
280
        axis=0,
281
    )
282
    A_f_ar = A_f_re_ar + 1j * A_f_im_ar
2✔
283
    A_bar_f_ar = np.identity(dim) - A_f_ar
2✔
284
    H_f_ar = np.linalg.inv(A_bar_f_ar)
2✔
285

286
    k_ma = np.arange(vma_coeffs.shape[0])
2✔
287
    A_f_re_ma = np.sum(
2✔
288
        vma_coeffs * np.cos(np.pi * 2 * k_ma * f)[:, np.newaxis, np.newaxis],
289
        axis=0,
290
    )
291
    A_f_im_ma = -np.sum(
2✔
292
        vma_coeffs * np.sin(np.pi * 2 * k_ma * f)[:, np.newaxis, np.newaxis],
293
        axis=0,
294
    )
295
    A_f_ma = A_f_re_ma + 1j * A_f_im_ma
2✔
296
    A_bar_f_ma = A_f_ma
2✔
297
    H_f_ma = A_bar_f_ma
2✔
298

299
    spec_mat = H_f_ar @ H_f_ma @ cov_matrix @ H_f_ma.conj().T @ H_f_ar.conj().T
2✔
300
    return spec_mat
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc