• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 4649236213

pending completion
4649236213

Pull #44

github

GitHub
Merge 0b894402b into 4646f5c3f
Pull Request #44: Make imputation models `val_X_intact` and `val_indicating_mask` be included in input `val_set` originally

15 of 15 new or added lines in 6 files covered. (100.0%)

2692 of 3170 relevant lines covered (84.92%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.52
/pypots/clustering/crli.py
1
"""
1✔
2
Torch implementation of CRLI (Clustering Representation Learning on Incomplete time-series data).
3

4
Please refer to :cite:``ma2021CRLI``.
5
"""
6

7
# Created by Wenjie Du <wenjay.du@gmail.com>
8
# License: GLP-v3
9

10
from typing import Tuple, Union, Optional
1✔
11

12
import numpy as np
1✔
13
import torch
1✔
14
import torch.nn as nn
1✔
15
import torch.nn.functional as F
1✔
16
from sklearn.cluster import KMeans
1✔
17
from torch.utils.data import DataLoader
1✔
18

19
from pypots.clustering.base import BaseNNClusterer
1✔
20
from pypots.data.dataset_for_grud import DatasetForGRUD
1✔
21
from pypots.utils.logging import logger
1✔
22
from pypots.utils.metrics import cal_mse
1✔
23

24
RNN_CELL = {
1✔
25
    "LSTM": nn.LSTMCell,
26
    "GRU": nn.GRUCell,
27
}
28

29

30
def reverse_tensor(tensor_: torch.Tensor) -> torch.Tensor:
1✔
31
    if tensor_.dim() <= 1:
1✔
32
        return tensor_
×
33
    indices = range(tensor_.size()[1])[::-1]
1✔
34
    indices = torch.tensor(
1✔
35
        indices, dtype=torch.long, device=tensor_.device, requires_grad=False
36
    )
37
    return tensor_.index_select(1, indices)
1✔
38

39

40
class MultiRNNCell(nn.Module):
1✔
41
    def __init__(
1✔
42
        self,
43
        cell_type: str,
44
        n_layer: int,
45
        d_input: int,
46
        d_hidden: int,
47
        device: Union[str, torch.device],
48
    ):
49
        super().__init__()
1✔
50
        self.cell_type = cell_type
1✔
51
        self.n_layer = n_layer
1✔
52
        self.d_input = d_input
1✔
53
        self.d_hidden = d_hidden
1✔
54
        self.device = device
1✔
55

56
        self.model = nn.ModuleList()
1✔
57
        if cell_type in ["LSTM", "GRU"]:
1✔
58
            for i in range(n_layer):
1✔
59
                if i == 0:
1✔
60
                    self.model.append(RNN_CELL[cell_type](d_input, d_hidden))
1✔
61
                else:
62
                    self.model.append(RNN_CELL[cell_type](d_hidden, d_hidden))
1✔
63

64
        self.output_layer = nn.Linear(d_hidden, d_input)
1✔
65

66
    def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
67
        X, missing_mask = inputs["X"], inputs["missing_mask"]
1✔
68
        bz, n_steps, _ = X.shape
1✔
69
        hidden_state = torch.zeros((bz, self.d_hidden), device=self.device)
1✔
70
        hidden_state_collector = torch.empty(
1✔
71
            (bz, n_steps, self.d_hidden), device=self.device
72
        )
73
        output_collector = torch.empty((bz, n_steps, self.d_input), device=self.device)
1✔
74
        if self.cell_type == "LSTM":
1✔
75
            # TODO: cell states should have different shapes
76
            cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device)
×
77
            for step in range(n_steps):
×
78
                x = X[:, step, :]
×
79
                estimation = self.output_layer(hidden_state)
×
80
                output_collector[:, step] = estimation
×
81
                imputed_x = (
×
82
                    missing_mask[:, step] * x + (1 - missing_mask[:, step]) * estimation
83
                )
84
                for i in range(self.n_layer):
×
85
                    if i == 0:
×
86
                        hidden_state, cell_states = self.model[i](
×
87
                            imputed_x, (hidden_state, cell_states)
88
                        )
89
                    else:
90
                        hidden_state, cell_states = self.model[i](
×
91
                            hidden_state, (hidden_state, cell_states)
92
                        )
93
                hidden_state_collector[:, step, :] = hidden_state
×
94

95
        elif self.cell_type == "GRU":
1✔
96
            for step in range(n_steps):
1✔
97
                x = X[:, step, :]
1✔
98
                estimation = self.output_layer(hidden_state)
1✔
99
                output_collector[:, step] = estimation
1✔
100
                imputed_x = (
1✔
101
                    missing_mask[:, step] * x + (1 - missing_mask[:, step]) * estimation
102
                )
103
                for i in range(self.n_layer):
1✔
104
                    if i == 0:
1✔
105
                        hidden_state = self.model[i](imputed_x, hidden_state)
1✔
106
                    else:
107
                        hidden_state = self.model[i](hidden_state, hidden_state)
1✔
108

109
                hidden_state_collector[:, step, :] = hidden_state
1✔
110

111
        output_collector = output_collector[:, 1:]
1✔
112
        estimation = self.output_layer(hidden_state).unsqueeze(1)
1✔
113
        output_collector = torch.concat([output_collector, estimation], dim=1)
1✔
114
        return output_collector, hidden_state
1✔
115

116

117
class Generator(nn.Module):
1✔
118
    def __init__(
1✔
119
        self,
120
        n_layers: int,
121
        n_features: int,
122
        d_hidden: int,
123
        cell_type: str,
124
        device: Union[str, torch.device],
125
    ):
126
        super().__init__()
1✔
127
        self.f_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device)
1✔
128
        self.b_rnn = MultiRNNCell(cell_type, n_layers, n_features, d_hidden, device)
1✔
129

130
    def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1✔
131
        f_outputs, f_final_hidden_state = self.f_rnn(inputs)
1✔
132
        b_outputs, b_final_hidden_state = self.b_rnn(inputs)
1✔
133
        b_outputs = reverse_tensor(b_outputs)  # reverse the output of the backward rnn
1✔
134
        imputation = (f_outputs + b_outputs) / 2
1✔
135
        imputed_X = inputs["X"] * inputs["missing_mask"] + imputation * (
1✔
136
            1 - inputs["missing_mask"]
137
        )
138
        fb_final_hidden_states = torch.concat(
1✔
139
            [f_final_hidden_state, b_final_hidden_state], dim=-1
140
        )
141
        return imputation, imputed_X, fb_final_hidden_states
1✔
142

143

144
class Discriminator(nn.Module):
1✔
145
    def __init__(
1✔
146
        self,
147
        cell_type: str,
148
        d_input: int,
149
        device: Union[str, torch.device],
150
    ):
151
        super().__init__()
1✔
152
        self.cell_type = cell_type
1✔
153
        self.device = device
1✔
154
        # this setting is the same with the official implementation
155
        self.rnn_cell_module_list = nn.ModuleList(
1✔
156
            [
157
                RNN_CELL[cell_type](d_input, 32),
158
                RNN_CELL[cell_type](32, 16),
159
                RNN_CELL[cell_type](16, 8),
160
                RNN_CELL[cell_type](8, 16),
161
                RNN_CELL[cell_type](16, 32),
162
            ]
163
        )
164
        self.output_layer = nn.Linear(32, d_input)
1✔
165

166
    def forward(self, inputs: dict) -> torch.Tensor:
1✔
167
        imputed_X = inputs["imputed_X"]
1✔
168
        bz, n_steps, _ = imputed_X.shape
1✔
169
        hidden_states = [
1✔
170
            torch.zeros((bz, 32), device=self.device),
171
            torch.zeros((bz, 16), device=self.device),
172
            torch.zeros((bz, 8), device=self.device),
173
            torch.zeros((bz, 16), device=self.device),
174
            torch.zeros((bz, 32), device=self.device),
175
        ]
176
        hidden_state_collector = torch.empty((bz, n_steps, 32), device=self.device)
1✔
177
        if self.cell_type == "LSTM":
1✔
178
            cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device)
×
179
            for step in range(n_steps):
×
180
                x = imputed_X[:, step, :]
×
181
                for i, rnn_cell in enumerate(self.rnn_cell_module_list):
×
182
                    if i == 0:
×
183
                        hidden_state, cell_states = rnn_cell(
×
184
                            x, (hidden_states[i], cell_states)
185
                        )
186
                    else:
187
                        hidden_state, cell_states = rnn_cell(
×
188
                            hidden_states[i - 1], (hidden_states[i], cell_states)
189
                        )
190
                    hidden_states[i] = hidden_state
×
191
                hidden_state_collector[:, step, :] = hidden_state
×
192

193
        elif self.cell_type == "GRU":
1✔
194
            for step in range(n_steps):
1✔
195
                x = imputed_X[:, step, :]
1✔
196
                for i, rnn_cell in enumerate(self.rnn_cell_module_list):
1✔
197
                    if i == 0:
1✔
198
                        hidden_state = rnn_cell(x, hidden_states[i])
1✔
199
                    else:
200
                        hidden_state = rnn_cell(hidden_states[i - 1], hidden_states[i])
1✔
201
                    hidden_states[i] = hidden_state
1✔
202
                hidden_state_collector[:, step, :] = hidden_state
1✔
203

204
        output_collector = self.output_layer(hidden_state_collector)
1✔
205
        return output_collector
1✔
206

207

208
class Decoder(nn.Module):
1✔
209
    def __init__(
1✔
210
        self,
211
        n_steps: int,
212
        d_input: int,
213
        d_output: int,
214
        fcn_output_dims: list = None,
215
        device: Union[str, torch.device] = "cpu",
216
    ):
217
        super().__init__()
1✔
218
        self.n_steps = n_steps
1✔
219
        self.d_output = d_output
1✔
220
        self.device = device
1✔
221

222
        if fcn_output_dims is None:
1✔
223
            fcn_output_dims = [d_input]
1✔
224
        self.fcn_output_dims = fcn_output_dims
1✔
225

226
        self.fcn = nn.ModuleList()
1✔
227
        for output_dim in fcn_output_dims:
1✔
228
            self.fcn.append(nn.Linear(d_input, output_dim))
1✔
229
            d_input = output_dim
1✔
230

231
        self.rnn_cell = nn.GRUCell(fcn_output_dims[-1], fcn_output_dims[-1])
1✔
232
        self.output_layer = nn.Linear(fcn_output_dims[-1], d_output)
1✔
233

234
    def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
235
        generator_fb_hidden_states = inputs["generator_fb_hidden_states"]
1✔
236
        bz, _ = generator_fb_hidden_states.shape
1✔
237
        fcn_latent = generator_fb_hidden_states
1✔
238
        for layer in self.fcn:
1✔
239
            fcn_latent = layer(fcn_latent)
1✔
240
        hidden_state = fcn_latent
1✔
241
        hidden_state_collector = torch.empty(
1✔
242
            (bz, self.n_steps, self.fcn_output_dims[-1]), device=self.device
243
        )
244
        for i in range(self.n_steps):
1✔
245
            hidden_state = self.rnn_cell(hidden_state, hidden_state)
1✔
246
            hidden_state_collector[:, i, :] = hidden_state
1✔
247
        reconstruction = self.output_layer(hidden_state_collector)
1✔
248
        return reconstruction, fcn_latent
1✔
249

250

251
class _CRLI(nn.Module):
1✔
252
    def __init__(
1✔
253
        self,
254
        n_steps: int,
255
        n_features: int,
256
        n_clusters: int,
257
        n_generator_layers: int,
258
        rnn_hidden_size: int,
259
        decoder_fcn_output_dims: list,
260
        lambda_kmeans: float,
261
        rnn_cell_type: str = "GRU",
262
        device: Union[str, torch.device] = "cpu",
263
    ):
264
        super().__init__()
1✔
265
        self.generator = Generator(
1✔
266
            n_generator_layers, n_features, rnn_hidden_size, rnn_cell_type, device
267
        )
268
        self.discriminator = Discriminator(rnn_cell_type, n_features, device)
1✔
269
        self.decoder = Decoder(
1✔
270
            n_steps, rnn_hidden_size * 2, n_features, decoder_fcn_output_dims, device
271
        )  # fully connected network is included in Decoder
272
        self.kmeans = KMeans(
1✔
273
            n_clusters=n_clusters
274
        )  # TODO: implement KMean with torch for gpu acceleration
275

276
        self.n_clusters = n_clusters
1✔
277
        self.lambda_kmeans = lambda_kmeans
1✔
278
        self.device = device
1✔
279

280
    def cluster(self, inputs: dict, training_object: str = "generator") -> dict:
1✔
281
        # concat final states from generator and input it as the initial state of decoder
282
        imputation, imputed_X, generator_fb_hidden_states = self.generator(inputs)
1✔
283
        inputs["imputation"] = imputation
1✔
284
        inputs["imputed_X"] = imputed_X
1✔
285
        inputs["generator_fb_hidden_states"] = generator_fb_hidden_states
1✔
286
        if training_object == "discriminator":
1✔
287
            discrimination = self.discriminator(inputs)
1✔
288
            inputs["discrimination"] = discrimination
1✔
289
            return inputs  # if only train discriminator, then no need to run decoder
1✔
290

291
        reconstruction, fcn_latent = self.decoder(inputs)
1✔
292
        inputs["reconstruction"] = reconstruction
1✔
293
        inputs["fcn_latent"] = fcn_latent
1✔
294
        return inputs
1✔
295

296
    def forward(self, inputs: dict, training_object: str = "generator") -> dict:
1✔
297
        assert training_object in [
1✔
298
            "generator",
299
            "discriminator",
300
        ], 'training_object should be "generator" or "discriminator"'
301

302
        X = inputs["X"]
1✔
303
        missing_mask = inputs["missing_mask"]
1✔
304
        batch_size, n_steps, n_features = X.shape
1✔
305
        losses = {}
1✔
306
        inputs = self.cluster(inputs, training_object)
1✔
307
        if training_object == "discriminator":
1✔
308
            l_D = F.binary_cross_entropy_with_logits(
1✔
309
                inputs["discrimination"], missing_mask
310
            )
311
            losses["l_disc"] = l_D
1✔
312
        else:
313
            inputs["discrimination"] = inputs["discrimination"].detach()
1✔
314
            l_G = F.binary_cross_entropy_with_logits(
1✔
315
                inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask
316
            )
317
            l_pre = cal_mse(inputs["imputation"], X, missing_mask)
1✔
318
            l_rec = cal_mse(inputs["reconstruction"], X, missing_mask)
1✔
319
            HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0))
1✔
320
            term_F = torch.nn.init.orthogonal_(
1✔
321
                torch.randn(batch_size, self.n_clusters, device=self.device), gain=1
322
            )
323
            FTHTHF = torch.matmul(torch.matmul(term_F.permute(1, 0), HTH), term_F)
1✔
324
            l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF)  # k-means loss
1✔
325
            loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans
1✔
326
            losses["l_gene"] = loss_gene
1✔
327
        return losses
1✔
328

329

330
class CRLI(BaseNNClusterer):
1✔
331
    def __init__(
1✔
332
        self,
333
        n_steps: int,
334
        n_features: int,
335
        n_clusters: int,
336
        n_generator_layers: int,
337
        rnn_hidden_size: int,
338
        decoder_fcn_output_dims: list = None,
339
        lambda_kmeans: float = 1,
340
        rnn_cell_type: str = "GRU",
341
        G_steps: int = 1,
342
        D_steps: int = 1,
343
        batch_size: int = 32,
344
        epochs: int = 100,
345
        patience: int = 10,
346
        learning_rate: float = 1e-3,
347
        weight_decay: float = 1e-5,
348
        num_workers: int = 0,
349
        device: Optional[Union[str, torch.device]] = None,
350
        tb_file_saving_path: str = None,
351
    ):
352
        super().__init__(
1✔
353
            n_clusters,
354
            batch_size,
355
            epochs,
356
            patience,
357
            learning_rate,
358
            weight_decay,
359
            num_workers,
360
            device,
361
            tb_file_saving_path,
362
        )
363
        assert G_steps > 0 and D_steps > 0, "G_steps and D_steps should both >0"
1✔
364

365
        self.n_steps = n_steps
1✔
366
        self.n_features = n_features
1✔
367
        self.G_steps = G_steps
1✔
368
        self.D_steps = D_steps
1✔
369

370
        self.model = _CRLI(
1✔
371
            n_steps,
372
            n_features,
373
            n_clusters,
374
            n_generator_layers,
375
            rnn_hidden_size,
376
            decoder_fcn_output_dims,
377
            lambda_kmeans,
378
            rnn_cell_type,
379
            self.device,
380
        )
381
        self.model = self.model.to(self.device)
1✔
382
        self._print_model_size()
1✔
383
        self.logger = {"training_loss_generator": [], "training_loss_discriminator": []}
1✔
384

385
    def _assemble_input_for_training(self, data: list) -> dict:
1✔
386
        """Assemble the given data into a dictionary for training input.
387

388
        Parameters
389
        ----------
390
        data : list,
391
            A list containing data fetched from Dataset by Dataloader.
392

393
        Returns
394
        -------
395
        inputs : dict,
396
            A python dictionary contains the input data for model training.
397
        """
398

399
        # fetch data
400
        indices, X, _, missing_mask, _, _ = data
1✔
401

402
        inputs = {
1✔
403
            "X": X,
404
            "missing_mask": missing_mask,
405
        }
406

407
        return inputs
1✔
408

409
    def _assemble_input_for_validating(self, data: list) -> dict:
1✔
410
        """Assemble the given data into a dictionary for validating input.
411

412
        Notes
413
        -----
414
        The validating data assembling processing is the same as training data assembling.
415

416

417
        Parameters
418
        ----------
419
        data : list,
420
            A list containing data fetched from Dataset by Dataloader.
421

422
        Returns
423
        -------
424
        inputs : dict,
425
            A python dictionary contains the input data for model validating.
426
        """
427
        return self._assemble_input_for_training(data)
1✔
428

429
    def _assemble_input_for_testing(self, data: list) -> dict:
1✔
430
        """Assemble the given data into a dictionary for testing input.
431

432
        Notes
433
        -----
434
        The testing data assembling processing is the same as training data assembling.
435

436
        Parameters
437
        ----------
438
        data : list,
439
            A list containing data fetched from Dataset by Dataloader.
440

441
        Returns
442
        -------
443
        inputs : dict,
444
            A python dictionary contains the input data for model testing.
445
        """
446
        return self._assemble_input_for_validating(data)
1✔
447

448
    def _train_model(
1✔
449
        self,
450
        training_loader: DataLoader,
451
        val_loader: DataLoader = None,
452
    ) -> None:
453
        self.G_optimizer = torch.optim.Adam(
1✔
454
            [
455
                {"params": self.model.generator.parameters()},
456
                {"params": self.model.decoder.parameters()},
457
            ],
458
            lr=self.lr,
459
            weight_decay=self.weight_decay,
460
        )
461
        self.D_optimizer = torch.optim.Adam(
1✔
462
            self.model.discriminator.parameters(),
463
            lr=self.lr,
464
            weight_decay=self.weight_decay,
465
        )
466

467
        # each training starts from the very beginning, so reset the loss and model dict here
468
        self.best_loss = float("inf")
1✔
469
        self.best_model_dict = None
1✔
470

471
        try:
1✔
472
            for epoch in range(self.epochs):
1✔
473
                self.model.train()
1✔
474
                epoch_train_loss_G_collector = []
1✔
475
                epoch_train_loss_D_collector = []
1✔
476
                for idx, data in enumerate(training_loader):
1✔
477
                    inputs = self._assemble_input_for_training(data)
1✔
478

479
                    for _ in range(self.D_steps):
1✔
480
                        self.D_optimizer.zero_grad()
1✔
481
                        results = self.model.forward(
1✔
482
                            inputs, training_object="discriminator"
483
                        )
484
                        results["l_disc"].backward(retain_graph=True)
1✔
485
                        self.D_optimizer.step()
1✔
486
                        epoch_train_loss_D_collector.append(results["l_disc"].item())
1✔
487

488
                    for _ in range(self.G_steps):
1✔
489
                        self.G_optimizer.zero_grad()
1✔
490
                        results = self.model.forward(
1✔
491
                            inputs, training_object="generator"
492
                        )
493
                        results["l_gene"].backward()
1✔
494
                        self.G_optimizer.step()
1✔
495
                        epoch_train_loss_G_collector.append(results["l_gene"].item())
1✔
496

497
                mean_train_G_loss = np.mean(
1✔
498
                    epoch_train_loss_G_collector
499
                )  # mean training loss of the current epoch
500
                mean_train_D_loss = np.mean(
1✔
501
                    epoch_train_loss_D_collector
502
                )  # mean training loss of the current epoch
503
                self.logger["training_loss_generator"].append(mean_train_G_loss)
1✔
504
                self.logger["training_loss_discriminator"].append(mean_train_D_loss)
1✔
505
                logger.info(
1✔
506
                    f"epoch {epoch}: "
507
                    f"training loss_generator {mean_train_G_loss:.4f}, "
508
                    f"train loss_discriminator {mean_train_D_loss:.4f}"
509
                )
510
                mean_loss = mean_train_G_loss
1✔
511

512
                if mean_loss < self.best_loss:
1✔
513
                    self.best_loss = mean_loss
1✔
514
                    self.best_model_dict = self.model.state_dict()
1✔
515
                    self.patience = self.original_patience
1✔
516
                else:
517
                    self.patience -= 1
×
518
                    if self.patience == 0:
×
519
                        logger.info(
×
520
                            "Exceeded the training patience. Terminating the training procedure..."
521
                        )
522
                        break
×
523
        except Exception as e:
×
524
            logger.info(f"Exception: {e}")
×
525
            if self.best_model_dict is None:
×
526
                raise RuntimeError(
×
527
                    "Training got interrupted. Model was not get trained. Please try fit() again."
528
                )
529
            else:
530
                RuntimeWarning(
×
531
                    "Training got interrupted. "
532
                    "Model will load the best parameters so far for testing. "
533
                    "If you don't want it, please try fit() again."
534
                )
535

536
        if np.equal(self.best_loss, float("inf")):
1✔
537
            raise ValueError("Something is wrong. best_loss is Nan after training.")
×
538

539
        logger.info("Finished training.")
1✔
540

541
    def fit(
1✔
542
        self,
543
        train_set: Union[dict, str],
544
        file_type: str = "h5py",
545
    ) -> None:
546
        """Train the cluster.
547

548
        Parameters
549
        ----------
550
        train_set : dict or str,
551
            The dataset for model training, should be a dictionary including the key 'X',
552
            or a path string locating a data file.
553
            If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
554
            which is time-series data for training, can contain missing values.
555
            If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
556
            key-value pairs like a dict, and it has to include the key 'X'.
557

558
        file_type : str, default = "h5py"
559
            The type of the given file if train_set is a path string.
560

561
        """
562
        training_set = DatasetForGRUD(train_set, file_type)
1✔
563
        training_loader = DataLoader(
1✔
564
            training_set,
565
            batch_size=self.batch_size,
566
            shuffle=True,
567
            num_workers=self.num_workers,
568
        )
569
        self._train_model(training_loader)
1✔
570
        self.model.load_state_dict(self.best_model_dict)
1✔
571
        self.model.eval()  # set the model as eval status to freeze it.
1✔
572

573
    def cluster(
1✔
574
        self,
575
        X: Union[dict, str],
576
        file_type: str = "h5py",
577
    ) -> np.ndarray:
578
        """Cluster the input with the trained model.
579

580
        Parameters
581
        ----------
582
        X : array-like or str,
583
            The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
584
            n_features], or a path string locating a data file, e.g. h5 file.
585

586
        file_type : str, default = "h5py"
587
            The type of the given file if X is a path string.
588

589
        Returns
590
        -------
591
        array-like, shape [n_samples],
592
            Clustering results.
593
        """
594
        self.model.eval()  # set the model as eval status to freeze it.
1✔
595
        test_set = DatasetForGRUD(X, file_type)
1✔
596
        test_loader = DataLoader(
1✔
597
            test_set,
598
            batch_size=self.batch_size,
599
            shuffle=False,
600
            num_workers=self.num_workers,
601
        )
602
        latent_collector = []
1✔
603

604
        with torch.no_grad():
1✔
605
            for idx, data in enumerate(test_loader):
1✔
606
                inputs = self._assemble_input_for_testing(data)
1✔
607
                inputs = self.model.cluster(inputs)
1✔
608
                latent_collector.append(inputs["fcn_latent"])
1✔
609

610
        latent_collector = torch.cat(latent_collector).cpu().detach().numpy()
1✔
611
        clustering = self.model.kmeans.fit_predict(latent_collector)
1✔
612

613
        return clustering
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc