• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 16502811447

24 Jul 2025 04:44PM UTC coverage: 97.767% (+0.1%) from 97.651%
16502811447

push

github

web-flow
New version 1.5.0

This pull request introduces version 1.5.0 of **nnodely**, featuring several updates:
1. Improved clarity of documentation and examples.
2. Support for managing multi-dataset features is now available.
3. DataFrames can now be used to create datasets.
4. Datasets can now be resampled.
5. Random data training has been fixed for both classic and recurrent training.
6. The `state` variable has been removed.
7. It is now possible to add or remove a connection or a closed loop.
8. Partial models can now be exported.
9. The `train` function and the result analysis have been separated.
10. A new function, `trainAndAnalyse`, is now available.
11. The report now works across all network types.
12. The training function code has been reorganized.

2901 of 2967 new or added lines in 53 files covered. (97.78%)

16 existing lines in 6 files now uncovered.

12652 of 12941 relevant lines covered (97.77%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.61
/nnodely/operators/network.py
1
import copy
1✔
2
from collections import defaultdict
1✔
3
from unittest import result
1✔
4

5
import numpy as np
1✔
6
import  torch, random
1✔
7

8
from nnodely.support.utils import TORCH_DTYPE, NP_DTYPE, check, enforce_types, tensor_to_list
1✔
9
from nnodely.basic.modeldef import ModelDef
1✔
10

11
from nnodely.support.logger import logging, nnLogger
1✔
12
log = nnLogger(__name__, logging.CRITICAL)
1✔
13

14
class Network:
1✔
15
    @enforce_types
1✔
16
    def __init__(self):
1✔
17
        check(type(self) is not Network, TypeError, "Loader class cannot be instantiated directly")
1✔
18

19
        # Models definition
20
        self._model_def = ModelDef()
1✔
21
        self._model = None
1✔
22
        self._neuralized = False
1✔
23
        self._traced = False
1✔
24

25
        # Model components
26
        self._states = {}
1✔
27
        self._input_n_samples = {}
1✔
28
        self._input_ns_backward = {}
1✔
29
        self._input_ns_forward = {}
1✔
30
        self._max_samples_backward = None
1✔
31
        self._max_samples_forward = None
1✔
32
        self._max_n_samples = 0
1✔
33

34
        # Dataset information
35
        self._data_loaded = False
1✔
36
        self._file_count = 0
1✔
37
        self._num_of_samples = {}
1✔
38
        self._data = {}
1✔
39
        self._multifile = {}
1✔
40

41
        # Training information
42
        self._standard_train_parameters = {
1✔
43
            'models': None,
44
            'train_dataset': None, 'validation_dataset': None,
45
            'dataset': None, 'splits': [100, 0, 0],
46
            'closed_loop': {}, 'connect': {}, 'step': 0, 'prediction_samples': 0,
47
            'shuffle_data': True,
48
            'early_stopping': None, 'early_stopping_params': {},
49
            'select_model': 'last', 'select_model_params': {},
50
            'minimize_gain': {},
51
            'num_of_epochs': 100,
52
            'train_batch_size': 128, 'val_batch_size': 128,
53
            'optimizer': 'Adam',
54
            'lr': 0.001, 'lr_param': {},
55
            'optimizer_params': [], 'add_optimizer_params': [],
56
            'optimizer_defaults': {}, 'add_optimizer_defaults': {}
57
        }
58
        self._training = {}
1✔
59

60
        # Save internal
61
        self._log_internal = False
1✔
62
        self._internals = {}
1✔
63

64
    def _save_internal(self, key, value):
1✔
65
        self._internals[key] = tensor_to_list(value)
1✔
66

67
    def _set_log_internal(self, log_internal:bool):
1✔
68
        self._log_internal = log_internal
1✔
69

70
    def _clean_log_internal(self):
1✔
71
        self._internals = {}
1✔
72

73
    def _removeVirtualStates(self, connect, closed_loop):
1✔
74
        if connect or closed_loop:
1✔
75
            for key in (connect.keys() | closed_loop.keys()):
1✔
76
                if key in self._states.keys():
1✔
77
                    del self._states[key]
1✔
78

79
    def _updateState(self, X, out_closed_loop, out_connect):
1✔
80
        for key, value in out_connect.items():
1✔
81
            X[key] = value
1✔
82
            self._states[key] = X[key].clone().detach()
1✔
83
        for key, val in out_closed_loop.items():
1✔
84
            shift = val.shape[1]  ## take the output time dimension
1✔
85
            X[key] = torch.roll(X[key], shifts=-1, dims=1)  ## Roll the time window
1✔
86
            X[key][:, -shift:, :] = val  ## substitute with the predicted value
1✔
87
            self._states[key] = X[key].clone().detach()
1✔
88

89
    def _get_gradient_on_inference(self):
1✔
90
        for key, value in self._model_def['Inputs'].items():
1✔
91
            if 'type' in value.keys():
1✔
92
                return True
1✔
93
        return False
1✔
94

95
    def _get_mandatory_inputs(self, connect, closed_loop):
1✔
96
        model_inputs = list(self._model_def['Inputs'].keys())
1✔
97
        non_mandatory_inputs = list(closed_loop.keys()) + list(connect.keys()) + list(self._model_def.recurrentInputs().keys())
1✔
98
        mandatory_inputs = list(set(model_inputs) - set(non_mandatory_inputs))
1✔
99
        return mandatory_inputs, non_mandatory_inputs
1✔
100
    
101
    def _get_batch_indexes(self, datasets:str|list|dict|None, n_samples:int=0, prediction_samples:int=0):
1✔
102
        if datasets is None:
1✔
103
            return []
1✔
104
        batch_indexes = list(range(n_samples))
1✔
105
        if prediction_samples > 0 and not isinstance(datasets, dict):
1✔
106
            datasets = [datasets] if type(datasets) is str else datasets
1✔
107
            forbidden_idxs = []
1✔
108
            n_samples_count = 0
1✔
109
            for dataset in datasets:
1✔
110
                if dataset in self._multifile.keys(): ## i have some forbidden indexes
1✔
111
                    for i in self._multifile[dataset]:
1✔
112
                        if i+n_samples_count < batch_indexes[-1]:
1✔
113
                            forbidden_idxs.extend(range((i+n_samples_count) - prediction_samples, (i+n_samples_count), 1))
1✔
114
                n_samples_count += self._num_of_samples[dataset]
1✔
115
            batch_indexes = [idx for idx in batch_indexes if idx not in forbidden_idxs]
1✔
116
            batch_indexes = batch_indexes[:-prediction_samples]
1✔
117
        return batch_indexes
1✔
118
    
119
    def _get_data(self, dataset:str|list|dict|None):
1✔
120
        if dataset is None:
1✔
121
            return {}
1✔
122
        if isinstance(dataset, dict):
1✔
123
            self.__check_data_integrity(dataset)
1✔
124
            return dataset
1✔
125
        dataset = [dataset] if type(dataset) is str else dataset
1✔
126
        loaded_datasets = list(self._data.keys())
1✔
127
        check(len([data for data in dataset if data in loaded_datasets]) > 0, KeyError, f'the datasets: {dataset} are not loaded!')
1✔
128
        total_data = defaultdict(list)
1✔
129
        for data in dataset:
1✔
130
            if data not in loaded_datasets:
1✔
131
                log.warning(f'{data} is not loaded. Ignoring this dataset...') 
1✔
132
                dataset.remove(data)
1✔
133
                continue
1✔
134
            for k, v in self._data[data].items():
1✔
135
                total_data[k].append(v)
1✔
136
        total_data = {key: np.concatenate(arrays) for key, arrays in total_data.items()}
1✔
137
        total_data = {key: torch.from_numpy(val).to(TORCH_DTYPE) for key, val in total_data.items()}
1✔
138
        return total_data
1✔
139

140
    def _clip_step(self, step, batch_indexes, batch_size):
1✔
141
        clipped_step = copy.deepcopy(step)
1✔
142
        if clipped_step < 0:  ## clip the step to zero
1✔
143
            log.warning(f"The step is negative ({clipped_step}). The step is set to zero.", stacklevel=5)
1✔
144
            clipped_step = 0
1✔
145
        if clipped_step > (len(batch_indexes) - batch_size):  ## Clip the step to the maximum number of samples
1✔
146
            log.warning(f"The step ({clipped_step}) is greater than the number of available samples ({len(batch_indexes) - batch_size}). The step is set to the maximum number.", stacklevel=5)
1✔
147
            clipped_step = len(batch_indexes) - batch_size
1✔
148
        check((batch_size + clipped_step) > 0, ValueError, f"The sum of batch_size={batch_size} and the step={clipped_step} must be greater than 0.")
1✔
149
        return clipped_step
1✔
150

151
    def _clip_batch_size(self, n_samples, batch_size=None):
1✔
152
        batch_size = batch_size if batch_size <= n_samples else max(0, n_samples)
1✔
153
        check((n_samples - batch_size + 1) > 0, ValueError, f"The number of available sample are {n_samples - batch_size + 1}")
1✔
154
        check(batch_size > 0, ValueError, f'The batch_size must be greater than 0.')
1✔
155
        return batch_size
1✔
156
    
157
    def __split_dataset(self, dataset:str|list|dict, splits:list):
1✔
158
        check(len(splits) == 3, ValueError, '3 elements must be inserted for the dataset split in training, validation and test')
1✔
159
        check(sum(splits) == 100, ValueError, 'Training, Validation and Test splits must sum up to 100.')
1✔
160
        check(splits[0] > 0, ValueError, 'The training split cannot be zero.')
1✔
161
        train_size, val_size, test_size = splits[0] / 100, splits[1] / 100, splits[2] / 100
1✔
162
        XY_train, XY_val, XY_test = {}, {}, {}
1✔
163
        if isinstance(dataset, dict):
1✔
NEW
164
            self.__check_data_integrity(dataset)
×
NEW
165
            num_of_samples = next(iter(dataset.values())).size(0)
×
NEW
166
            XY_train = {key: value[:round(num_of_samples*train_size), :, :] for key, value in dataset.items()}
×
NEW
167
            XY_val = {key: value[round(num_of_samples*train_size):round(num_of_samples*(train_size + val_size)), :, :] for key, value in dataset.items()}
×
NEW
168
            XY_test = {key: value[round(num_of_samples*(train_size + val_size)):, :, :] for key, value in dataset.items()}
×
169
        else:
170
            dataset = [dataset] if type(dataset) is str else dataset
1✔
171
            check(len([data for data in dataset if data in self._data.keys()]) > 0, KeyError, f'the datasets: {dataset} are not loaded!')
1✔
172
            for data in dataset:
1✔
173
                if data not in self._data.keys():
1✔
NEW
174
                    log.warning(f'{data} is not loaded. The training will continue without this dataset.') 
×
NEW
175
                    dataset.remove(data)
×
176

177
            num_of_samples = sum([self._num_of_samples[data] for data in dataset])
1✔
178
            n_samples_train, n_samples_val = round(num_of_samples * train_size), round(num_of_samples * val_size)
1✔
179
            n_samples_test = num_of_samples - n_samples_train - n_samples_val
1✔
180
            check(n_samples_train > 0, ValueError, f'The number of train samples {n_samples_train} must be greater than 0.')
1✔
181
            total_data = defaultdict(list)
1✔
182
            for data in dataset:
1✔
183
                for k, v in self._data[data].items():
1✔
184
                    total_data[k].append(v)
1✔
185
            total_data = {key: np.concatenate(arrays, dtype=NP_DTYPE) for key, arrays in total_data.items()}
1✔
186
            for key, samples in total_data.items():
1✔
187
                if val_size == 0.0 and test_size == 0.0:  ## we have only training set
1✔
188
                    XY_train[key] = torch.from_numpy(samples).to(TORCH_DTYPE)
1✔
189
                elif val_size == 0.0 and test_size != 0.0:  ## we have only training and test set
1✔
190
                    XY_train[key] = torch.from_numpy(samples[:n_samples_train]).to(TORCH_DTYPE)
1✔
191
                    XY_test[key] = torch.from_numpy(samples[n_samples_train:]).to(TORCH_DTYPE)
1✔
192
                elif val_size != 0.0 and test_size == 0.0:  ## we have only training and validation set
1✔
193
                    XY_train[key] = torch.from_numpy(samples[:n_samples_train]).to(TORCH_DTYPE)
1✔
194
                    XY_val[key] = torch.from_numpy(samples[n_samples_train:]).to(TORCH_DTYPE)
1✔
195
                else:  ## we have training, validation and test set
196
                    XY_train[key] = torch.from_numpy(samples[:n_samples_train]).to(TORCH_DTYPE)
1✔
197
                    XY_val[key] = torch.from_numpy(samples[n_samples_train:-n_samples_test]).to(TORCH_DTYPE)
1✔
198
                    XY_test[key] = torch.from_numpy(samples[n_samples_train + n_samples_val:]).to(TORCH_DTYPE)
1✔
199
        return XY_train, XY_val, XY_test
1✔
200

201
    def _get_tag(self, dataset: str | list | dict | None) -> str:
1✔
202
        """
203
        Helper function to get the tag for a dataset.
204
        """
205
        if isinstance(dataset, str):
1✔
206
            return dataset
1✔
207
        elif isinstance(dataset, list):
1✔
208
            return f"{dataset[0]}_{len(dataset)}" if len(dataset) > 1 else f"{dataset[0]}"
1✔
209
        elif isinstance(dataset, dict):
1✔
210
            return "custom_dataset"
1✔
211
        return dataset
1✔
212

213
    def _setup_dataset(self, train_dataset:str|list|dict, validation_dataset:str|list|dict, test_dataset:str|list|dict, dataset:str|list|dict, splits:list):
1✔
214
        if train_dataset is None: ## use the splits
1✔
215
            train_dataset = list(self._data.keys()) if dataset is None else dataset
1✔
216
            return self.__split_dataset(train_dataset, splits)
1✔
217
        else: ## use each dataset
218
            return self._get_data(train_dataset), self._get_data(validation_dataset), self._get_data(test_dataset)
1✔
219

220
    def __check_data_integrity(self, dataset:dict):
1✔
221
        if bool(dataset):
1✔
222
            check(len(set([t.size(0) for t in dataset.values()])) == 1, ValueError, "All the tensors in the dataset must have the same number of samples.")
1✔
223
            check(len([t for t in self._model_def['Inputs'].keys() if t in dataset.keys()]) == len(list(self._model_def['Inputs'].keys())), ValueError, "Some inputs are missing.")
1✔
224
            for key, value in dataset.items():
1✔
225
                if key not in self._model_def['Inputs']:
1✔
226
                    log.warning(f"The key '{key}' is not an input of the network. It will be ignored.")
1✔
227
                else:
228
                    check(isinstance(value, torch.Tensor), TypeError, f"The value of the input '{key}' must be a torch.Tensor.")
1✔
229
                    check(value.size(1) == self._model_def['Inputs'][key]['ntot'], ValueError, f"The time size of the input '{key}' is not correct. Expected {self._model_def['Inputs'][key]['ntot']}, got {value.size(1)}.")
1✔
230
                    check(value.size(2) == self._model_def['Inputs'][key]['dim'], ValueError, f"The dimension of the input '{key}' is not correct. Expected {self._model_def['Inputs'][key]['dim']}, got {value.size(2)}.")
1✔
231

232
    def _get_not_mandatory_inputs(self, data, X, non_mandatory_inputs, remaning_indexes, batch_size, step, shuffle = False):
1✔
233
        related_indexes = random.sample(remaning_indexes, batch_size) if shuffle else remaning_indexes[:batch_size]
1✔
234
        for num in related_indexes:
1✔
235
            remaning_indexes.remove(num)
1✔
236
        if step > 0:
1✔
237
            if len(remaning_indexes) >= step:
1✔
238
                step_idxs = random.sample(remaning_indexes, step) if shuffle else remaning_indexes[:step]
1✔
239
                for num in step_idxs:
1✔
240
                    remaning_indexes.remove(num)
1✔
241
            else:
242
                remaning_indexes.clear()
1✔
243
        for key in non_mandatory_inputs:
1✔
244
            if key in data.keys(): ## with data
1✔
245
                X[key] = data[key][related_indexes]
1✔
246
            else:  ## with zeros
247
                window_size = self._input_n_samples[key]
1✔
248
                dim = self._model_def['Inputs'][key]['dim']
1✔
249
                if 'type' in self._model_def['Inputs'][key]:
1✔
250
                    X[key] = torch.zeros(size=(batch_size, window_size, dim), dtype=TORCH_DTYPE, requires_grad=True)
1✔
251
                else:
252
                    X[key] = torch.zeros(size=(batch_size, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
253
                self._states[key] = X[key]
1✔
254
        return related_indexes
1✔
255

256
    def _inference(self, data, n_samples, batch_size, loss_gains, loss_functions,
1✔
257
                    shuffle = False, optimizer = None,
258
                    total_losses = None, A = None, B = None):
259
        if shuffle:
1✔
260
            randomize = torch.randperm(n_samples)
1✔
261
            data = {key: val[randomize] for key, val in data.items()}
1✔
262
        ## Initialize the train losses vector
263
        aux_losses = torch.zeros([len(self._model_def['Minimizers']), n_samples // batch_size])
1✔
264
        for idx in range(0, (n_samples - batch_size + 1), batch_size):
1✔
265
            ## Build the input tensor
266
            XY = {key: val[idx:idx + batch_size] for key, val in data.items()}
1✔
267
            ## Reset gradient
268
            if optimizer:
1✔
269
                optimizer.zero_grad()
1✔
270
            ## Model Forward
271
            _, minimize_out, _, _ = self._model(XY)  ## Forward pass
1✔
272
            ## Loss Calculation
273
            total_loss = 0
1✔
274
            for ind, (key, value) in enumerate(self._model_def['Minimizers'].items()):
1✔
275
                if A is not None:
1✔
276
                    A[key].append(minimize_out[value['A']].detach().numpy())
1✔
277
                if B is not None:
1✔
278
                    B[key].append(minimize_out[value['B']].detach().numpy())
1✔
279
                loss = loss_functions[key](minimize_out[value['A']], minimize_out[value['B']])
1✔
280
                loss = (loss * loss_gains[key]) if key in loss_gains.keys() else loss
1✔
281
                if total_losses is not None:
1✔
282
                    total_losses[key].append(loss.detach().numpy())
1✔
283
                aux_losses[ind][idx // batch_size] = loss.item()
1✔
284
                total_loss += loss
1✔
285
            ## Gradient step
286
            if optimizer:
1✔
287
                total_loss.backward()
1✔
288
                optimizer.step()
1✔
289
                self.visualizer.showWeightsInTrain(batch=idx // batch_size)
1✔
290

291
        ## return the losses
292
        return aux_losses
1✔
293

294
    def _recurrent_inference(self, data, batch_indexes, batch_size, loss_gains, prediction_samples,
1✔
295
                             step, non_mandatory_inputs, mandatory_inputs, loss_functions,
296
                             shuffle = False, optimizer = None,
297
                             total_losses = None, A = None, B = None):
298
        indexes = copy.deepcopy(batch_indexes)
1✔
299
        aux_losses = torch.zeros([len(self._model_def['Minimizers']), round((len(indexes) + step) / (batch_size + step))])
1✔
300
        X = {}
1✔
301
        batch_val = 0
1✔
302
        while len(indexes) >= batch_size:
1✔
303
            selected_indexes = self._get_not_mandatory_inputs(data, X, non_mandatory_inputs, indexes, batch_size, step, shuffle)
1✔
304
            horizon_losses = {ind: [] for ind in range(len(self._model_def['Minimizers']))}
1✔
305
            if optimizer:
1✔
306
                optimizer.zero_grad()  ## Reset the gradient
1✔
307

308
            for horizon_idx in range(prediction_samples + 1):
1✔
309
                ## Get data
310
                for key in mandatory_inputs:
1✔
311
                    X[key] = data[key][[idx + horizon_idx for idx in selected_indexes]]
1✔
312
                ## Forward pass
313
                out, minimize_out, out_closed_loop, out_connect = self._model(X)
1✔
314

315
                if self._log_internal:
1✔
316
                    #assert (check_gradient_operations(self._states) == 0)
317
                    #assert (check_gradient_operations(data) == 0)
318
                    internals_dict = {'XY': tensor_to_list(X), 'out': out, 'param': self._model.all_parameters,
1✔
319
                                      'closedLoop': self._model.closed_loop_update, 'connect': self._model.connect_update}
320

321
                ## Loss Calculation
322
                for ind, (key, value) in enumerate(self._model_def['Minimizers'].items()):
1✔
323
                    if A is not None:
1✔
324
                        A[key][horizon_idx].append(minimize_out[value['A']].detach().numpy())
1✔
325
                    if B is not None:
1✔
326
                        B[key][horizon_idx].append(minimize_out[value['B']].detach().numpy())
1✔
327
                    loss = loss_functions[key](minimize_out[value['A']], minimize_out[value['B']])
1✔
328
                    loss = (loss * loss_gains[key]) if key in loss_gains.keys() else loss
1✔
329
                    horizon_losses[ind].append(loss)
1✔
330

331
                ## Update
332
                self._updateState(X, out_closed_loop, out_connect)
1✔
333

334
                if self._log_internal:
1✔
335
                    internals_dict['state'] = self._states
1✔
336
                    self._save_internal('inout_' + str(batch_val) + '_' + str(horizon_idx), internals_dict)
1✔
337

338
            ## Calculate the total loss
339
            total_loss = 0
1✔
340
            for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
341
                loss = sum(horizon_losses[ind]) / (prediction_samples + 1)
1✔
342
                aux_losses[ind][batch_val] = loss.item()
1✔
343
                if total_losses is not None:
1✔
344
                    total_losses[key].append(loss.detach().numpy())
1✔
345
                total_loss += loss
1✔
346

347
            ## Gradient Step
348
            if optimizer:
1✔
349
                total_loss.backward()  ## Backpropagate the error
1✔
350
                optimizer.step()
1✔
351
                self.visualizer.showWeightsInTrain(batch=batch_val)
1✔
352
            batch_val += 1
1✔
353

354
        ## return the losses
355
        return aux_losses
1✔
356

357
    def _setup_recurrent_variables(self, prediction_samples, closed_loop, connect):
1✔
358
        ## Prediction samples
359
        check(prediction_samples >= -1, KeyError, 'The sample horizon must be positive or -1 for disconnect connection!')
1✔
360
        ## Close loop information
361
        for input, output in closed_loop.items():
1✔
362
            check(input in self._model_def['Inputs'], ValueError, f'the tag {input} is not an input variable.')
1✔
363
            check(output in self._model_def['Outputs'], ValueError, f'the tag {output} is not an output of the network')
1✔
364
            log.warning(f'Recurrent train: closing the loop between the the input ports {input} and the output ports {output} for {prediction_samples} samples')
1✔
365
        ## Connect information
366
        for connect_in, connect_out in connect.items():
1✔
367
            check(connect_in in self._model_def['Inputs'], ValueError, f'the tag {connect_in} is not an input variable.')
1✔
368
            check(connect_out in self._model_def['Outputs'], ValueError, f'the tag {connect_out} is not an output of the network')
1✔
369
            log.warning(f'Recurrent train: connecting the input ports {connect_in} with output ports {connect_out} for {prediction_samples} samples')
1✔
370
        ## Disable recurrent training if there are no recurrent variables
371
        if len(connect|closed_loop|self._model_def.recurrentInputs()) == 0:
1✔
372
            if prediction_samples >= 0:
1✔
373
                log.warning(f"The value of the prediction_samples={prediction_samples} but the network has no recurrent variables.")
1✔
374
            prediction_samples = -1
1✔
375
        return prediction_samples
1✔
376

377
    @enforce_types
1✔
378
    def resetStates(self, states:set={}, *, batch:int=1) -> None:
1✔
379
        """
380
        Resets the state of all the recurrent inputs of the network to zero.
381
        Parameters
382
        ----------
383
        states : set, optional
384
            A set of recurrent inputs names to reset. If provided, only those inputs will be resetted.
385
        batch : int, optional
386
            The batch size for the reset states. Default is 1.
387
        """
388
        if states: ## reset only specific states
1✔
389
            for key in states:
1✔
390
                window_size = self._input_n_samples[key]
1✔
391
                dim = self._model_def['Inputs'][key]['dim']
1✔
392
                self._states[key] = torch.zeros(size=(batch, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
393
        else: ## reset all states
394
            self._states = {}
1✔
395
            for key, state in self._model_def.recurrentInputs().items():
1✔
396
                window_size = self._input_n_samples[key]
1✔
397
                dim = state['dim']
1✔
398
                self._states[key] = torch.zeros(size=(batch, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
399

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc