• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 10736679932

06 Sep 2024 10:20AM UTC coverage: 83.273% (+0.2%) from 83.123%
10736679932

Pull #505

github

web-flow
Merge bfc8a18e1 into 66da59c96
Pull Request #505: Add TEFN model

129 of 132 new or added lines in 8 files covered. (97.73%)

2 existing lines in 2 files now uncovered.

11261 of 13523 relevant lines covered (83.27%)

4.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.68
/pypots/classification/base.py
1
"""
6✔
2
The base classes for PyPOTS classification models.
3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8

9
import os
6✔
10
from abc import abstractmethod
6✔
11
from typing import Optional, Union
6✔
12

13
import numpy as np
6✔
14
import torch
6✔
15
from torch.utils.data import DataLoader
6✔
16

17
from ..base import BaseModel, BaseNNModel
6✔
18
from ..utils.logging import logger
6✔
19

20
try:
6✔
21
    import nni
6✔
22
except ImportError:
6✔
23
    pass
6✔
24

25

26
class BaseClassifier(BaseModel):
6✔
27
    """The abstract class for all PyPOTS classification models.
6✔
28

29
    Parameters
30
    ----------
31
    n_classes :
32
        The number of classes in the classification task.
33

34
    device :
35
        The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
36
        If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
37
        then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
38
        If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
39
        model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
40
        Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
41

42
    saving_path :
43
        The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
44
        training into a tensorboard file). Will not save if not given.
45

46
    model_saving_strategy :
47
        The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
48
        No model will be saved when it is set as None.
49
        The "best" strategy will only automatically save the best model after the training finished.
50
        The "better" strategy will automatically save the model during training whenever the model performs
51
        better than in previous epochs.
52
        The "all" strategy will save every model after each epoch training.
53

54
    verbose :
55
        Whether to print out the training logs during the training process.
56
    """
57

58
    def __init__(
6✔
59
        self,
60
        n_classes: int,
61
        device: Optional[Union[str, torch.device, list]] = None,
62
        saving_path: str = None,
63
        model_saving_strategy: Optional[str] = "best",
64
        verbose: bool = True,
65
    ):
66
        super().__init__(
×
67
            device,
68
            saving_path,
69
            model_saving_strategy,
70
            verbose,
71
        )
72
        self.n_classes = n_classes
×
73

74
    @abstractmethod
6✔
75
    def fit(
6✔
76
        self,
77
        train_set: Union[dict, str],
78
        val_set: Optional[Union[dict, str]] = None,
79
        file_type: str = "hdf5",
80
    ) -> None:
81
        """Train the classifier on the given data.
82

83
        Parameters
84
        ----------
85
        train_set :
86
            The dataset for model training, should be a dictionary including keys as 'X' and 'y',
87
            or a path string locating a data file.
88
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
89
            which is time-series data for training, can contain missing values, and y should be array-like of shape
90
            [n_samples], which is classification labels of X.
91
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
92
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
93

94
        val_set :
95
            The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
96
            or a path string locating a data file.
97
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
98
            which is time-series data for validating, can contain missing values, and y should be array-like of shape
99
            [n_samples], which is classification labels of X.
100
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
101
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
102

103
        file_type :
104
            The type of the given file if train_set and val_set are path strings.
105

106
        """
107
        raise NotImplementedError
×
108

109
    @abstractmethod
6✔
110
    def predict(
6✔
111
        self,
112
        test_set: Union[dict, str],
113
        file_type: str = "hdf5",
114
    ) -> dict:
115
        raise NotImplementedError
×
116

117
    @abstractmethod
6✔
118
    def classify(
6✔
119
        self,
120
        test_set: Union[dict, str],
121
        file_type: str = "hdf5",
122
    ) -> np.ndarray:
123
        """Classify the input data with the trained model.
124

125
        Parameters
126
        ----------
127
        test_set :
128
            The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
129
            n_features], or a path string locating a data file, e.g. h5 file.
130

131
        file_type :
132
            The type of the given file if X is a path string.
133

134
        Returns
135
        -------
136
        array-like, shape [n_samples],
137
            Classification results of the given samples.
138
        """
139

140
        raise NotImplementedError
×
141

142

143
class BaseNNClassifier(BaseNNModel):
6✔
144
    """The abstract class for all neural-network classification models in PyPOTS.
6✔
145

146
    Parameters
147
    ----------
148
    n_classes :
149
        The number of classes in the classification task.
150

151
    batch_size :
152
        Size of the batch input into the model for one step.
153

154
    epochs :
155
        Training epochs, i.e. the maximum rounds of the model to be trained with.
156

157
    patience :
158
        Number of epochs the training procedure will keep if loss doesn't decrease.
159
        Once exceeding the number, the training will stop.
160
        Must be smaller than or equal to the value of ``epochs``.
161

162
    num_workers :
163
        The number of subprocesses to use for data loading.
164
        `0` means data loading will be in the main process, i.e. there won't be subprocesses.
165

166
    device :
167
        The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
168
        If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
169
        then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
170
        If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
171
        model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
172
        Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
173

174
    saving_path :
175
        The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
176
        training into a tensorboard file). Will not save if not given.
177

178
    model_saving_strategy :
179
        The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
180
        No model will be saved when it is set as None.
181
        The "best" strategy will only automatically save the best model after the training finished.
182
        The "better" strategy will automatically save the model during training whenever the model performs
183
        better than in previous epochs.
184
        The "all" strategy will save every model after each epoch training.
185

186
    verbose :
187
        Whether to print out the training logs during the training process.
188

189
    Notes
190
    -----
191
    Optimizers are necessary for training deep-learning neural networks, but we don't put  a parameter ``optimizer``
192
    here because some models (e.g. GANs) need more than one optimizer (e.g. one for generator, one for discriminator),
193
    and ``optimizer`` is ambiguous for them. Therefore, we leave optimizers as parameters for concrete model
194
    implementations, and you can pass any number of optimizers to your model when implementing it,
195
    :class:`pypots.clustering.crli.CRLI` for example.
196

197
    """
198

199
    def __init__(
6✔
200
        self,
201
        n_classes: int,
202
        batch_size: int,
203
        epochs: int,
204
        patience: Optional[int] = None,
205
        num_workers: int = 0,
206
        device: Optional[Union[str, torch.device, list]] = None,
207
        saving_path: str = None,
208
        model_saving_strategy: Optional[str] = "best",
209
        verbose: bool = True,
210
    ):
211
        super().__init__(
6✔
212
            batch_size,
213
            epochs,
214
            patience,
215
            num_workers,
216
            device,
217
            saving_path,
218
            model_saving_strategy,
219
            verbose,
220
        )
221
        self.n_classes = n_classes
6✔
222

223
    @abstractmethod
6✔
224
    def _assemble_input_for_training(self, data: list) -> dict:
6✔
225
        """Assemble the given data into a dictionary for training input.
226

227
        Parameters
228
        ----------
229
        data :
230
            Input data from dataloader, should be list.
231

232
        Returns
233
        -------
234
        dict,
235
            A python dictionary contains the input data for model training.
236
        """
237
        raise NotImplementedError
×
238

239
    @abstractmethod
6✔
240
    def _assemble_input_for_validating(self, data: list) -> dict:
6✔
241
        """Assemble the given data into a dictionary for validating input.
242

243
        Parameters
244
        ----------
245
        data :
246
            Data output from dataloader, should be list.
247

248
        Returns
249
        -------
250
        dict,
251
            A python dictionary contains the input data for model validating.
252
        """
253
        raise NotImplementedError
×
254

255
    @abstractmethod
6✔
256
    def _assemble_input_for_testing(self, data: list) -> dict:
6✔
257
        """Assemble the given data into a dictionary for testing input.
258

259
        Notes
260
        -----
261
        The processing functions of train/val/test stages are separated for the situation that the input of
262
        the three stages are different, and this situation usually happens when the Dataset/Dataloader classes
263
        used in the train/val/test stages are not the same, e.g. the training data and validating data in a
264
        classification task contains labels, but the testing data (from the production environment) generally
265
        doesn't have labels.
266

267
        Parameters
268
        ----------
269
        data :
270
            Data output from dataloader, should be list.
271

272
        Returns
273
        -------
274
        dict,
275
            A python dictionary contains the input data for model testing.
276
        """
277
        raise NotImplementedError
×
278

279
    def _train_model(
6✔
280
        self,
281
        training_loader: DataLoader,
282
        val_loader: DataLoader = None,
283
    ) -> None:
284
        # each training starts from the very beginning, so reset the loss and model dict here
285
        self.best_loss = float("inf")
6✔
286
        self.best_model_dict = None
6✔
287

288
        try:
6✔
289
            training_step = 0
6✔
290
            for epoch in range(1, self.epochs + 1):
6✔
291
                self.model.train()
6✔
292
                epoch_train_loss_collector = []
6✔
293
                for idx, data in enumerate(training_loader):
6✔
294
                    training_step += 1
6✔
295
                    inputs = self._assemble_input_for_training(data)
6✔
296
                    self.optimizer.zero_grad()
6✔
297
                    results = self.model.forward(inputs)
6✔
298
                    results["loss"].sum().backward()
6✔
299
                    self.optimizer.step()
6✔
300
                    epoch_train_loss_collector.append(results["loss"].sum().item())
6✔
301

302
                    # save training loss logs into the tensorboard file for every step if in need
303
                    if self.summary_writer is not None:
6✔
304
                        self._save_log_into_tb_file(training_step, "training", results)
6✔
305

306
                # mean training loss of the current epoch
307
                mean_train_loss = np.mean(epoch_train_loss_collector)
6✔
308

309
                if val_loader is not None:
6✔
310
                    self.model.eval()
6✔
311
                    epoch_val_loss_collector = []
6✔
312
                    with torch.no_grad():
6✔
313
                        for idx, data in enumerate(val_loader):
6✔
314
                            inputs = self._assemble_input_for_validating(data)
6✔
315
                            results = self.model.forward(inputs)
6✔
316
                            epoch_val_loss_collector.append(
6✔
317
                                results["loss"].sum().item()
318
                            )
319

320
                    mean_val_loss = np.mean(epoch_val_loss_collector)
6✔
321

322
                    # save validation loss logs into the tensorboard file for every epoch if in need
323
                    if self.summary_writer is not None:
6✔
324
                        val_loss_dict = {
6✔
325
                            "classification_loss": mean_val_loss,
326
                        }
327
                        self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
6✔
328

329
                    logger.info(
6✔
330
                        f"Epoch {epoch:03d} - "
331
                        f"training loss: {mean_train_loss:.4f}, "
332
                        f"validation loss: {mean_val_loss:.4f}"
333
                    )
334
                    mean_loss = mean_val_loss
6✔
335
                else:
336
                    logger.info(
×
337
                        f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}"
338
                    )
339
                    mean_loss = mean_train_loss
×
340

341
                if np.isnan(mean_loss):
6✔
342
                    logger.warning(
×
343
                        f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
344
                    )
345

346
                if mean_loss < self.best_loss:
6✔
347
                    self.best_epoch = epoch
6✔
348
                    self.best_loss = mean_loss
6✔
349
                    self.best_model_dict = self.model.state_dict()
6✔
350
                    self.patience = self.original_patience
6✔
351
                else:
UNCOV
352
                    self.patience -= 1
4✔
353

354
                # save the model if necessary
355
                self._auto_save_model_if_necessary(
6✔
356
                    confirm_saving=self.best_epoch == epoch,
357
                    saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
358
                )
359

360
                if os.getenv("enable_tuning", False):
6✔
361
                    nni.report_intermediate_result(mean_loss)
×
362
                    if epoch == self.epochs - 1 or self.patience == 0:
×
363
                        nni.report_final_result(self.best_loss)
×
364

365
                if self.patience == 0:
6✔
366
                    logger.info(
×
367
                        "Exceeded the training patience. Terminating the training procedure..."
368
                    )
369
                    break
×
370

371
        except KeyboardInterrupt:  # if keyboard interrupt, only warning
×
372
            logger.warning("‼️ Training got interrupted by the user. Exist now ...")
×
373
        except Exception as e:  # other kind of exception follows below processing
×
374
            logger.error(f"❌ Exception: {e}")
×
375
            if self.best_model_dict is None:  # if no best model, raise error
×
376
                raise RuntimeError(
×
377
                    "Training got interrupted. Model was not trained. Please investigate the error printed above."
378
                )
379
            else:
380
                RuntimeWarning(
×
381
                    "Training got interrupted. Please investigate the error printed above.\n"
382
                    "Model got trained and will load the best checkpoint so far for testing.\n"
383
                    "If you don't want it, please try fit() again."
384
                )
385

386
        if np.isnan(self.best_loss):
6✔
387
            raise ValueError("Something is wrong. best_loss is Nan after training.")
×
388

389
        logger.info(
6✔
390
            f"Finished training. The best model is from epoch#{self.best_epoch}."
391
        )
392

393
    @abstractmethod
6✔
394
    def fit(
6✔
395
        self,
396
        train_set: Union[dict, str],
397
        val_set: Optional[Union[dict, str]] = None,
398
        file_type: str = "hdf5",
399
    ) -> None:
400
        """Train the classifier on the given data.
401

402
        Parameters
403
        ----------
404
        train_set :
405
            The dataset for model training, should be a dictionary including keys as 'X' and 'y',
406
            or a path string locating a data file.
407
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
408
            which is time-series data for training, can contain missing values, and y should be array-like of shape
409
            [n_samples], which is classification labels of X.
410
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
411
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
412

413
        val_set :
414
            The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
415
            or a path string locating a data file.
416
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
417
            which is time-series data for validating, can contain missing values, and y should be array-like of shape
418
            [n_samples], which is classification labels of X.
419
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
420
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
421

422
        file_type :
423
            The type of the given file if train_set and val_set are path strings.
424

425
        """
426
        raise NotImplementedError
×
427

428
    @abstractmethod
6✔
429
    def predict(
6✔
430
        self,
431
        test_set: Union[dict, str],
432
        file_type: str = "hdf5",
433
    ) -> dict:
434
        raise NotImplementedError
×
435

436
    @abstractmethod
6✔
437
    def classify(
6✔
438
        self,
439
        test_set: Union[dict, str],
440
        file_type: str = "hdf5",
441
    ) -> np.ndarray:
442
        """Classify the input data with the trained model.
443

444

445

446
        Parameters
447
        ----------
448
        test_set :
449
            The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
450
            n_features], or a path string locating a data file, e.g. h5 file.
451

452
        file_type :
453
            The type of the given file if X is a path string.
454

455
        Returns
456
        -------
457
        array-like, shape [n_samples],
458
            Classification results of the given samples.
459
        """
460

461
        raise NotImplementedError
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc