• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 4649236213

pending completion
4649236213

Pull #44

github

GitHub
Merge 0b894402b into 4646f5c3f
Pull Request #44: Make imputation models `val_X_intact` and `val_indicating_mask` be included in input `val_set` originally

15 of 15 new or added lines in 6 files covered. (100.0%)

2692 of 3170 relevant lines covered (84.92%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.78
/pypots/imputation/transformer.py
1
"""
1✔
2
PyTorch Transformer model for the time-series imputation task.
3

4
Notes
5
-----
6
Partial implementation uses code from https://github.com/WenjieDu/SAITS.
7
"""
8

9
# Created by Wenjie Du <wenjay.du@gmail.com>
10
# License: GPL-v3
11

12
from typing import Tuple, Union, Optional
1✔
13

14
import h5py
1✔
15
import numpy as np
1✔
16
import torch
1✔
17
import torch.nn as nn
1✔
18
import torch.nn.functional as F
1✔
19
from torch.utils.data import DataLoader
1✔
20

21
from pypots.data.base import BaseDataset
1✔
22
from pypots.data.dataset_for_mit import DatasetForMIT
1✔
23
from pypots.imputation.base import BaseNNImputer
1✔
24
from pypots.utils.metrics import cal_mae
1✔
25

26

27
class ScaledDotProductAttention(nn.Module):
1✔
28
    """Scaled dot-product attention"""
1✔
29

30
    def __init__(self, temperature: float, attn_dropout: float = 0.1):
1✔
31
        super().__init__()
1✔
32
        self.temperature = temperature
1✔
33
        self.dropout = nn.Dropout(attn_dropout)
1✔
34

35
    def forward(
1✔
36
        self,
37
        q: torch.Tensor,
38
        k: torch.Tensor,
39
        v: torch.Tensor,
40
        attn_mask: torch.Tensor = None,
41
    ) -> Tuple[torch.Tensor, torch.Tensor]:
42
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
1✔
43
        if attn_mask is not None:
1✔
44
            attn = attn.masked_fill(attn_mask == 1, -1e9)
1✔
45
        attn = self.dropout(F.softmax(attn, dim=-1))
1✔
46
        output = torch.matmul(attn, v)
1✔
47
        return output, attn
1✔
48

49

50
class MultiHeadAttention(nn.Module):
1✔
51
    """original Transformer multi-head attention"""
1✔
52

53
    def __init__(
1✔
54
        self,
55
        n_head: int,
56
        d_model: int,
57
        d_k: int,
58
        d_v: int,
59
        attn_dropout: float,
60
    ):
61
        super().__init__()
1✔
62

63
        self.n_head = n_head
1✔
64
        self.d_k = d_k
1✔
65
        self.d_v = d_v
1✔
66

67
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
1✔
68
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
1✔
69
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
1✔
70

71
        self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout)
1✔
72
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
1✔
73

74
    def forward(
1✔
75
        self,
76
        q: torch.Tensor,
77
        k: torch.Tensor,
78
        v: torch.Tensor,
79
        attn_mask: torch.Tensor = None,
80
    ) -> Tuple[torch.Tensor, torch.Tensor]:
81
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
1✔
82
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
1✔
83

84
        # Pass through the pre-attention projection: b x lq x (n*dv)
85
        # Separate different heads: b x lq x n x dv
86
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
1✔
87
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
1✔
88
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
1✔
89

90
        # Transpose for attention dot product: b x n x lq x dv
91
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
1✔
92

93
        if attn_mask is not None:
1✔
94
            # this mask is imputation mask, which is not generated from each batch, so needs broadcasting on batch dim
95
            attn_mask = attn_mask.unsqueeze(0).unsqueeze(
1✔
96
                1
97
            )  # For batch and head axis broadcasting.
98

99
        v, attn_weights = self.attention(q, k, v, attn_mask)
1✔
100

101
        # Transpose to move the head dimension back: b x lq x n x dv
102
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
103
        v = v.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
1✔
104
        v = self.fc(v)
1✔
105
        return v, attn_weights
1✔
106

107

108
class PositionWiseFeedForward(nn.Module):
1✔
109
    def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1):
1✔
110
        super().__init__()
1✔
111
        self.w_1 = nn.Linear(d_in, d_hid)
1✔
112
        self.w_2 = nn.Linear(d_hid, d_in)
1✔
113
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
1✔
114
        self.dropout = nn.Dropout(dropout)
1✔
115

116
    def forward(self, x: torch.Tensor) -> torch.Tensor:
1✔
117
        residual = x
1✔
118
        x = self.layer_norm(x)
1✔
119
        x = self.w_2(F.relu(self.w_1(x)))
1✔
120
        x = self.dropout(x)
1✔
121
        x += residual
1✔
122
        return x
1✔
123

124

125
class EncoderLayer(nn.Module):
1✔
126
    def __init__(
1✔
127
        self,
128
        d_time: int,
129
        d_feature: int,
130
        d_model: int,
131
        d_inner: int,
132
        n_head: int,
133
        d_k: int,
134
        d_v: int,
135
        dropout: float = 0.1,
136
        attn_dropout: float = 0.1,
137
        diagonal_attention_mask: bool = False,
138
    ):
139
        super().__init__()
1✔
140

141
        self.diagonal_attention_mask = diagonal_attention_mask
1✔
142
        self.d_time = d_time
1✔
143
        self.d_feature = d_feature
1✔
144

145
        self.layer_norm = nn.LayerNorm(d_model)
1✔
146
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, attn_dropout)
1✔
147
        self.dropout = nn.Dropout(dropout)
1✔
148
        self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
1✔
149

150
    def forward(self, enc_input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
151
        if self.diagonal_attention_mask:
1✔
152
            mask_time = torch.eye(self.d_time).to(enc_input.device)
1✔
153
        else:
154
            mask_time = None
1✔
155

156
        residual = enc_input
1✔
157
        # here we apply LN before attention cal, namely Pre-LN, refer paper https://arxiv.org/abs/2002.04745
158
        enc_input = self.layer_norm(enc_input)
1✔
159
        enc_output, attn_weights = self.slf_attn(
1✔
160
            enc_input, enc_input, enc_input, attn_mask=mask_time
161
        )
162
        enc_output = self.dropout(enc_output)
1✔
163
        enc_output += residual
1✔
164

165
        enc_output = self.pos_ffn(enc_output)
1✔
166
        return enc_output, attn_weights
1✔
167

168

169
class PositionalEncoding(nn.Module):
1✔
170
    def __init__(self, d_hid: int, n_position: int = 200):
1✔
171
        super().__init__()
1✔
172
        # Not a parameter
173
        self.register_buffer(
1✔
174
            "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
175
        )
176

177
    @staticmethod
1✔
178
    def _get_sinusoid_encoding_table(n_position: int, d_hid: int) -> torch.Tensor:
1✔
179
        """Sinusoid position encoding table"""
180

181
        def get_position_angle_vec(position):
1✔
182
            return [
1✔
183
                position / np.power(10000, 2 * (hid_j // 2) / d_hid)
184
                for hid_j in range(d_hid)
185
            ]
186

187
        sinusoid_table = np.array(
1✔
188
            [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
189
        )
190
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
1✔
191
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
1✔
192
        return torch.FloatTensor(sinusoid_table).unsqueeze(0)
1✔
193

194
    def forward(self, x):
1✔
195
        return x + self.pos_table[:, : x.size(1)].clone().detach()
1✔
196

197

198
class _TransformerEncoder(nn.Module):
1✔
199
    def __init__(
1✔
200
        self,
201
        n_layers: int,
202
        d_time: int,
203
        d_feature: int,
204
        d_model: int,
205
        d_inner: int,
206
        n_head: int,
207
        d_k: int,
208
        d_v: int,
209
        dropout: float,
210
        ORT_weight: float = 1,
211
        MIT_weight: float = 1,
212
    ):
213
        super().__init__()
1✔
214
        self.n_layers = n_layers
1✔
215
        actual_d_feature = d_feature * 2
1✔
216
        self.ORT_weight = ORT_weight
1✔
217
        self.MIT_weight = MIT_weight
1✔
218

219
        self.layer_stack = nn.ModuleList(
1✔
220
            [
221
                EncoderLayer(
222
                    d_time,
223
                    actual_d_feature,
224
                    d_model,
225
                    d_inner,
226
                    n_head,
227
                    d_k,
228
                    d_v,
229
                    dropout,
230
                    0,
231
                    False,
232
                )
233
                for _ in range(n_layers)
234
            ]
235
        )
236

237
        self.embedding = nn.Linear(actual_d_feature, d_model)
1✔
238
        self.position_enc = PositionalEncoding(d_model, n_position=d_time)
1✔
239
        self.dropout = nn.Dropout(p=dropout)
1✔
240
        self.reduce_dim = nn.Linear(d_model, d_feature)
1✔
241

242
    def impute(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
243
        X, masks = inputs["X"], inputs["missing_mask"]
1✔
244
        input_X = torch.cat([X, masks], dim=2)
1✔
245
        input_X = self.embedding(input_X)
1✔
246
        enc_output = self.dropout(self.position_enc(input_X))
1✔
247

248
        for encoder_layer in self.layer_stack:
1✔
249
            enc_output, _ = encoder_layer(enc_output)
1✔
250

251
        learned_presentation = self.reduce_dim(enc_output)
1✔
252
        imputed_data = (
1✔
253
            masks * X + (1 - masks) * learned_presentation
254
        )  # replace non-missing part with original data
255
        return imputed_data, learned_presentation
1✔
256

257
    def forward(self, inputs: dict) -> dict:
1✔
258
        X, masks = inputs["X"], inputs["missing_mask"]
1✔
259
        imputed_data, learned_presentation = self.impute(inputs)
1✔
260
        reconstruction_loss = cal_mae(learned_presentation, X, masks)
1✔
261

262
        # have to cal imputation loss in the val stage; no need to cal imputation loss here in the tests stage
263
        imputation_loss = cal_mae(
1✔
264
            learned_presentation, inputs["X_intact"], inputs["indicating_mask"]
265
        )
266

267
        loss = self.ORT_weight * reconstruction_loss + self.MIT_weight * imputation_loss
1✔
268

269
        return {
1✔
270
            "imputed_data": imputed_data,
271
            "reconstruction_loss": reconstruction_loss,
272
            "imputation_loss": imputation_loss,
273
            "loss": loss,
274
        }
275

276

277
class Transformer(BaseNNImputer):
1✔
278
    def __init__(
1✔
279
        self,
280
        n_steps: int,
281
        n_features: int,
282
        n_layers: int,
283
        d_model: int,
284
        d_inner: int,
285
        n_head: int,
286
        d_k: int,
287
        d_v: int,
288
        dropout: float,
289
        ORT_weight: int = 1,
290
        MIT_weight: int = 1,
291
        batch_size: int = 32,
292
        epochs: int = 100,
293
        patience: int = 10,
294
        learning_rate: float = 1e-3,
295
        weight_decay: float = 1e-5,
296
        num_workers: int = 0,
297
        device: Optional[Union[str, torch.device]] = None,
298
        tb_file_saving_path: str = None,
299
    ):
300
        super().__init__(
1✔
301
            batch_size,
302
            epochs,
303
            patience,
304
            learning_rate,
305
            weight_decay,
306
            num_workers,
307
            device,
308
            tb_file_saving_path,
309
        )
310

311
        self.n_steps = n_steps
1✔
312
        self.n_features = n_features
1✔
313
        # model hype-parameters
314
        self.n_layers = n_layers
1✔
315
        self.d_model = d_model
1✔
316
        self.d_inner = d_inner
1✔
317
        self.n_head = n_head
1✔
318
        self.d_k = d_k
1✔
319
        self.d_v = d_v
1✔
320
        self.dropout = dropout
1✔
321
        self.ORT_weight = ORT_weight
1✔
322
        self.MIT_weight = MIT_weight
1✔
323

324
        self.model = _TransformerEncoder(
1✔
325
            self.n_layers,
326
            self.n_steps,
327
            self.n_features,
328
            self.d_model,
329
            self.d_inner,
330
            self.n_head,
331
            self.d_k,
332
            self.d_v,
333
            self.dropout,
334
            self.ORT_weight,
335
            self.MIT_weight,
336
        )
337
        self.model = self.model.to(self.device)
1✔
338
        self._print_model_size()
1✔
339

340
    def _assemble_input_for_training(self, data: dict) -> dict:
1✔
341
        """Assemble the given data into a dictionary for training input.
342

343
        Parameters
344
        ----------
345
        data : list,
346
            A list containing data fetched from Dataset by Dataloader.
347

348
        Returns
349
        -------
350
        inputs : dict,
351
            A python dictionary contains the input data for model training.
352
        """
353

354
        indices, X_intact, X, missing_mask, indicating_mask = data
1✔
355

356
        inputs = {
1✔
357
            "X": X,
358
            "X_intact": X_intact,
359
            "missing_mask": missing_mask,
360
            "indicating_mask": indicating_mask,
361
        }
362

363
        return inputs
1✔
364

365
    def _assemble_input_for_validating(self, data: list) -> dict:
1✔
366
        """Assemble the given data into a dictionary for validating input.
367

368
        Notes
369
        -----
370
        The validating data assembling processing is the same as training data assembling.
371

372

373
        Parameters
374
        ----------
375
        data : list,
376
            A list containing data fetched from Dataset by Dataloader.
377

378
        Returns
379
        -------
380
        inputs : dict,
381
            A python dictionary contains the input data for model validating.
382
        """
383
        indices, X, missing_mask = data
1✔
384

385
        inputs = {
1✔
386
            "X": X,
387
            "missing_mask": missing_mask,
388
        }
389

390
        return inputs
1✔
391

392
    def _assemble_input_for_testing(self, data: list) -> dict:
1✔
393
        """Assemble the given data into a dictionary for testing input.
394

395
        Notes
396
        -----
397
        The testing data assembling processing is the same as training data assembling.
398

399
        Parameters
400
        ----------
401
        data : list,
402
            A list containing data fetched from Dataset by Dataloader.
403

404
        Returns
405
        -------
406
        inputs : dict,
407
            A python dictionary contains the input data for model testing.
408
        """
409
        return self._assemble_input_for_validating(data)
×
410

411
    def fit(
1✔
412
        self,
413
        train_set: Union[dict, str],
414
        val_set: Optional[Union[dict, str]] = None,
415
        file_type: str = "h5py",
416
    ) -> None:
417
        """Train the imputer on the given data.
418

419
        Parameters
420
        ----------
421
        train_set : dict or str,
422
            The dataset for model training, should be a dictionary including the key 'X',
423
            or a path string locating a data file.
424
            If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
425
            which is time-series data for training, can contain missing values.
426
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
427
            key-value pairs like a dict, and it has to include the key 'X'.
428

429
        val_set : dict or str,
430
            The dataset for model validating, should be a dictionary including the key 'X',
431
            or a path string locating a data file.
432
            If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
433
            which is time-series data for validating, can contain missing values.
434
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
435
            key-value pairs like a dict, and it has to include the key 'X'.
436

437
        file_type : str, default = "h5py",
438
            The type of the given file if train_set and val_set are path strings.
439

440
        """
441

442
        training_set = DatasetForMIT(train_set, file_type)
1✔
443
        training_loader = DataLoader(
1✔
444
            training_set,
445
            batch_size=self.batch_size,
446
            shuffle=True,
447
            num_workers=self.num_workers,
448
        )
449
        if val_set is None:
1✔
450
            self._train_model(training_loader)
×
451
        else:
452
            if isinstance(val_set, str):
1✔
453
                with h5py.File(val_set, "r") as hf:
×
454
                    # Here we read the whole validation set from the file to mask a portion for validation.
455
                    # In PyPOTS, using a file usually because the data is too big. However, the validation set is
456
                    # generally shouldn't be too large. For example, we have 1 billion samples for model training.
457
                    # We won't take 20% of them as the validation set because we want as much as possible data for the
458
                    # training stage to enhance the model's generalization ability. Therefore, 100,000 representative
459
                    # samples will be enough to validate the model.
460
                    val_set = {
×
461
                        "X": hf["X"][:],
462
                        "X_intact": hf["X_intact"][:],
463
                        "indicating_mask": hf["indicating_mask"][:],
464
                    }
465

466
            val_set = BaseDataset(val_set)
1✔
467
            val_loader = DataLoader(
1✔
468
                val_set,
469
                batch_size=self.batch_size,
470
                shuffle=False,
471
                num_workers=self.num_workers,
472
            )
473
            self._train_model(training_loader, val_loader)
1✔
474

475
        self.model.load_state_dict(self.best_model_dict)
1✔
476
        self.model.eval()  # set the model as eval status to freeze it.
1✔
477

478
    def impute(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
1✔
479
        """Impute missing values in the given data with the trained model.
480

481
        Parameters
482
        ----------
483
        X : array-like or str,
484
            The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
485
            n_features], or a path string locating a data file, e.g. h5 file.
486

487
        file_type : str, default = "h5py",
488
            The type of the given file if X is a path string.
489

490
        Returns
491
        -------
492
        array-like, shape [n_samples, sequence length (time steps), n_features],
493
            Imputed data.
494
        """
495
        self.model.eval()  # set the model as eval status to freeze it.
1✔
496
        test_set = BaseDataset(X, file_type)
1✔
497
        test_loader = DataLoader(
1✔
498
            test_set,
499
            batch_size=self.batch_size,
500
            shuffle=False,
501
            num_workers=self.num_workers,
502
        )
503
        imputation_collector = []
1✔
504

505
        with torch.no_grad():
1✔
506
            for idx, data in enumerate(test_loader):
1✔
507
                inputs = {"X": data[1], "missing_mask": data[2]}
1✔
508
                imputed_data, _ = self.model.impute(inputs)
1✔
509
                imputation_collector.append(imputed_data)
1✔
510

511
        imputation_collector = torch.cat(imputation_collector)
1✔
512
        return imputation_collector.cpu().detach().numpy()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc