• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 8614163418

09 Apr 2024 10:24AM UTC coverage: 81.03% (+0.2%) from 80.813%
8614163418

Pull #343

github

web-flow
Merge 1fd684f5b into 93062a244
Pull Request #343: Apply SAITS embedding strategy to new added models

79 of 80 new or added lines in 10 files covered. (98.75%)

2 existing lines in 1 file now uncovered.

6847 of 8450 relevant lines covered (81.03%)

4.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.36
/pypots/imputation/dlinear/modules/core.py
1
"""
6✔
2

3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8
from typing import Optional
6✔
9

10
import torch
6✔
11
import torch.nn as nn
6✔
12

13
from ...autoformer.modules.submodules import SeriesDecompositionBlock
6✔
14
from ....utils.metrics import calc_mse
6✔
15

16

17
class _DLinear(nn.Module):
6✔
18
    def __init__(
6✔
19
        self,
20
        n_steps: int,
21
        n_features: int,
22
        moving_avg_window_size: int,
23
        individual: bool = False,
24
        d_model: Optional[int] = None,
25
    ):
26
        super().__init__()
6✔
27

28
        self.n_steps = n_steps
6✔
29
        self.n_features = n_features
6✔
30
        self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size)
6✔
31
        self.individual = individual
6✔
32

33
        if individual:
6✔
34
            # create linear layers for each feature individually
35
            self.linear_seasonal = nn.ModuleList()
6✔
36
            self.linear_trend = nn.ModuleList()
6✔
37
            for i in range(n_features):
6✔
38
                self.linear_seasonal.append(nn.Linear(n_steps, n_steps))
6✔
39
                self.linear_trend.append(nn.Linear(n_steps, n_steps))
6✔
40
                self.linear_seasonal[i].weight = nn.Parameter(
6✔
41
                    (1 / n_steps) * torch.ones([n_steps, n_steps])
42
                )
43
                self.linear_trend[i].weight = nn.Parameter(
6✔
44
                    (1 / n_steps) * torch.ones([n_steps, n_steps])
45
                )
46
        else:
47
            if d_model is None:
6✔
NEW
48
                raise ValueError(
×
49
                    "The argument d_model is necessary for DLinear in the non-individual mode."
50
                )
51
            self.linear_seasonal = nn.Linear(n_steps, n_steps)
6✔
52
            self.linear_trend = nn.Linear(n_steps, n_steps)
6✔
53
            self.linear_seasonal.weight = nn.Parameter(
6✔
54
                (1 / n_steps) * torch.ones([n_steps, n_steps])
55
            )
56
            self.linear_trend.weight = nn.Parameter(
6✔
57
                (1 / n_steps) * torch.ones([n_steps, n_steps])
58
            )
59

60
            self.linear_seasonal_embedding = nn.Linear(n_features * 2, d_model)
6✔
61
            self.linear_trend_embedding = nn.Linear(n_features * 2, d_model)
6✔
62
            self.linear_seasonal_output = nn.Linear(d_model, n_features)
6✔
63
            self.linear_trend_output = nn.Linear(d_model, n_features)
6✔
64

65
    def forward(self, inputs: dict, training: bool = True) -> dict:
6✔
66
        X, masks = inputs["X"], inputs["missing_mask"]
6✔
67

68
        # input preprocessing and embedding for DLinear
69
        seasonal_init, trend_init = self.series_decomp(X)
6✔
70

71
        # DLinear processing
72
        if self.individual:
6✔
73
            seasonal_init, trend_init = seasonal_init.permute(
6✔
74
                0, 2, 1
75
            ), trend_init.permute(0, 2, 1)
76
            seasonal_output = torch.zeros(
6✔
77
                [seasonal_init.size(0), seasonal_init.size(1), self.n_steps],
78
                dtype=seasonal_init.dtype,
79
            ).to(seasonal_init.device)
80
            trend_output = torch.zeros(
6✔
81
                [trend_init.size(0), trend_init.size(1), self.n_steps],
82
                dtype=trend_init.dtype,
83
            ).to(trend_init.device)
84
            for i in range(self.n_features):
6✔
85
                seasonal_output[:, i, :] = self.linear_seasonal[i](
6✔
86
                    seasonal_init[:, i, :]
87
                )
88
                trend_output[:, i, :] = self.linear_trend[i](trend_init[:, i, :])
6✔
89

90
            seasonal_output = seasonal_output.permute(0, 2, 1)
6✔
91
            trend_output = trend_output.permute(0, 2, 1)
6✔
92
        else:
93
            # WDU: the original DLinear paper isn't proposed for imputation task. Hence the model doesn't take
94
            # the missing mask into account, which means, in the process, the model doesn't know which part of
95
            # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the
96
            # embedding layers to project the concatenation of features and masks into a hidden space, as well as
97
            # the output layers to project the seasonal and trend from the hidden space to the original space.
98
            # But this is only for the non-individual mode.
99
            seasonal_init = torch.cat([seasonal_init, masks], dim=2)
6✔
100
            trend_init = torch.cat([trend_init, masks], dim=2)
6✔
101
            seasonal_init = self.linear_seasonal_embedding(seasonal_init)
6✔
102
            trend_init = self.linear_trend_embedding(trend_init)
6✔
103
            seasonal_init, trend_init = seasonal_init.permute(
6✔
104
                0, 2, 1
105
            ), trend_init.permute(0, 2, 1)
106

107
            seasonal_output = self.linear_seasonal(seasonal_init)
6✔
108
            trend_output = self.linear_trend(trend_init)
6✔
109
            seasonal_output = seasonal_output.permute(0, 2, 1)
6✔
110
            trend_output = trend_output.permute(0, 2, 1)
6✔
111
            seasonal_output = self.linear_seasonal_output(seasonal_output)
6✔
112
            trend_output = self.linear_trend_output(trend_output)
6✔
113

114
        output = seasonal_output + trend_output
6✔
115

116
        imputed_data = masks * X + (1 - masks) * output
6✔
117
        results = {
6✔
118
            "imputed_data": imputed_data,
119
        }
120

121
        if training:
6✔
122
            # `loss` is always the item for backward propagating to update the model
123
            loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
6✔
124
            results["loss"] = loss
6✔
125

126
        return results
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc