• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 3911954423

pending completion
3911954423

push

github

Wenjie Du
fix: add the dependencies of PyPOTS into the doc building requirement file;

2110 of 2800 relevant lines covered (75.36%)

0.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.99
/pypots/clustering/base.py
1
"""
1✔
2
The base class for clustering models.
3
"""
4

5
# Created by Wenjie Du <wenjay.du@gmail.com>
6
# License: GPL-v3
7

8

9
from abc import abstractmethod
1✔
10

11
import numpy as np
1✔
12
import torch
1✔
13

14
from pypots.base import BaseModel, BaseNNModel
1✔
15

16

17
class BaseClusterer(BaseModel):
1✔
18
    """Abstract class for all clustering models."""
1✔
19

20
    def __init__(self, device):
1✔
21
        super().__init__(device)
1✔
22

23
    @abstractmethod
1✔
24
    def fit(self, train_X):
1✔
25
        """Train the cluster.
26

27
        Parameters
28
        ----------
29
        train_X : array-like of shape [n_samples, sequence length (time steps), n_features],
30
            Time-series data for training, can contain missing values.
31

32
        Returns
33
        -------
34
        self : object,
35
            Trained classifier.
36
        """
37
        return self
×
38

39
    @abstractmethod
1✔
40
    def cluster(self, X):
1✔
41
        """Cluster the input with the trained model.
42

43
        Parameters
44
        ----------
45
        X : array-like of shape [n_samples, sequence length (time steps), n_features],
46
            Time-series data contains missing values.
47

48
        Returns
49
        -------
50
        array-like, shape [n_samples, sequence length (time steps), n_features],
51
            Clustering results.
52
        """
53
        pass
×
54

55

56
class BaseNNClusterer(BaseNNModel, BaseClusterer):
1✔
57
    def __init__(
1✔
58
        self,
59
        n_clusters,
60
        learning_rate,
61
        epochs,
62
        patience,
63
        batch_size,
64
        weight_decay,
65
        device,
66
    ):
67
        super().__init__(
1✔
68
            learning_rate, epochs, patience, batch_size, weight_decay, device
69
        )
70
        self.n_clusters = n_clusters
1✔
71

72
    @abstractmethod
1✔
73
    def assemble_input_data(self, data):
1✔
74
        pass
×
75

76
    def _train_model(self, training_loader, val_loader=None):
1✔
77
        self.optimizer = torch.optim.Adam(
×
78
            self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
79
        )
80

81
        # each training starts from the very beginning, so reset the loss and model dict here
82
        self.best_loss = float("inf")
×
83
        self.best_model_dict = None
×
84

85
        try:
×
86
            for epoch in range(self.epochs):
×
87
                self.model.train()
×
88
                epoch_train_loss_collector = []
×
89
                for idx, data in enumerate(training_loader):
×
90
                    inputs = self.assemble_input_data(data)
×
91
                    self.optimizer.zero_grad()
×
92
                    results = self.model.forward(inputs)
×
93
                    results["loss"].backward()
×
94
                    self.optimizer.step()
×
95
                    epoch_train_loss_collector.append(results["loss"].item())
×
96

97
                mean_train_loss = np.mean(
×
98
                    epoch_train_loss_collector
99
                )  # mean training loss of the current epoch
100
                self.logger["training_loss"].append(mean_train_loss)
×
101

102
                if val_loader is not None:
×
103
                    self.model.eval()
×
104
                    epoch_val_loss_collector = []
×
105
                    with torch.no_grad():
×
106
                        for idx, data in enumerate(val_loader):
×
107
                            inputs = self.assemble_input_data(data)
×
108
                            results = self.model.forward(inputs)
×
109
                            epoch_val_loss_collector.append(results["loss"].item())
×
110

111
                    mean_val_loss = np.mean(epoch_val_loss_collector)
×
112
                    self.logger["validating_loss"].append(mean_val_loss)
×
113
                    print(
×
114
                        f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
115
                    )
116
                    mean_loss = mean_val_loss
×
117
                else:
118
                    print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
×
119
                    mean_loss = mean_train_loss
×
120

121
                if mean_loss < self.best_loss:
×
122
                    self.best_loss = mean_loss
×
123
                    self.best_model_dict = self.model.state_dict()
×
124
                    self.patience = self.original_patience
×
125
                else:
126
                    self.patience -= 1
×
127
                    if self.patience == 0:
×
128
                        print(
×
129
                            "Exceeded the training patience. Terminating the training procedure..."
130
                        )
131
                        break
×
132
        except Exception as e:
×
133
            print(f"Exception: {e}")
×
134
            if self.best_model_dict is None:
×
135
                raise RuntimeError(
×
136
                    "Training got interrupted. Model was not get trained. Please try fit() again."
137
                )
138
            else:
139
                RuntimeWarning(
×
140
                    "Training got interrupted. "
141
                    "Model will load the best parameters so far for testing. "
142
                    "If you don't want it, please try fit() again."
143
                )
144

145
        if np.equal(self.best_loss, float("inf")):
×
146
            raise ValueError("Something is wrong. best_loss is Nan after training.")
×
147

148
        print("Finished training.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc