• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 4649236213

pending completion
4649236213

Pull #44

github

GitHub
Merge 0b894402b into 4646f5c3f
Pull Request #44: Make imputation models `val_X_intact` and `val_indicating_mask` be included in input `val_set` originally

15 of 15 new or added lines in 6 files covered. (100.0%)

2692 of 3170 relevant lines covered (84.92%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.89
/pypots/imputation/base.py
1
"""
1✔
2
The base class for imputation models.
3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: GPL-v3
7

8

9
import os
1✔
10
from abc import abstractmethod
1✔
11
from typing import Union, Optional
1✔
12

13
import numpy as np
1✔
14
import torch
1✔
15
from torch.utils.data import DataLoader
1✔
16

17
from pypots.base import BaseModel, BaseNNModel
1✔
18
from pypots.utils.logging import logger
1✔
19
from pypots.utils.metrics import cal_mae
1✔
20

21
try:
1✔
22
    import nni
1✔
23
except ImportError:
1✔
24
    pass
1✔
25

26

27
class BaseImputer(BaseModel):
1✔
28
    """Abstract class for all imputation models.
1✔
29

30
    Parameters
31
    ----------
32
    device : str or `torch.device`, default = None,
33
        The device for the model to run on.
34
        If not given, will try to use CUDA devices first, then CPUs. CUDA and CPU are so far the main devices for people
35
        to train ML models. Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
36

37
    tb_file_saving_path : str, default = None,
38
        The path to save the tensorboard file, which contains the loss values recorded during training.
39
    """
40

41
    def __init__(
1✔
42
        self,
43
        device: Optional[Union[str, torch.device]] = None,
44
        tb_file_saving_path: str = None,
45
    ):
46
        super().__init__(
1✔
47
            device,
48
            tb_file_saving_path,
49
        )
50

51
    @abstractmethod
1✔
52
    def fit(
1✔
53
        self,
54
        train_set: Union[dict, str],
55
        val_set: Optional[Union[dict, str]] = None,
56
        file_type: str = "h5py",
57
    ) -> None:
58
        """Train the imputer on the given data.
59

60
        Parameters
61
        ----------
62
        train_set : dict or str,
63
            The dataset for model training, should be a dictionary including the key 'X',
64
            or a path string locating a data file.
65
            If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
66
            which is time-series data for training, can contain missing values.
67
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
68
            key-value pairs like a dict, and it has to include the key 'X'.
69

70
        val_set : dict or str,
71
            The dataset for model validating, should be a dictionary including the key 'X',
72
            or a path string locating a data file.
73
            If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
74
            which is time-series data for validating, can contain missing values.
75
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
76
            key-value pairs like a dict, and it has to include the key 'X'.
77

78
        file_type : str, default = "h5py",
79
            The type of the given file if train_set and val_set are path strings.
80

81
        """
82
        pass
×
83

84
    @abstractmethod
1✔
85
    def impute(
1✔
86
        self,
87
        X: Union[dict, str],
88
        file_type: str = "h5py",
89
    ) -> np.ndarray:
90
        """Impute missing values in the given data with the trained model.
91

92
        Parameters
93
        ----------
94
        X : array-like or str,
95
            The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
96
            n_features], or a path string locating a data file, e.g. h5 file.
97

98
        file_type : str, default = "h5py",
99
            The type of the given file if X is a path string.
100

101
        Returns
102
        -------
103
        array-like, shape [n_samples, sequence length (time steps), n_features],
104
            Imputed data.
105
        """
106
        pass
×
107

108

109
class BaseNNImputer(BaseNNModel, BaseImputer):
1✔
110
    """Abstract class for all neural-network imputation models.
1✔
111

112
    Parameters
113
    ----------
114
    batch_size : int,
115
        Size of the batch input into the model for one step.
116

117
    epochs : int,
118
        Training epochs, i.e. the maximum rounds of the model to be trained with.
119

120
    patience : int,
121
        Number of epochs the training procedure will keep if loss doesn't decrease.
122
        Once exceeding the number, the training will stop.
123

124
    learning_rate : float,
125
        The learning rate of the optimizer.
126

127
    weight_decay : float,
128
        The weight decay of the optimizer.
129

130
    device : str or `torch.device`, default = None,
131
        The device for the model to run on.
132
        If not given, will try to use CUDA devices first, then CPUs. CUDA and CPU are so far the main devices for people
133
        to train ML models. Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
134

135
    tb_file_saving_path : str, default = None,
136
        The path to save the tensorboard file, which contains the loss values recorded during training.
137
    """
138

139
    def __init__(
1✔
140
        self,
141
        batch_size: int,
142
        epochs: int,
143
        patience: int,
144
        learning_rate: float,
145
        weight_decay: float,
146
        num_workers: int = 0,
147
        device: Optional[Union[str, torch.device]] = None,
148
        tb_file_saving_path: str = None,
149
    ):
150
        super().__init__(
1✔
151
            batch_size,
152
            epochs,
153
            patience,
154
            learning_rate,
155
            weight_decay,
156
            num_workers,
157
            device,
158
            tb_file_saving_path,
159
        )
160

161
    @abstractmethod
1✔
162
    def _assemble_input_for_training(self, data: list) -> dict:
1✔
163
        """Assemble the given data into a dictionary for training input.
164

165
        Parameters
166
        ----------
167
        data : list,
168
            Input data from dataloader, should be list.
169

170
        Returns
171
        -------
172
        dict,
173
            A python dictionary contains the input data for model training.
174
        """
175
        pass
×
176

177
    @abstractmethod
1✔
178
    def _assemble_input_for_validating(self, data: list) -> dict:
1✔
179
        """Assemble the given data into a dictionary for validating input.
180

181
        Parameters
182
        ----------
183
        data : list,
184
            Data output from dataloader, should be list.
185

186
        Returns
187
        -------
188
        dict,
189
            A python dictionary contains the input data for model validating.
190
        """
191
        pass
×
192

193
    @abstractmethod
1✔
194
    def _assemble_input_for_testing(self, data: list) -> dict:
1✔
195
        """Assemble the given data into a dictionary for testing input.
196

197
        Notes
198
        -----
199
        The processing functions of train/val/test stages are separated for the situation that the input of
200
        the three stages are different, and this situation usually happens when the Dataset/Dataloader classes
201
        used in the train/val/test stages are not the same, e.g. the training data and validating data in a
202
        classification task contains labels, but the testing data (from the production environment) generally
203
        doesn't have labels.
204

205
        Parameters
206
        ----------
207
        data : list,
208
            Data output from dataloader, should be list.
209

210
        Returns
211
        -------
212
        dict,
213
            A python dictionary contains the input data for model testing.
214
        """
215
        pass
×
216

217
    def _train_model(
1✔
218
        self,
219
        training_loader: DataLoader,
220
        val_loader: DataLoader = None,
221
    ) -> None:
222
        self.optimizer = torch.optim.Adam(
1✔
223
            self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
224
        )
225

226
        # each training starts from the very beginning, so reset the loss and model dict here
227
        self.best_loss = float("inf")
1✔
228
        self.best_model_dict = None
1✔
229

230
        try:
1✔
231
            for epoch in range(self.epochs):
1✔
232
                self.model.train()
1✔
233
                epoch_train_loss_collector = []
1✔
234
                for idx, data in enumerate(training_loader):
1✔
235
                    inputs = self._assemble_input_for_training(data)
1✔
236
                    self.optimizer.zero_grad()
1✔
237
                    results = self.model.forward(inputs)
1✔
238
                    results["loss"].backward()
1✔
239
                    self.optimizer.step()
1✔
240
                    epoch_train_loss_collector.append(results["loss"].item())
1✔
241

242
                mean_train_loss = np.mean(
1✔
243
                    epoch_train_loss_collector
244
                )  # mean training loss of the current epoch
245
                self.logger["training_loss"].append(mean_train_loss)
1✔
246

247
                if val_loader is not None:
1✔
248
                    self.model.eval()
1✔
249
                    imputation_collector = []
1✔
250
                    with torch.no_grad():
1✔
251
                        for idx, data in enumerate(val_loader):
1✔
252
                            inputs = self._assemble_input_for_validating(data)
1✔
253
                            imputed_data, _ = self.model.impute(inputs)
1✔
254
                            imputation_collector.append(imputed_data)
1✔
255

256
                    imputation_collector = torch.cat(imputation_collector)
1✔
257
                    imputation_collector = imputation_collector.numpy()
1✔
258

259
                    mean_val_loss = cal_mae(
1✔
260
                        imputation_collector,
261
                        val_loader.dataset.data["X_intact"],
262
                        val_loader.dataset.data["indicating_mask"],
263
                        # the above val_loader.dataset.data is a dict containing the validation dataset
264
                    )
265
                    self.logger["validating_loss"].append(mean_val_loss)
1✔
266
                    logger.info(
1✔
267
                        f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
268
                    )
269
                    mean_loss = mean_val_loss
1✔
270
                else:
271
                    logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
×
272
                    mean_loss = mean_train_loss
×
273

274
                if mean_loss < self.best_loss:
1✔
275
                    self.best_loss = mean_loss
1✔
276
                    self.best_model_dict = self.model.state_dict()
1✔
277
                    self.patience = self.original_patience
1✔
278
                else:
279
                    self.patience -= 1
×
280

281
                if os.getenv("enable_nni", False):
1✔
282
                    nni.report_intermediate_result(mean_loss)
×
283
                    if epoch == self.epochs - 1 or self.patience == 0:
×
284
                        nni.report_final_result(self.best_loss)
×
285

286
                if self.patience == 0:
1✔
287
                    logger.info(
×
288
                        "Exceeded the training patience. Terminating the training procedure..."
289
                    )
290
                    break
×
291

292
        except Exception as e:
×
293
            logger.info(f"Exception: {e}")
×
294
            if self.best_model_dict is None:
×
295
                raise RuntimeError(
×
296
                    "Training got interrupted. Model was not get trained. Please try fit() again."
297
                )
298
            else:
299
                RuntimeWarning(
×
300
                    "Training got interrupted. "
301
                    "Model will load the best parameters so far for testing. "
302
                    "If you don't want it, please try fit() again."
303
                )
304

305
        if np.equal(self.best_loss.item(), float("inf")):
1✔
306
            raise ValueError("Something is wrong. best_loss is Nan after training.")
×
307

308
        logger.info("Finished training.")
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc