• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 12015227701

25 Nov 2024 05:13PM UTC coverage: 84.286% (+0.6%) from 83.684%
12015227701

push

github

web-flow
Merge pull request #550 from WenjieDu/dev

Update the stale workflow and docs, add SegRNN tests

12047 of 14293 relevant lines covered (84.29%)

4.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.68
/pypots/classification/base.py
1
"""
3✔
2
The base classes for PyPOTS classification models.
3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8

9
import os
6✔
10
from abc import abstractmethod
6✔
11
from typing import Optional, Union
6✔
12

13
import numpy as np
6✔
14
import torch
6✔
15
from torch.utils.data import DataLoader
6✔
16

17
from ..base import BaseModel, BaseNNModel
6✔
18
from ..utils.logging import logger
6✔
19

20
try:
6✔
21
    import nni
6✔
22
except ImportError:
6✔
23
    pass
6✔
24

25

26
class BaseClassifier(BaseModel):
6✔
27
    """The abstract class for all PyPOTS classification models.
3✔
28

29
    Parameters
30
    ----------
31
    n_classes :
32
        The number of classes in the classification task.
33

34
    device :
35
        The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
36
        If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
37
        then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
38
        If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
39
        model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
40
        Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
41

42
    saving_path :
43
        The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
44
        training into a tensorboard file). Will not save if not given.
45

46
    model_saving_strategy :
47
        The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
48
        No model will be saved when it is set as None.
49
        The "best" strategy will only automatically save the best model after the training finished.
50
        The "better" strategy will automatically save the model during training whenever the model performs
51
        better than in previous epochs.
52
        The "all" strategy will save every model after each epoch training.
53

54
    verbose :
55
        Whether to print out the training logs during the training process.
56
    """
57

58
    def __init__(
6✔
59
        self,
60
        n_classes: int,
61
        device: Optional[Union[str, torch.device, list]] = None,
62
        saving_path: str = None,
63
        model_saving_strategy: Optional[str] = "best",
64
        verbose: bool = True,
65
    ):
66
        super().__init__(
×
67
            device,
68
            saving_path,
69
            model_saving_strategy,
70
            verbose,
71
        )
72
        self.n_classes = n_classes
×
73

74
    @abstractmethod
6✔
75
    def fit(
6✔
76
        self,
77
        train_set: Union[dict, str],
78
        val_set: Optional[Union[dict, str]] = None,
79
        file_type: str = "hdf5",
80
    ) -> None:
81
        """Train the classifier on the given data.
82

83
        Parameters
84
        ----------
85
        train_set :
86
            The dataset for model training, should be a dictionary including keys as 'X' and 'y',
87
            or a path string locating a data file.
88
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
89
            which is time-series data for training, can contain missing values, and y should be array-like of shape
90
            [n_samples], which is classification labels of X.
91
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
92
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
93

94
        val_set :
95
            The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
96
            or a path string locating a data file.
97
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
98
            which is time-series data for validating, can contain missing values, and y should be array-like of shape
99
            [n_samples], which is classification labels of X.
100
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
101
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
102

103
        file_type :
104
            The type of the given file if train_set and val_set are path strings.
105

106
        """
107
        raise NotImplementedError
×
108

109
    @abstractmethod
6✔
110
    def predict(
6✔
111
        self,
112
        test_set: Union[dict, str],
113
        file_type: str = "hdf5",
114
    ) -> dict:
115
        raise NotImplementedError
×
116

117
    @abstractmethod
6✔
118
    def classify(
6✔
119
        self,
120
        test_set: Union[dict, str],
121
        file_type: str = "hdf5",
122
    ) -> np.ndarray:
123
        """Classify the input data with the trained model.
124

125
        Parameters
126
        ----------
127
        test_set :
128
            The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
129
            n_features], or a path string locating a data file, e.g. h5 file.
130

131
        file_type :
132
            The type of the given file if X is a path string.
133

134
        Returns
135
        -------
136
        array-like, shape [n_samples],
137
            Classification results of the given samples.
138
        """
139

140
        raise NotImplementedError
×
141

142

143
class BaseNNClassifier(BaseNNModel):
6✔
144
    """The abstract class for all neural-network classification models in PyPOTS.
3✔
145

146
    Parameters
147
    ----------
148
    n_classes :
149
        The number of classes in the classification task.
150

151
    batch_size :
152
        Size of the batch input into the model for one step.
153

154
    epochs :
155
        Training epochs, i.e. the maximum rounds of the model to be trained with.
156

157
    patience :
158
        Number of epochs the training procedure will keep if loss doesn't decrease.
159
        Once exceeding the number, the training will stop.
160
        Must be smaller than or equal to the value of ``epochs``.
161

162
    num_workers :
163
        The number of subprocesses to use for data loading.
164
        `0` means data loading will be in the main process, i.e. there won't be subprocesses.
165

166
    device :
167
        The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
168
        If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
169
        then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
170
        If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
171
        model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
172
        Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
173

174
    saving_path :
175
        The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
176
        training into a tensorboard file). Will not save if not given.
177

178
    model_saving_strategy :
179
        The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
180
        No model will be saved when it is set as None.
181
        The "best" strategy will only automatically save the best model after the training finished.
182
        The "better" strategy will automatically save the model during training whenever the model performs
183
        better than in previous epochs.
184
        The "all" strategy will save every model after each epoch training.
185

186
    verbose :
187
        Whether to print out the training logs during the training process.
188

189
    Notes
190
    -----
191
    Optimizers are necessary for training deep-learning neural networks, but we don't put  a parameter ``optimizer``
192
    here because some models (e.g. GANs) need more than one optimizer (e.g. one for generator, one for discriminator),
193
    and ``optimizer`` is ambiguous for them. Therefore, we leave optimizers as parameters for concrete model
194
    implementations, and you can pass any number of optimizers to your model when implementing it,
195
    :class:`pypots.clustering.crli.CRLI` for example.
196

197
    """
198

199
    def __init__(
6✔
200
        self,
201
        n_classes: int,
202
        batch_size: int,
203
        epochs: int,
204
        patience: Optional[int] = None,
205
        num_workers: int = 0,
206
        device: Optional[Union[str, torch.device, list]] = None,
207
        saving_path: str = None,
208
        model_saving_strategy: Optional[str] = "best",
209
        verbose: bool = True,
210
    ):
211
        super().__init__(
6✔
212
            batch_size,
213
            epochs,
214
            patience,
215
            num_workers,
216
            device,
217
            saving_path,
218
            model_saving_strategy,
219
            verbose,
220
        )
221
        self.n_classes = n_classes
6✔
222

223
    @abstractmethod
6✔
224
    def _assemble_input_for_training(self, data: list) -> dict:
6✔
225
        """Assemble the given data into a dictionary for training input.
226

227
        Parameters
228
        ----------
229
        data :
230
            Input data from dataloader, should be list.
231

232
        Returns
233
        -------
234
        dict,
235
            A python dictionary contains the input data for model training.
236
        """
237
        raise NotImplementedError
×
238

239
    @abstractmethod
6✔
240
    def _assemble_input_for_validating(self, data: list) -> dict:
6✔
241
        """Assemble the given data into a dictionary for validating input.
242

243
        Parameters
244
        ----------
245
        data :
246
            Data output from dataloader, should be list.
247

248
        Returns
249
        -------
250
        dict,
251
            A python dictionary contains the input data for model validating.
252
        """
253
        raise NotImplementedError
×
254

255
    @abstractmethod
6✔
256
    def _assemble_input_for_testing(self, data: list) -> dict:
6✔
257
        """Assemble the given data into a dictionary for testing input.
258

259
        Notes
260
        -----
261
        The processing functions of train/val/test stages are separated for the situation that the input of
262
        the three stages are different, and this situation usually happens when the Dataset/Dataloader classes
263
        used in the train/val/test stages are not the same, e.g. the training data and validating data in a
264
        classification task contains labels, but the testing data (from the production environment) generally
265
        doesn't have labels.
266

267
        Parameters
268
        ----------
269
        data :
270
            Data output from dataloader, should be list.
271

272
        Returns
273
        -------
274
        dict,
275
            A python dictionary contains the input data for model testing.
276
        """
277
        raise NotImplementedError
×
278

279
    def _train_model(
6✔
280
        self,
281
        training_loader: DataLoader,
282
        val_loader: DataLoader = None,
283
    ) -> None:
284
        # each training starts from the very beginning, so reset the loss and model dict here
285
        self.best_loss = float("inf")
6✔
286
        self.best_model_dict = None
6✔
287

288
        try:
6✔
289
            training_step = 0
6✔
290
            for epoch in range(1, self.epochs + 1):
6✔
291
                self.model.train()
6✔
292
                epoch_train_loss_collector = []
6✔
293
                for idx, data in enumerate(training_loader):
6✔
294
                    training_step += 1
6✔
295
                    inputs = self._assemble_input_for_training(data)
6✔
296
                    self.optimizer.zero_grad()
6✔
297
                    results = self.model.forward(inputs)
6✔
298
                    results["loss"].sum().backward()
6✔
299
                    self.optimizer.step()
6✔
300
                    epoch_train_loss_collector.append(results["loss"].sum().item())
6✔
301

302
                    # save training loss logs into the tensorboard file for every step if in need
303
                    if self.summary_writer is not None:
6✔
304
                        self._save_log_into_tb_file(training_step, "training", results)
6✔
305

306
                # mean training loss of the current epoch
307
                mean_train_loss = np.mean(epoch_train_loss_collector)
6✔
308

309
                if val_loader is not None:
6✔
310
                    self.model.eval()
6✔
311
                    epoch_val_loss_collector = []
6✔
312
                    with torch.no_grad():
6✔
313
                        for idx, data in enumerate(val_loader):
6✔
314
                            inputs = self._assemble_input_for_validating(data)
6✔
315
                            results = self.model.forward(inputs)
6✔
316
                            epoch_val_loss_collector.append(results["loss"].sum().item())
6✔
317

318
                    mean_val_loss = np.mean(epoch_val_loss_collector)
6✔
319

320
                    # save validation loss logs into the tensorboard file for every epoch if in need
321
                    if self.summary_writer is not None:
6✔
322
                        val_loss_dict = {
6✔
323
                            "classification_loss": mean_val_loss,
324
                        }
325
                        self._save_log_into_tb_file(epoch, "validating", val_loss_dict)
6✔
326

327
                    logger.info(
6✔
328
                        f"Epoch {epoch:03d} - "
329
                        f"training loss: {mean_train_loss:.4f}, "
330
                        f"validation loss: {mean_val_loss:.4f}"
331
                    )
332
                    mean_loss = mean_val_loss
6✔
333
                else:
334
                    logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}")
×
335
                    mean_loss = mean_train_loss
×
336

337
                if np.isnan(mean_loss):
6✔
338
                    logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.")
×
339

340
                if mean_loss < self.best_loss:
6✔
341
                    self.best_epoch = epoch
6✔
342
                    self.best_loss = mean_loss
6✔
343
                    self.best_model_dict = self.model.state_dict()
6✔
344
                    self.patience = self.original_patience
6✔
345
                else:
346
                    self.patience -= 1
4✔
347

348
                # save the model if necessary
349
                self._auto_save_model_if_necessary(
6✔
350
                    confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
351
                    saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
352
                )
353

354
                if os.getenv("enable_tuning", False):
6✔
355
                    nni.report_intermediate_result(mean_loss)
×
356
                    if epoch == self.epochs - 1 or self.patience == 0:
×
357
                        nni.report_final_result(self.best_loss)
×
358

359
                if self.patience == 0:
6✔
360
                    logger.info("Exceeded the training patience. Terminating the training procedure...")
×
361
                    break
×
362

363
        except KeyboardInterrupt:  # if keyboard interrupt, only warning
×
364
            logger.warning("‼️ Training got interrupted by the user. Exist now ...")
×
365
        except Exception as e:  # other kind of exception follows below processing
×
366
            logger.error(f"❌ Exception: {e}")
×
367
            if self.best_model_dict is None:  # if no best model, raise error
×
368
                raise RuntimeError(
×
369
                    "Training got interrupted. Model was not trained. Please investigate the error printed above."
370
                )
371
            else:
372
                RuntimeWarning(
×
373
                    "Training got interrupted. Please investigate the error printed above.\n"
374
                    "Model got trained and will load the best checkpoint so far for testing.\n"
375
                    "If you don't want it, please try fit() again."
376
                )
377

378
        if np.isnan(self.best_loss):
6✔
379
            raise ValueError("Something is wrong. best_loss is Nan after training.")
×
380

381
        logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.")
6✔
382

383
    @abstractmethod
6✔
384
    def fit(
6✔
385
        self,
386
        train_set: Union[dict, str],
387
        val_set: Optional[Union[dict, str]] = None,
388
        file_type: str = "hdf5",
389
    ) -> None:
390
        """Train the classifier on the given data.
391

392
        Parameters
393
        ----------
394
        train_set :
395
            The dataset for model training, should be a dictionary including keys as 'X' and 'y',
396
            or a path string locating a data file.
397
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
398
            which is time-series data for training, can contain missing values, and y should be array-like of shape
399
            [n_samples], which is classification labels of X.
400
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
401
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
402

403
        val_set :
404
            The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
405
            or a path string locating a data file.
406
            If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
407
            which is time-series data for validating, can contain missing values, and y should be array-like of shape
408
            [n_samples], which is classification labels of X.
409
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
410
            key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
411

412
        file_type :
413
            The type of the given file if train_set and val_set are path strings.
414

415
        """
416
        raise NotImplementedError
×
417

418
    @abstractmethod
6✔
419
    def predict(
6✔
420
        self,
421
        test_set: Union[dict, str],
422
        file_type: str = "hdf5",
423
    ) -> dict:
424
        raise NotImplementedError
×
425

426
    @abstractmethod
6✔
427
    def classify(
6✔
428
        self,
429
        test_set: Union[dict, str],
430
        file_type: str = "hdf5",
431
    ) -> np.ndarray:
432
        """Classify the input data with the trained model.
433

434

435

436
        Parameters
437
        ----------
438
        test_set :
439
            The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
440
            n_features], or a path string locating a data file, e.g. h5 file.
441

442
        file_type :
443
            The type of the given file if X is a path string.
444

445
        Returns
446
        -------
447
        array-like, shape [n_samples],
448
            Classification results of the given samples.
449
        """
450

451
        raise NotImplementedError
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc