• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 10097470274

25 Jul 2024 04:00PM UTC coverage: 83.904% (-0.1%) from 84.035%
10097470274

push

github

web-flow
Merge pull request #475 from WenjieDu/dev

Add attention map visualization func

5 of 27 new or added lines in 3 files covered. (18.52%)

3 existing lines in 2 files now uncovered.

10582 of 12612 relevant lines covered (83.9%)

5.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

46.17
/pypots/nn/modules/fedformer/layers.py
1
"""
6✔
2

3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8
import math
6✔
9
from functools import partial
6✔
10
from typing import List, Tuple, Optional
6✔
11

12
import numpy as np
6✔
13
import torch
6✔
14
import torch.nn.functional as F
6✔
15
from scipy.special import eval_legendre
6✔
16
from sympy import Poly, legendre, Symbol, chebyshevt
6✔
17
from torch import Tensor
6✔
18
from torch import nn
6✔
19

20
from ..autoformer.layers import MovingAvgBlock
6✔
21
from ..transformer.attention import AttentionOperator
6✔
22

23

24
def legendreDer(k, x):
6✔
25
    def _legendre(k, x):
6✔
26
        return (2 * k + 1) * eval_legendre(k, x)
6✔
27

28
    out = 0
6✔
29
    for i in np.arange(k - 1, -1, -2):
6✔
30
        out += _legendre(i, x)
6✔
31
    return out
6✔
32

33

34
def phi_(phi_c, x, lb=0, ub=1):
6✔
35
    mask = np.logical_or(x < lb, x > ub) * 1.0
×
36
    return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask)
×
37

38

39
def get_phi_psi(k, base):
6✔
40
    x = Symbol("x")
6✔
41
    phi_coeff = np.zeros((k, k))
6✔
42
    phi_2x_coeff = np.zeros((k, k))
6✔
43
    if base == "legendre":
6✔
44
        for ki in range(k):
6✔
45
            coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs()
6✔
46
            phi_coeff[ki, : ki + 1] = np.flip(
6✔
47
                np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)
48
            )
49
            coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs()
6✔
50
            phi_2x_coeff[ki, : ki + 1] = np.flip(
6✔
51
                np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)
52
            )
53

54
        psi1_coeff = np.zeros((k, k))
6✔
55
        psi2_coeff = np.zeros((k, k))
6✔
56
        for ki in range(k):
6✔
57
            psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
6✔
58
            for i in range(k):
6✔
59
                a = phi_2x_coeff[ki, : ki + 1]
6✔
60
                b = phi_coeff[i, : i + 1]
6✔
61
                prod_ = np.convolve(a, b)
6✔
62
                prod_[np.abs(prod_) < 1e-8] = 0
6✔
63
                proj_ = (
6✔
64
                    prod_
65
                    * 1
66
                    / (np.arange(len(prod_)) + 1)
67
                    * np.power(0.5, 1 + np.arange(len(prod_)))
68
                ).sum()
69
                psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
6✔
70
                psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
6✔
71
            for j in range(ki):
6✔
72
                a = phi_2x_coeff[ki, : ki + 1]
6✔
73
                b = psi1_coeff[j, :]
6✔
74
                prod_ = np.convolve(a, b)
6✔
75
                prod_[np.abs(prod_) < 1e-8] = 0
6✔
76
                proj_ = (
6✔
77
                    prod_
78
                    * 1
79
                    / (np.arange(len(prod_)) + 1)
80
                    * np.power(0.5, 1 + np.arange(len(prod_)))
81
                ).sum()
82
                psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
6✔
83
                psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
6✔
84

85
            a = psi1_coeff[ki, :]
6✔
86
            prod_ = np.convolve(a, a)
6✔
87
            prod_[np.abs(prod_) < 1e-8] = 0
6✔
88
            norm1 = (
6✔
89
                prod_
90
                * 1
91
                / (np.arange(len(prod_)) + 1)
92
                * np.power(0.5, 1 + np.arange(len(prod_)))
93
            ).sum()
94

95
            a = psi2_coeff[ki, :]
6✔
96
            prod_ = np.convolve(a, a)
6✔
97
            prod_[np.abs(prod_) < 1e-8] = 0
6✔
98
            norm2 = (
6✔
99
                prod_
100
                * 1
101
                / (np.arange(len(prod_)) + 1)
102
                * (1 - np.power(0.5, 1 + np.arange(len(prod_))))
103
            ).sum()
104
            norm_ = np.sqrt(norm1 + norm2)
6✔
105
            psi1_coeff[ki, :] /= norm_
6✔
106
            psi2_coeff[ki, :] /= norm_
6✔
107
            psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
6✔
108
            psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
6✔
109

110
        phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
6✔
111
        psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
6✔
112
        psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]
6✔
113

114
    elif base == "chebyshev":
×
115
        for ki in range(k):
×
116
            if ki == 0:
×
117
                phi_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi)
×
118
                phi_2x_coeff[ki, : ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2)
×
119
            else:
120
                coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs()
×
121
                phi_coeff[ki, : ki + 1] = np.flip(
×
122
                    2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)
123
                )
124
                coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs()
×
125
                phi_2x_coeff[ki, : ki + 1] = np.flip(
×
126
                    np.sqrt(2)
127
                    * 2
128
                    / np.sqrt(np.pi)
129
                    * np.array(coeff_).astype(np.float64)
130
                )
131

132
        phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]
×
133

134
        x = Symbol("x")
×
135
        kUse = 2 * k
×
136
        roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
×
137
        x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
×
138
        # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
139
        # not needed for our purpose here, we use even k always to avoid
140
        wm = np.pi / kUse / 2
×
141

142
        psi1_coeff = np.zeros((k, k))
×
143
        psi2_coeff = np.zeros((k, k))
×
144

145
        psi1 = [[] for _ in range(k)]
×
146
        psi2 = [[] for _ in range(k)]
×
147

148
        for ki in range(k):
×
149
            psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
×
150
            for i in range(k):
×
151
                proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
×
152
                psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
×
153
                psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
×
154

155
            for j in range(ki):
×
156
                proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum()
×
157
                psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
×
158
                psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]
×
159

160
            psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
×
161
            psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)
×
162

163
            norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
×
164
            norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()
×
165

166
            norm_ = np.sqrt(norm1 + norm2)
×
167
            psi1_coeff[ki, :] /= norm_
×
168
            psi2_coeff[ki, :] /= norm_
×
169
            psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
×
170
            psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0
×
171

172
            psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16)
×
173
            psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1)
×
174

175
    return phi, psi1, psi2
6✔
176

177

178
def get_filter(base, k):
6✔
179
    def psi(psi1, psi2, i, inp):
6✔
180
        mask = (inp <= 0.5) * 1.0
6✔
181
        return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask)
6✔
182

183
    if base not in ["legendre", "chebyshev"]:
6✔
184
        raise Exception("Base not supported")
×
185

186
    x = Symbol("x")
6✔
187
    H0 = np.zeros((k, k))
6✔
188
    H1 = np.zeros((k, k))
6✔
189
    G0 = np.zeros((k, k))
6✔
190
    G1 = np.zeros((k, k))
6✔
191
    PHI0 = np.zeros((k, k))
6✔
192
    PHI1 = np.zeros((k, k))
6✔
193
    phi, psi1, psi2 = get_phi_psi(k, base)
6✔
194
    if base == "legendre":
6✔
195
        roots = Poly(legendre(k, 2 * x - 1)).all_roots()
6✔
196
        x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
6✔
197
        wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1)
6✔
198

199
        for ki in range(k):
6✔
200
            for kpi in range(k):
6✔
201
                H0[ki, kpi] = (
6✔
202
                    1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
203
                )
204
                G0[ki, kpi] = (
6✔
205
                    1
206
                    / np.sqrt(2)
207
                    * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
208
                )
209
                H1[ki, kpi] = (
6✔
210
                    1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
211
                )
212
                G1[ki, kpi] = (
6✔
213
                    1
214
                    / np.sqrt(2)
215
                    * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
216
                )
217

218
        PHI0 = np.eye(k)
6✔
219
        PHI1 = np.eye(k)
6✔
220

221
    elif base == "chebyshev":
×
222
        x = Symbol("x")
×
223
        kUse = 2 * k
×
224
        roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots()
×
225
        x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
×
226
        # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
227
        # not needed for our purpose here, we use even k always to avoid
228
        wm = np.pi / kUse / 2
×
229

230
        for ki in range(k):
×
231
            for kpi in range(k):
×
232
                H0[ki, kpi] = (
×
233
                    1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum()
234
                )
235
                G0[ki, kpi] = (
×
236
                    1
237
                    / np.sqrt(2)
238
                    * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum()
239
                )
240
                H1[ki, kpi] = (
×
241
                    1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum()
242
                )
243
                G1[ki, kpi] = (
×
244
                    1
245
                    / np.sqrt(2)
246
                    * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum()
247
                )
248

249
                PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2
×
250
                PHI1[ki, kpi] = (
×
251
                    wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)
252
                ).sum() * 2
253

254
        PHI0[np.abs(PHI0) < 1e-8] = 0
×
255
        PHI1[np.abs(PHI1) < 1e-8] = 0
×
256

257
    H0[np.abs(H0) < 1e-8] = 0
6✔
258
    H1[np.abs(H1) < 1e-8] = 0
6✔
259
    G0[np.abs(G0) < 1e-8] = 0
6✔
260
    G1[np.abs(G1) < 1e-8] = 0
6✔
261

262
    return H0, H1, G0, G1, PHI0, PHI1
6✔
263

264

265
class sparseKernelFT1d(nn.Module):
6✔
266
    def __init__(self, k, alpha, c=1, nl=1, initializer=None, **kwargs):
6✔
267
        super().__init__()
6✔
268

269
        self.modes1 = alpha
6✔
270
        self.scale = 1 / (c * k * c * k)
6✔
271
        self.weights1 = nn.Parameter(
6✔
272
            self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float)
273
        )
274
        self.weights2 = nn.Parameter(
6✔
275
            self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.float)
276
        )
277
        self.weights1.requires_grad = True
6✔
278
        self.weights2.requires_grad = True
6✔
279
        self.k = k
6✔
280

281
    def compl_mul1d(self, order, x, weights):
6✔
282
        x_flag = True
6✔
283
        w_flag = True
6✔
284
        if not torch.is_complex(x):
6✔
285
            x_flag = False
×
286
            x = torch.complex(x, torch.zeros_like(x).to(x.device))
×
287
        if not torch.is_complex(weights):
6✔
288
            w_flag = False
×
289
            weights = torch.complex(
×
290
                weights, torch.zeros_like(weights).to(weights.device)
291
            )
292
        if x_flag or w_flag:
6✔
293
            return torch.complex(
6✔
294
                torch.einsum(order, x.real, weights.real)
295
                - torch.einsum(order, x.imag, weights.imag),
296
                torch.einsum(order, x.real, weights.imag)
297
                + torch.einsum(order, x.imag, weights.real),
298
            )
299
        else:
300
            return torch.einsum(order, x.real, weights.real)
×
301

302
    def forward(self, x):
6✔
303
        B, N, c, k = x.shape  # (B, N, c, k)
6✔
304

305
        x = x.view(B, N, -1)
6✔
306
        x = x.permute(0, 2, 1)
6✔
307
        x_fft = torch.fft.rfft(x)
6✔
308
        # Multiply relevant Fourier modes
309
        mode = min(self.modes1, N // 2 + 1)
6✔
310
        out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat)
6✔
311
        out_ft[:, :, :mode] = self.compl_mul1d(
6✔
312
            "bix,iox->box",
313
            x_fft[:, :, :mode],
314
            torch.complex(self.weights1, self.weights2)[:, :, :mode],
315
        )
316
        x = torch.fft.irfft(out_ft, n=N)
6✔
317
        x = x.permute(0, 2, 1).view(B, N, c, k)
6✔
318
        return x
6✔
319

320

321
class MWT_CZ1d(nn.Module):
6✔
322
    def __init__(
6✔
323
        self, k=3, alpha=64, L=0, c=1, base="legendre", initializer=None, **kwargs
324
    ):
325
        super().__init__()
6✔
326

327
        self.k = k
6✔
328
        self.L = L
6✔
329
        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
6✔
330
        H0r = H0 @ PHI0
6✔
331
        G0r = G0 @ PHI0
6✔
332
        H1r = H1 @ PHI1
6✔
333
        G1r = G1 @ PHI1
6✔
334

335
        H0r[np.abs(H0r) < 1e-8] = 0
6✔
336
        H1r[np.abs(H1r) < 1e-8] = 0
6✔
337
        G0r[np.abs(G0r) < 1e-8] = 0
6✔
338
        G1r[np.abs(G1r) < 1e-8] = 0
6✔
339
        self.max_item = 3
6✔
340

341
        self.A = sparseKernelFT1d(k, alpha, c)
6✔
342
        self.B = sparseKernelFT1d(k, alpha, c)
6✔
343
        self.C = sparseKernelFT1d(k, alpha, c)
6✔
344

345
        self.T0 = nn.Linear(k, k)
6✔
346

347
        self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
6✔
348
        self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
6✔
349

350
        self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
6✔
351
        self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
6✔
352

353
    def forward(self, x):
6✔
354
        B, N, c, k = x.shape  # (B, N, k)
6✔
355
        ns = math.floor(np.log2(N))
6✔
356
        nl = pow(2, math.ceil(np.log2(N)))
6✔
357
        extra_x = x[:, 0 : nl - N, :, :]
6✔
358
        x = torch.cat([x, extra_x], 1)
6✔
359
        Ud = torch.jit.annotate(List[Tensor], [])
6✔
360
        Us = torch.jit.annotate(List[Tensor], [])
6✔
361
        for i in range(ns - self.L):
6✔
362
            d, x = self.wavelet_transform(x)
6✔
363
            Ud += [self.A(d) + self.B(x)]
6✔
364
            Us += [self.C(d)]
6✔
365
        x = self.T0(x)  # coarsest scale transform
6✔
366

367
        #        reconstruct
368
        for i in range(ns - 1 - self.L, -1, -1):
6✔
369
            x = x + Us[i]
6✔
370
            x = torch.cat((x, Ud[i]), -1)
6✔
371
            x = self.evenOdd(x)
6✔
372
        x = x[:, :N, :, :]
6✔
373

374
        return x
6✔
375

376
    def wavelet_transform(self, x):
6✔
377
        xa = torch.cat(
6✔
378
            [
379
                x[:, ::2, :, :],
380
                x[:, 1::2, :, :],
381
            ],
382
            -1,
383
        )
384
        d = torch.matmul(xa, self.ec_d)
6✔
385
        s = torch.matmul(xa, self.ec_s)
6✔
386
        return d, s
6✔
387

388
    def evenOdd(self, x):
6✔
389

390
        B, N, c, ich = x.shape  # (B, N, c, k)
6✔
391
        assert ich == 2 * self.k
6✔
392
        x_e = torch.matmul(x, self.rc_e)
6✔
393
        x_o = torch.matmul(x, self.rc_o)
6✔
394

395
        x = torch.zeros(B, N * 2, c, self.k, device=x.device)
6✔
396
        x[..., ::2, :, :] = x_e
6✔
397
        x[..., 1::2, :, :] = x_o
6✔
398
        return x
6✔
399

400

401
class MultiWaveletTransform(AttentionOperator):
6✔
402
    """
6✔
403
    1D multiwavelet block.
404
    """
405

406
    def __init__(
6✔
407
        self,
408
        ich=1,
409
        k=8,
410
        alpha=16,
411
        c=128,
412
        nCZ=1,
413
        L=0,
414
        base="legendre",
415
        attention_dropout=0.1,
416
    ):
417
        super().__init__()
6✔
418
        # print("base", base)
419
        self.k = k
6✔
420
        self.c = c
6✔
421
        self.L = L
6✔
422
        self.nCZ = nCZ
6✔
423
        self.Lk0 = nn.Linear(ich, c * k)
6✔
424
        self.Lk1 = nn.Linear(c * k, ich)
6✔
425
        self.ich = ich
6✔
426
        self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ))
6✔
427

428
    def forward(
6✔
429
        self,
430
        q: torch.Tensor,
431
        k: torch.Tensor,
432
        v: torch.Tensor,
433
        attn_mask: Optional[torch.Tensor] = None,
434
        **kwargs,
435
    ) -> Tuple[torch.Tensor, None]:
436
        # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
437
        # d_tensor could be d_q, d_k, d_v
438

439
        B, L, H, E = q.shape
6✔
440
        _, S, _, D = v.shape
6✔
441
        if L > S:
6✔
442
            zeros = torch.zeros_like(q[:, : (L - S), :]).float()
×
443
            v = torch.cat([v, zeros], dim=1)
×
444
            # k = torch.cat([k, zeros], dim=1)
445
        else:
446
            v = v[:, :L, :, :]
6✔
447
            # k = k[:, :L, :, :]
448
        v = v.reshape(B, L, -1)
6✔
449

450
        V = self.Lk0(v).view(B, L, self.c, -1)
6✔
451
        for i in range(self.nCZ):
6✔
452
            V = self.MWT_CZ[i](V)
6✔
453
            if i < self.nCZ - 1:
6✔
454
                V = F.relu(V)
×
455

456
        V = self.Lk1(V.view(B, L, -1))
6✔
457
        V = V.view(B, L, -1, D)
6✔
458
        return V.contiguous(), None
6✔
459

460

461
class FourierCrossAttentionW(nn.Module):
6✔
462
    def __init__(
6✔
463
        self,
464
        in_channels,
465
        out_channels,
466
        seq_len_q,
467
        seq_len_kv,
468
        modes=16,
469
        activation="tanh",
470
        mode_select_method="random",
471
    ):
472
        super().__init__()
×
473
        # print("corss fourier correlation used!")
474
        self.in_channels = in_channels
×
475
        self.out_channels = out_channels
×
476
        self.modes1 = modes
×
477
        self.activation = activation
×
478

479
    def compl_mul1d(self, order, x, weights):
6✔
480
        x_flag = True
×
481
        w_flag = True
×
482
        if not torch.is_complex(x):
×
483
            x_flag = False
×
484
            x = torch.complex(x, torch.zeros_like(x).to(x.device))
×
485
        if not torch.is_complex(weights):
×
486
            w_flag = False
×
487
            weights = torch.complex(
×
488
                weights, torch.zeros_like(weights).to(weights.device)
489
            )
490
        if x_flag or w_flag:
×
491
            return torch.complex(
×
492
                torch.einsum(order, x.real, weights.real)
493
                - torch.einsum(order, x.imag, weights.imag),
494
                torch.einsum(order, x.real, weights.imag)
495
                + torch.einsum(order, x.imag, weights.real),
496
            )
497
        else:
498
            return torch.einsum(order, x.real, weights.real)
×
499

500
    def forward(self, q, k, v, mask):
6✔
501
        B, L, E, H = q.shape
×
502

503
        xq = q.permute(0, 3, 2, 1)  # size = [B, H, E, L] torch.Size([3, 8, 64, 512])
×
504
        xk = k.permute(0, 3, 2, 1)
×
505
        xv = v.permute(0, 3, 2, 1)
×
506
        self.index_q = list(range(0, min(int(L // 2), self.modes1)))
×
507
        self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1)))
×
508

509
        # Compute Fourier coefficients
510
        xq_ft_ = torch.zeros(
×
511
            B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
512
        )
513
        xq_ft = torch.fft.rfft(xq, dim=-1)
×
514
        for i, j in enumerate(self.index_q):
×
515
            xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
×
516

517
        xk_ft_ = torch.zeros(
×
518
            B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat
519
        )
520
        xk_ft = torch.fft.rfft(xk, dim=-1)
×
521
        for i, j in enumerate(self.index_k_v):
×
522
            xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
×
523
        xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)
×
524
        if self.activation == "tanh":
×
525
            xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
×
526
        elif self.activation == "softmax":
×
527
            xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
×
528
            xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
×
529
        else:
530
            raise Exception(
×
531
                "{} actiation function is not implemented".format(self.activation)
532
            )
533
        xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
×
534

535
        xqkvw = xqkv_ft
×
536
        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
×
537
        for i, j in enumerate(self.index_q):
×
538
            out_ft[:, :, :, j] = xqkvw[:, :, :, i]
×
539

540
        out = torch.fft.irfft(
×
541
            out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
542
        ).permute(0, 3, 2, 1)
543
        # size = [B, L, H, E]
544
        return (out, None)
×
545

546

547
class MultiWaveletCross(AttentionOperator):
6✔
548
    """
6✔
549
    1D Multiwavelet Cross Attention layer.
550
    """
551

552
    def __init__(
6✔
553
        self,
554
        in_channels,
555
        out_channels,
556
        seq_len_q,
557
        seq_len_kv,
558
        modes,
559
        c=64,
560
        k=8,
561
        ich=512,
562
        L=0,
563
        base="legendre",
564
        mode_select_method="random",
565
        initializer=None,
566
        activation="tanh",
567
        **kwargs,
568
    ):
569
        super().__init__()
×
570

571
        self.c = c
×
572
        self.k = k
×
573
        self.L = L
×
574
        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
×
575
        H0r = H0 @ PHI0
×
576
        G0r = G0 @ PHI0
×
577
        H1r = H1 @ PHI1
×
578
        G1r = G1 @ PHI1
×
579

580
        H0r[np.abs(H0r) < 1e-8] = 0
×
581
        H1r[np.abs(H1r) < 1e-8] = 0
×
582
        G0r[np.abs(G0r) < 1e-8] = 0
×
583
        G1r[np.abs(G1r) < 1e-8] = 0
×
584
        self.max_item = 3
×
585

586
        self.attn1 = FourierCrossAttentionW(
×
587
            in_channels=in_channels,
588
            out_channels=out_channels,
589
            seq_len_q=seq_len_q,
590
            seq_len_kv=seq_len_kv,
591
            modes=modes,
592
            activation=activation,
593
            mode_select_method=mode_select_method,
594
        )
595
        self.attn2 = FourierCrossAttentionW(
×
596
            in_channels=in_channels,
597
            out_channels=out_channels,
598
            seq_len_q=seq_len_q,
599
            seq_len_kv=seq_len_kv,
600
            modes=modes,
601
            activation=activation,
602
            mode_select_method=mode_select_method,
603
        )
604
        self.attn3 = FourierCrossAttentionW(
×
605
            in_channels=in_channels,
606
            out_channels=out_channels,
607
            seq_len_q=seq_len_q,
608
            seq_len_kv=seq_len_kv,
609
            modes=modes,
610
            activation=activation,
611
            mode_select_method=mode_select_method,
612
        )
613
        self.attn4 = FourierCrossAttentionW(
×
614
            in_channels=in_channels,
615
            out_channels=out_channels,
616
            seq_len_q=seq_len_q,
617
            seq_len_kv=seq_len_kv,
618
            modes=modes,
619
            activation=activation,
620
            mode_select_method=mode_select_method,
621
        )
622
        self.T0 = nn.Linear(k, k)
×
623
        self.register_buffer("ec_s", torch.Tensor(np.concatenate((H0.T, H1.T), axis=0)))
×
624
        self.register_buffer("ec_d", torch.Tensor(np.concatenate((G0.T, G1.T), axis=0)))
×
625

626
        self.register_buffer("rc_e", torch.Tensor(np.concatenate((H0r, G0r), axis=0)))
×
627
        self.register_buffer("rc_o", torch.Tensor(np.concatenate((H1r, G1r), axis=0)))
×
628

629
        self.Lk = nn.Linear(ich, c * k)
×
630
        self.Lq = nn.Linear(ich, c * k)
×
631
        self.Lv = nn.Linear(ich, c * k)
×
632
        self.out = nn.Linear(c * k, ich)
×
633
        self.modes1 = modes
×
634

635
    def forward(
6✔
636
        self,
637
        q: torch.Tensor,
638
        k: torch.Tensor,
639
        v: torch.Tensor,
640
        attn_mask: Optional[torch.Tensor] = None,
641
        **kwargs,
642
    ) -> Tuple[torch.Tensor, None]:
643
        # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
644
        # d_tensor could be d_q, d_k, d_v
645

646
        B, N, H, E = q.shape  # (B, N, H, E) torch.Size([3, 768, 8, 2])
×
647
        _, S, _, _ = k.shape  # (B, S, H, E) torch.Size([3, 96, 8, 2])
×
648

649
        q = q.view(q.shape[0], q.shape[1], -1)
×
650
        k = k.view(k.shape[0], k.shape[1], -1)
×
651
        v = v.view(v.shape[0], v.shape[1], -1)
×
652
        q = self.Lq(q)
×
653
        q = q.view(q.shape[0], q.shape[1], self.c, self.k)
×
654
        k = self.Lk(k)
×
655
        k = k.view(k.shape[0], k.shape[1], self.c, self.k)
×
656
        v = self.Lv(v)
×
657
        v = v.view(v.shape[0], v.shape[1], self.c, self.k)
×
658

659
        if N > S:
×
660
            zeros = torch.zeros_like(q[:, : (N - S), :]).float()
×
661
            v = torch.cat([v, zeros], dim=1)
×
662
            k = torch.cat([k, zeros], dim=1)
×
663
        else:
664
            v = v[:, :N, :, :]
×
665
            k = k[:, :N, :, :]
×
666

667
        ns = math.floor(np.log2(N))
×
668
        nl = pow(2, math.ceil(np.log2(N)))
×
669
        extra_q = q[:, 0 : nl - N, :, :]
×
670
        extra_k = k[:, 0 : nl - N, :, :]
×
671
        extra_v = v[:, 0 : nl - N, :, :]
×
672
        q = torch.cat([q, extra_q], 1)
×
673
        k = torch.cat([k, extra_k], 1)
×
674
        v = torch.cat([v, extra_v], 1)
×
675

676
        Ud_q = torch.jit.annotate(List[Tuple[Tensor]], [])
×
677
        Ud_k = torch.jit.annotate(List[Tuple[Tensor]], [])
×
678
        Ud_v = torch.jit.annotate(List[Tuple[Tensor]], [])
×
679

680
        Us_q = torch.jit.annotate(List[Tensor], [])
×
681
        Us_k = torch.jit.annotate(List[Tensor], [])
×
682
        Us_v = torch.jit.annotate(List[Tensor], [])
×
683

684
        Ud = torch.jit.annotate(List[Tensor], [])
×
685
        Us = torch.jit.annotate(List[Tensor], [])
×
686

687
        # decompose
688
        for i in range(ns - self.L):
×
689
            d, q = self.wavelet_transform(q)
×
690
            Ud_q += [tuple([d, q])]
×
691
            Us_q += [d]
×
692
        for i in range(ns - self.L):
×
693
            d, k = self.wavelet_transform(k)
×
694
            Ud_k += [tuple([d, k])]
×
695
            Us_k += [d]
×
696
        for i in range(ns - self.L):
×
697
            d, v = self.wavelet_transform(v)
×
698
            Ud_v += [tuple([d, v])]
×
699
            Us_v += [d]
×
700
        for i in range(ns - self.L):
×
701
            dk, sk = Ud_k[i], Us_k[i]
×
702
            dq, sq = Ud_q[i], Us_q[i]
×
703
            dv, sv = Ud_v[i], Us_v[i]
×
704
            Ud += [
×
705
                self.attn1(dq[0], dk[0], dv[0], attn_mask)[0]
706
                + self.attn2(dq[1], dk[1], dv[1], attn_mask)[0]
707
            ]
708
            Us += [self.attn3(sq, sk, sv, attn_mask)[0]]
×
709
        v = self.attn4(q, k, v, attn_mask)[0]
×
710

711
        # reconstruct
712
        for i in range(ns - 1 - self.L, -1, -1):
×
713
            v = v + Us[i]
×
714
            v = torch.cat((v, Ud[i]), -1)
×
715
            v = self.evenOdd(v)
×
716
        v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1))
×
717
        return v.contiguous(), None
×
718

719
    def wavelet_transform(self, x):
6✔
720
        xa = torch.cat(
×
721
            [
722
                x[:, ::2, :, :],
723
                x[:, 1::2, :, :],
724
            ],
725
            -1,
726
        )
727
        d = torch.matmul(xa, self.ec_d)
×
728
        s = torch.matmul(xa, self.ec_s)
×
729
        return d, s
×
730

731
    def evenOdd(self, x):
6✔
732
        B, N, c, ich = x.shape  # (B, N, c, k)
×
733
        assert ich == 2 * self.k
×
734
        x_e = torch.matmul(x, self.rc_e)
×
735
        x_o = torch.matmul(x, self.rc_o)
×
736

737
        x = torch.zeros(B, N * 2, c, self.k, device=x.device)
×
738
        x[..., ::2, :, :] = x_e
×
739
        x[..., 1::2, :, :] = x_o
×
740
        return x
×
741

742

743
def get_frequency_modes(seq_len, modes=64, mode_select_method="random"):
6✔
744
    """
745
    get modes on frequency domain:
746
    'random' means sampling randomly;
747
    'else' means sampling the lowest modes;
748
    """
749
    modes = min(modes, seq_len // 2)
×
750
    if mode_select_method == "random":
×
751
        index = list(range(0, seq_len // 2))
×
752
        np.random.shuffle(index)
×
753
        index = index[:modes]
×
754
    else:
755
        index = list(range(0, modes))
×
756
    index.sort()
×
757
    return index
×
758

759

760
# ########## fourier layer #############
761
class FourierBlock(AttentionOperator):
6✔
762
    def __init__(
6✔
763
        self, in_channels, out_channels, seq_len, modes=0, mode_select_method="random"
764
    ):
765
        super().__init__()
×
766
        # print("fourier enhanced block used!")
UNCOV
767
        """
768
        1D Fourier block. It performs representation learning on frequency domain,
769
        it does FFT, linear transform, and Inverse FFT.
770
        """
771
        # get modes on frequency domain
772
        self.index = get_frequency_modes(
×
773
            seq_len, modes=modes, mode_select_method=mode_select_method
774
        )
775
        # print("modes={}, index={}".format(modes, self.index))
776

777
        self.scale = 1 / (in_channels * out_channels)
×
778
        self.weights1 = nn.Parameter(
×
779
            self.scale
780
            * torch.rand(
781
                8,
782
                in_channels // 8,
783
                out_channels // 8,
784
                len(self.index),
785
                dtype=torch.cfloat,
786
            )
787
        )
788

789
    # Complex multiplication
790
    def compl_mul1d(self, input, weights):
6✔
791
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
792
        return torch.einsum("bhi,hio->bho", input, weights)
×
793

794
    def forward(
6✔
795
        self,
796
        q: torch.Tensor,
797
        k: torch.Tensor,
798
        v: torch.Tensor,
799
        attn_mask: Optional[torch.Tensor] = None,
800
        **kwargs,
801
    ) -> Tuple[torch.Tensor, None]:
802
        # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
803
        # d_tensor could be d_q, d_k, d_v
804

805
        B, L, H, E = q.shape
×
806
        x = q.permute(0, 2, 3, 1)
×
807
        # Compute Fourier coefficients
808
        x_ft = torch.fft.rfft(x, dim=-1)
×
809
        # Perform Fourier neural operations
810
        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
×
811
        for wi, i in enumerate(self.index):
×
812
            out_ft[:, :, :, wi] = self.compl_mul1d(
×
813
                x_ft[:, :, :, i], self.weights1[:, :, :, wi]
814
            )
815
        # Return to time domain
816
        x = torch.fft.irfft(out_ft, n=x.size(-1))
×
817
        return x, None
×
818

819

820
# ########## Fourier Cross Former ####################
821
class FourierCrossAttention(AttentionOperator):
6✔
822
    def __init__(
6✔
823
        self,
824
        in_channels,
825
        out_channels,
826
        seq_len_q,
827
        seq_len_kv,
828
        modes=64,
829
        mode_select_method="random",
830
        activation="tanh",
831
        policy=0,
832
        num_heads=8,
833
    ):
834
        super().__init__()
×
835
        # print("fourier enhanced cross attention used!")
UNCOV
836
        """
837
        1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.
838
        """
839
        self.activation = activation
×
840
        self.in_channels = in_channels
×
841
        self.out_channels = out_channels
×
842
        # get modes for queries and keys (& values) on frequency domain
843
        self.index_q = get_frequency_modes(
×
844
            seq_len_q, modes=modes, mode_select_method=mode_select_method
845
        )
846
        self.index_kv = get_frequency_modes(
×
847
            seq_len_kv, modes=modes, mode_select_method=mode_select_method
848
        )
849

850
        # print("modes_q={}, index_q={}".format(len(self.index_q), self.index_q))
851
        # print("modes_kv={}, index_kv={}".format(len(self.index_kv), self.index_kv))
852

853
        self.scale = 1 / (in_channels * out_channels)
×
854
        self.weights1 = nn.Parameter(
×
855
            self.scale
856
            * torch.rand(
857
                num_heads,
858
                in_channels // num_heads,
859
                out_channels // num_heads,
860
                len(self.index_q),
861
                dtype=torch.float,
862
            )
863
        )
864
        self.weights2 = nn.Parameter(
×
865
            self.scale
866
            * torch.rand(
867
                num_heads,
868
                in_channels // num_heads,
869
                out_channels // num_heads,
870
                len(self.index_q),
871
                dtype=torch.float,
872
            )
873
        )
874

875
    # Complex multiplication
876
    def compl_mul1d(self, order, x, weights):
6✔
877
        x_flag = True
×
878
        w_flag = True
×
879
        if not torch.is_complex(x):
×
880
            x_flag = False
×
881
            x = torch.complex(x, torch.zeros_like(x).to(x.device))
×
882
        if not torch.is_complex(weights):
×
883
            w_flag = False
×
884
            weights = torch.complex(
×
885
                weights, torch.zeros_like(weights).to(weights.device)
886
            )
887
        if x_flag or w_flag:
×
888
            return torch.complex(
×
889
                torch.einsum(order, x.real, weights.real)
890
                - torch.einsum(order, x.imag, weights.imag),
891
                torch.einsum(order, x.real, weights.imag)
892
                + torch.einsum(order, x.imag, weights.real),
893
            )
894
        else:
895
            return torch.einsum(order, x.real, weights.real)
×
896

897
    def forward(
6✔
898
        self,
899
        q: torch.Tensor,
900
        k: torch.Tensor,
901
        v: torch.Tensor,
902
        attn_mask: Optional[torch.Tensor] = None,
903
        **kwargs,
904
    ) -> Tuple[torch.Tensor, None]:
905
        # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor]
906
        # d_tensor could be d_q, d_k, d_v
907

908
        B, L, H, E = q.shape
×
909
        xq = q.permute(0, 2, 3, 1)  # size = [B, H, E, L]
×
910
        xk = k.permute(0, 2, 3, 1)
×
911
        # xv = v.permute(0, 2, 3, 1)
912

913
        # Compute Fourier coefficients
914
        xq_ft_ = torch.zeros(
×
915
            B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat
916
        )
917
        xq_ft = torch.fft.rfft(xq, dim=-1)
×
918
        for i, j in enumerate(self.index_q):
×
919
            if j >= xq_ft.shape[3]:
×
920
                continue
×
921
            xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]
×
922
        xk_ft_ = torch.zeros(
×
923
            B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat
924
        )
925
        xk_ft = torch.fft.rfft(xk, dim=-1)
×
926
        for i, j in enumerate(self.index_kv):
×
927
            if j >= xk_ft.shape[3]:
×
928
                continue
×
929
            xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]
×
930

931
        # perform attention mechanism on frequency domain
932
        xqk_ft = self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)
×
933
        if self.activation == "tanh":
×
934
            xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())
×
935
        elif self.activation == "softmax":
×
936
            xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)
×
937
            xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))
×
938
        else:
939
            raise Exception(
×
940
                "{} actiation function is not implemented".format(self.activation)
941
            )
942
        xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)
×
943
        xqkvw = self.compl_mul1d(
×
944
            "bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)
945
        )
946
        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)
×
947
        for i, j in enumerate(self.index_q):
×
948
            if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:
×
949
                continue
×
950
            out_ft[:, :, :, j] = xqkvw[:, :, :, i]
×
951
        # Return to time domain
952
        out = torch.fft.irfft(
×
953
            out_ft / self.in_channels / self.out_channels, n=xq.size(-1)
954
        )
955
        return out, None
×
956

957

958
class SeriesDecompositionMultiBlock(nn.Module):
6✔
959
    """
6✔
960
    Series decomposition block from FEDfromer,
961
    i.e. series_decomp_multi from https://github.com/MAZiqing/FEDformer
962

963
    """
964

965
    def __init__(self, kernel_size):
6✔
966
        super().__init__()
6✔
967
        self.moving_avg = [MovingAvgBlock(kernel, stride=1) for kernel in kernel_size]
6✔
968
        self.layer = torch.nn.Linear(1, len(kernel_size))
6✔
969

970
    def forward(self, x):
6✔
971
        moving_mean = []
6✔
972
        for func in self.moving_avg:
6✔
973
            moving_avg = func(x)
6✔
974
            moving_mean.append(moving_avg.unsqueeze(-1))
6✔
975
        moving_mean = torch.cat(moving_mean, dim=-1)
6✔
976
        moving_mean = torch.sum(
6✔
977
            moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1
978
        )
979
        res = x - moving_mean
6✔
980
        return res, moving_mean
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc