• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 20079174979

09 Dec 2025 09:31PM UTC coverage: 98.1% (+0.3%) from 97.767%
20079174979

Pull #109

github

tonegas
Removed skipped test for windows
Pull Request #109: New version of nnodely

923 of 933 new or added lines in 38 files covered. (98.93%)

10 existing lines in 4 files now uncovered.

13266 of 13523 relevant lines covered (98.1%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.39
/nnodely/operators/network.py
1
import copy
1✔
2
from collections import defaultdict
1✔
3

4
import numpy as np
1✔
5
import  torch, random
1✔
6

7
from nnodely.support.utils import TORCH_DTYPE, NP_DTYPE, check, enforce_types, tensor_to_list
1✔
8
from nnodely.basic.modeldef import ModelDef
1✔
9

10
from nnodely.support.logger import logging, nnLogger
1✔
11
log = nnLogger(__name__, logging.WARNING)
1✔
12

13
class Network:
1✔
14
    @enforce_types
1✔
15
    def __init__(self):
1✔
16
        check(type(self) is not Network, TypeError, "Loader class cannot be instantiated directly")
1✔
17

18
        # Models definition
19
        self._model_def = ModelDef()
1✔
20
        self._model = None
1✔
21
        self._neuralized = False
1✔
22
        self._traced = False
1✔
23

24
        # Model components
25
        self._states = {}
1✔
26
        self._input_n_samples = {}
1✔
27
        self._input_ns_backward = {}
1✔
28
        self._input_ns_forward = {}
1✔
29
        self._max_samples_backward = None
1✔
30
        self._max_samples_forward = None
1✔
31
        self._max_n_samples = 0
1✔
32

33
        # Dataset information
34
        self._data_loaded = False
1✔
35
        self._file_count = 0
1✔
36
        self._num_of_samples = {}
1✔
37
        self._data = {}
1✔
38
        self._multifile = {}
1✔
39

40
        # Training information
41
        self._standard_train_parameters = {
1✔
42
            'models': None,
43
            'train_dataset': None, 'validation_dataset': None,
44
            'dataset': None, 'splits': [100, 0, 0],
45
            'closed_loop': {}, 'connect': {}, 'step': 0, 'prediction_samples': 0,
46
            'shuffle_data': True,
47
            'early_stopping': None, 'early_stopping_params': {},
48
            'select_model': 'last', 'select_model_params': {},
49
            'minimize_gain': {},
50
            'num_of_epochs': 100,
51
            'train_batch_size': 128, 'val_batch_size': 128,
52
            'optimizer': 'Adam',
53
            'lr': 0.001, 'lr_param': {},
54
            'optimizer_params': [], 'add_optimizer_params': [],
55
            'optimizer_defaults': {}, 'add_optimizer_defaults': {}
56
        }
57
        self._training = {}
1✔
58

59
        # Save internal
60
        self._log_internal = False
1✔
61
        self._internals = {}
1✔
62

63
    def _save_internal(self, key, value):
1✔
64
        self._internals[key] = tensor_to_list(value)
1✔
65

66
    def _set_log_internal(self, log_internal:bool):
1✔
67
        self._log_internal = log_internal
1✔
68

69
    def _clean_log_internal(self):
1✔
70
        self._internals = {}
1✔
71

72
    def _remove_virtual_states(self, connect, closed_loop):
1✔
73
        if connect or closed_loop:
1✔
74
            for key in (connect.keys() | closed_loop.keys()):
1✔
75
                if key in self._states.keys():
1✔
76
                    del self._states[key]
1✔
77

78
    def _update_state(self, X, out_closed_loop, out_connect):
1✔
79
        for key, value in out_connect.items():
1✔
80
            X[key] = torch.roll(value, shifts=-1, dims=1)  ## Roll the time window
1✔
81
            X[key][:, -1, :] = float('inf') ## inf value to make clear that the last state value
1✔
82
            self._states[key] = X[key].clone().detach()
1✔
83
        for key, val in out_closed_loop.items():
1✔
84
            shift = val.shape[1] #+ self._input_ns_forward[key]  ## take the output time dimension + forward samples
1✔
85
            X[key] = torch.roll(X[key], shifts=-1, dims=1)  ## Roll the time window
1✔
86
            X[key][:, -shift:, :] = val  ## substitute with the predicted value
1✔
87
            self._states[key] = X[key].clone().detach()
1✔
88

89
    def _get_gradient_on_inference(self):
1✔
90
        for key, value in self._model_def['Inputs'].items():
1✔
91
            if 'type' in value.keys():
1✔
92
                return True
1✔
93
        return False
1✔
94

95
    def _get_mandatory_inputs(self, connect, closed_loop):
1✔
96
        model_inputs = list(self._model_def['Inputs'].keys())
1✔
97
        non_mandatory_inputs = list(closed_loop.keys()) + list(connect.keys()) + list(self._model_def.recurrentInputs().keys())
1✔
98
        mandatory_inputs = list(set(model_inputs) - set(non_mandatory_inputs))
1✔
99
        return mandatory_inputs, non_mandatory_inputs
1✔
100
    
101
    def _get_batch_indexes(self, datasets:str|list|dict|None, n_samples:int=0, prediction_samples:int=0):
1✔
102
        if datasets is None:
1✔
103
            return []
1✔
104
        batch_indexes = list(range(n_samples))
1✔
105
        if prediction_samples > 0 and not isinstance(datasets, dict):
1✔
106
            datasets = [datasets] if type(datasets) is str else datasets
1✔
107
            forbidden_idxs = []
1✔
108
            n_samples_count = 0
1✔
109
            for dataset in datasets:
1✔
110
                if dataset in self._multifile.keys(): ## i have some forbidden indexes
1✔
111
                    for i in self._multifile[dataset]:
1✔
112
                        if i+n_samples_count < batch_indexes[-1]:
1✔
113
                            forbidden_idxs.extend(range((i+n_samples_count) - prediction_samples, (i+n_samples_count), 1))
1✔
114
                n_samples_count += self._num_of_samples[dataset]
1✔
115
            batch_indexes = [idx for idx in batch_indexes if idx not in forbidden_idxs]
1✔
116
            batch_indexes = batch_indexes[:-prediction_samples]
1✔
117
        return batch_indexes
1✔
118
    
119
    def _get_data(self, dataset:str|list|dict|None):
1✔
120
        if dataset is None:
1✔
121
            return {}
1✔
122
        if isinstance(dataset, dict):
1✔
123
            self.__check_data_integrity(dataset)
1✔
124
            return dataset
1✔
125
        dataset = [dataset] if type(dataset) is str else dataset
1✔
126
        loaded_datasets = list(self._data.keys())
1✔
127
        check(len([data for data in dataset if data in loaded_datasets]) > 0, KeyError, f'the datasets: {dataset} are not loaded!')
1✔
128
        total_data = defaultdict(list)
1✔
129
        for data in dataset:
1✔
130
            if data not in loaded_datasets:
1✔
131
                log.warning(f'{data} is not loaded. Ignoring this dataset...') 
1✔
132
                dataset.remove(data)
1✔
133
                continue
1✔
134
            for k, v in self._data[data].items():
1✔
135
                total_data[k].append(v)
1✔
136
        total_data = {key: np.concatenate(arrays) for key, arrays in total_data.items()}
1✔
137
        total_data = {key: torch.from_numpy(val).to(TORCH_DTYPE) for key, val in total_data.items()}
1✔
138
        return total_data
1✔
139

140
    def _clip_step(self, step, batch_indexes, batch_size):
1✔
141
        clipped_step = copy.deepcopy(step)
1✔
142
        if clipped_step < 0:  ## clip the step to zero
1✔
143
            log.warning(f"The step is negative ({clipped_step}). The step is set to zero.", stacklevel=5)
1✔
144
            clipped_step = 0
1✔
145
        if clipped_step > (len(batch_indexes) - batch_size):  ## Clip the step to the maximum number of samples
1✔
146
            log.warning(f"The step ({clipped_step}) is greater than the number of available samples ({len(batch_indexes) - batch_size}). The step is set to the maximum number.", stacklevel=5)
1✔
147
            clipped_step = len(batch_indexes) - batch_size
1✔
148
        check((batch_size + clipped_step) > 0, ValueError, f"The sum of batch_size={batch_size} and the step={clipped_step} must be greater than 0.")
1✔
149
        return clipped_step
1✔
150

151
    def _clip_batch_size(self, n_samples, batch_size=None):
1✔
152
        batch_size = batch_size if batch_size <= n_samples else max(0, n_samples)
1✔
153
        check((n_samples - batch_size + 1) > 0, ValueError, f"The number of available sample are {n_samples - batch_size + 1}")
1✔
154
        check(batch_size > 0, ValueError, f'The batch_size must be greater than 0.')
1✔
155
        return batch_size
1✔
156
    
157
    def __split_dataset(self, dataset:str|list|dict, splits:list):
1✔
158
        check(len(splits) == 3, ValueError, '3 elements must be inserted for the dataset split in training, validation and test')
1✔
159
        check(sum(splits) == 100, ValueError, 'Training, Validation and Test splits must sum up to 100.')
1✔
160
        check(splits[0] > 0, ValueError, 'The training split cannot be zero.')
1✔
161
        train_size, val_size, test_size = splits[0] / 100, splits[1] / 100, splits[2] / 100
1✔
162
        XY_train, XY_val, XY_test = {}, {}, {}
1✔
163
        if isinstance(dataset, dict):
1✔
164
            self.__check_data_integrity(dataset)
×
165
            num_of_samples = next(iter(dataset.values())).size(0)
×
166
            XY_train = {key: value[:round(num_of_samples*train_size), :, :] for key, value in dataset.items()}
×
167
            XY_val = {key: value[round(num_of_samples*train_size):round(num_of_samples*(train_size + val_size)), :, :] for key, value in dataset.items()}
×
168
            XY_test = {key: value[round(num_of_samples*(train_size + val_size)):, :, :] for key, value in dataset.items()}
×
169
        else:
170
            dataset = [dataset] if type(dataset) is str else dataset
1✔
171
            check(len([data for data in dataset if data in self._data.keys()]) > 0, KeyError, f'the datasets: {dataset} are not loaded!')
1✔
172
            for data in dataset:
1✔
173
                if data not in self._data.keys():
1✔
174
                    log.warning(f'{data} is not loaded. The training will continue without this dataset.') 
×
175
                    dataset.remove(data)
×
176

177
            num_of_samples = sum([self._num_of_samples[data] for data in dataset])
1✔
178
            n_samples_train, n_samples_val = round(num_of_samples * train_size), round(num_of_samples * val_size)
1✔
179
            n_samples_test = num_of_samples - n_samples_train - n_samples_val
1✔
180
            check(n_samples_train > 0, ValueError, f'The number of train samples {n_samples_train} must be greater than 0.')
1✔
181
            total_data = defaultdict(list)
1✔
182
            for data in dataset:
1✔
183
                for k, v in self._data[data].items():
1✔
184
                    total_data[k].append(v)
1✔
185
            total_data = {key: np.concatenate(arrays, dtype=NP_DTYPE) for key, arrays in total_data.items()}
1✔
186
            for key, samples in total_data.items():
1✔
187
                if val_size == 0.0 and test_size == 0.0:  ## we have only training set
1✔
188
                    XY_train[key] = torch.from_numpy(samples).to(TORCH_DTYPE)
1✔
189
                elif val_size == 0.0 and test_size != 0.0:  ## we have only training and test set
1✔
190
                    XY_train[key] = torch.from_numpy(samples[:n_samples_train]).to(TORCH_DTYPE)
1✔
191
                    XY_test[key] = torch.from_numpy(samples[n_samples_train:]).to(TORCH_DTYPE)
1✔
192
                elif val_size != 0.0 and test_size == 0.0:  ## we have only training and validation set
1✔
193
                    XY_train[key] = torch.from_numpy(samples[:n_samples_train]).to(TORCH_DTYPE)
1✔
194
                    XY_val[key] = torch.from_numpy(samples[n_samples_train:]).to(TORCH_DTYPE)
1✔
195
                else:  ## we have training, validation and test set
196
                    XY_train[key] = torch.from_numpy(samples[:n_samples_train]).to(TORCH_DTYPE)
1✔
197
                    XY_val[key] = torch.from_numpy(samples[n_samples_train:-n_samples_test]).to(TORCH_DTYPE)
1✔
198
                    XY_test[key] = torch.from_numpy(samples[n_samples_train + n_samples_val:]).to(TORCH_DTYPE)
1✔
199
        return XY_train, XY_val, XY_test
1✔
200

201
    def _get_tag(self, dataset: str | list | dict | None) -> str:
1✔
202
        """
203
        Helper function to get the tag for a dataset.
204
        """
205
        if isinstance(dataset, str):
1✔
206
            return dataset
1✔
207
        elif isinstance(dataset, list):
1✔
208
            return f"{dataset[0]}_{len(dataset)}" if len(dataset) > 1 else f"{dataset[0]}"
1✔
209
        elif isinstance(dataset, dict):
1✔
210
            return "custom_dataset"
1✔
211
        return dataset
1✔
212

213
    def _setup_dataset(self, train_dataset:str|list|dict, validation_dataset:str|list|dict, test_dataset:str|list|dict, dataset:str|list|dict, splits:list):
1✔
214
        if train_dataset is None: ## use the splits
1✔
215
            train_dataset = list(self._data.keys()) if dataset is None else dataset
1✔
216
            return self.__split_dataset(train_dataset, splits)
1✔
217
        else: ## use each dataset
218
            return self._get_data(train_dataset), self._get_data(validation_dataset), self._get_data(test_dataset)
1✔
219

220
    def __check_data_integrity(self, dataset:dict):
1✔
221
        if bool(dataset):
1✔
222
            check(len(set([t.size(0) for t in dataset.values()])) == 1, ValueError, "All the tensors in the dataset must have the same number of samples.")
1✔
223
            #TODO check why is wrong
224
            #check(len([t for t in self._model_def['Inputs'].keys() if t in dataset.keys()]) == len(list(self._model_def['Inputs'].keys())), ValueError, "Some inputs are missing.")
225
            for key, value in dataset.items():
1✔
226
                if key not in self._model_def['Inputs']:
1✔
227
                    log.warning(f"The key '{key}' is not an input of the network. It will be ignored.")
1✔
228
                else:
229
                    check(isinstance(value, torch.Tensor), TypeError, f"The value of the input '{key}' must be a torch.Tensor.")
1✔
230
                    check(value.size(1) == self._model_def['Inputs'][key]['ntot'], ValueError, f"The time size of the input '{key}' is not correct. Expected {self._model_def['Inputs'][key]['ntot']}, got {value.size(1)}.")
1✔
231
                    check(value.size(2) == self._model_def['Inputs'][key]['dim'], ValueError, f"The dimension of the input '{key}' is not correct. Expected {self._model_def['Inputs'][key]['dim']}, got {value.size(2)}.")
1✔
232

233
    def _get_not_mandatory_inputs(self, data, X, non_mandatory_inputs, remaning_indexes, batch_size, step, shuffle = False):
1✔
234
        related_indexes = random.sample(remaning_indexes, batch_size) if shuffle else remaning_indexes[:batch_size]
1✔
235
        for num in related_indexes:
1✔
236
            remaning_indexes.remove(num)
1✔
237
        if step > 0:
1✔
238
            if len(remaning_indexes) >= step:
1✔
239
                step_idxs = random.sample(remaning_indexes, step) if shuffle else remaning_indexes[:step]
1✔
240
                for num in step_idxs:
1✔
241
                    remaning_indexes.remove(num)
1✔
242
            else:
243
                remaning_indexes.clear()
1✔
244
        for key in non_mandatory_inputs:
1✔
245
            if key in data.keys(): ## with data
1✔
246
                X[key] = data[key][related_indexes]
1✔
247
            else:  ## with zeros
248
                window_size = self._input_n_samples[key]
1✔
249
                dim = self._model_def['Inputs'][key]['dim']
1✔
250
                if 'type' in self._model_def['Inputs'][key]:
1✔
251
                    X[key] = torch.zeros(size=(batch_size, window_size, dim), dtype=TORCH_DTYPE, requires_grad=True)
1✔
252
                else:
253
                    X[key] = torch.zeros(size=(batch_size, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
254
                self._states[key] = X[key]
1✔
255
        return related_indexes
1✔
256

257
    def _inference(self, data, n_samples, batch_size, loss_gains, loss_functions,
1✔
258
                    shuffle = False, optimizer = None,
259
                    total_losses = None, A = None, B = None):
260
        if shuffle:
1✔
261
            randomize = torch.randperm(n_samples)
1✔
262
            data = {key: val[randomize] for key, val in data.items()}
1✔
263

264
        ## Initialize the train losses vector
265
        aux_losses = torch.zeros([len(self._model_def['Minimizers']), n_samples // batch_size])
1✔
266
        for idx in range(0, (n_samples - batch_size + 1), batch_size):
1✔
267
            ## Build the input tensor
268
            XY = {key: val[idx:idx + batch_size] for key, val in data.items()}
1✔
269
            for key in self._model_def.recurrentInputs().keys():
1✔
270
                if key not in XY.keys():
1✔
271
                    window_size = self._input_n_samples[key]
1✔
272
                    dim = self._model_def['Inputs'][key]['dim']
1✔
273
                    XY[key] = torch.zeros(size=(batch_size, window_size, dim), dtype=TORCH_DTYPE, requires_grad=True)
1✔
274
            ## Reset gradient
275
            if optimizer:
1✔
276
                optimizer.zero_grad()
1✔
277
            ## Model Forward
278
            out, minimize_out, _, _ = self._model(XY)  ## Forward pass
1✔
279

280
            if self._log_internal:
1✔
281
                internals_dict = {'XY': tensor_to_list(XY), 'out': out, 'param': self._model.all_parameters}
1✔
282

283
            ## Loss Calculation
284
            total_loss = 0
1✔
285
            for ind, (key, value) in enumerate(self._model_def['Minimizers'].items()):
1✔
286
                if A is not None:
1✔
287
                    A[key].append(minimize_out[value['A']].detach().numpy())
1✔
288
                if B is not None:
1✔
289
                    B[key].append(minimize_out[value['B']].detach().numpy())
1✔
290
                loss = loss_functions[key](minimize_out[value['A']], minimize_out[value['B']])
1✔
291
                loss = (loss * loss_gains[key]) if key in loss_gains.keys() else loss
1✔
292
                if total_losses is not None:
1✔
293
                    total_losses[key].append(loss.detach().numpy())
1✔
294
                aux_losses[ind][idx // batch_size] = loss.item()
1✔
295
                total_loss += loss
1✔
296

297
            if self._log_internal:
1✔
298
                self._save_internal('inout_' + str(idx), internals_dict)
1✔
299

300
            ## Gradient step
301
            if optimizer:
1✔
302
                total_loss.backward()
1✔
303
                optimizer.step()
1✔
304
                self.visualizer.showWeightsInTrain(batch=idx // batch_size)
1✔
305

306
        ## return the losses
307
        return aux_losses
1✔
308

309
    def _recurrent_inference(self, data, batch_indexes, batch_size, loss_gains, prediction_samples,
1✔
310
                             step, non_mandatory_inputs, mandatory_inputs, loss_functions,
311
                             shuffle = False, optimizer = None,
312
                             total_losses = None, A = None, B = None, idxs = None):
313
        indexes = copy.deepcopy(batch_indexes)
1✔
314
        aux_losses = torch.zeros([len(self._model_def['Minimizers']), round((len(indexes) + step) / (batch_size + step))])
1✔
315
        X = {}
1✔
316
        batch_idx = 0
1✔
317
        while len(indexes) >= batch_size:
1✔
318
            selected_indexes = self._get_not_mandatory_inputs(data, X, non_mandatory_inputs, indexes, batch_size, step, shuffle)
1✔
319
            horizon_losses = {ind: [] for ind in range(len(self._model_def['Minimizers']))}
1✔
320
            if optimizer:
1✔
321
                optimizer.zero_grad()  ## Reset the gradient
1✔
322

323
            for horizon_idx in range(prediction_samples + 1):
1✔
324
                # Save the indexes
325
                if idxs is not None:
1✔
326
                    idxs[horizon_idx].append([idx + horizon_idx for idx in selected_indexes])
1✔
327
                ## Get data
328
                for key in mandatory_inputs:
1✔
329
                    X[key] = data[key][[idx + horizon_idx for idx in selected_indexes]]
1✔
330
                ## Forward pass
331
                out, minimize_out, out_closed_loop, out_connect = self._model(X)
1✔
332

333
                if self._log_internal:
1✔
334
                    #assert (check_gradient_operations(self._states) == 0)
335
                    #assert (check_gradient_operations(data) == 0)
336
                    internals_dict = {'XY': tensor_to_list(X), 'out': out, 'param': self._model.all_parameters,
1✔
337
                                      'closedLoop': self._model.closed_loop_update, 'connect': self._model.connect_update}
338

339
                ## Loss Calculation
340
                for ind, (key, value) in enumerate(self._model_def['Minimizers'].items()):
1✔
341
                    if A is not None:
1✔
342
                        A[key][horizon_idx].append(minimize_out[value['A']].detach().numpy())
1✔
343
                    if B is not None:
1✔
344
                        B[key][horizon_idx].append(minimize_out[value['B']].detach().numpy())
1✔
345
                    loss = loss_functions[key](minimize_out[value['A']], minimize_out[value['B']])
1✔
346
                    loss = (loss * loss_gains[key]) if key in loss_gains.keys() else loss
1✔
347
                    horizon_losses[ind].append(loss)
1✔
348

349
                ## Update
350
                self._update_state(X, out_closed_loop, out_connect)
1✔
351

352
                if self._log_internal:
1✔
353
                    internals_dict['state'] = self._states
1✔
354
                    self._save_internal('inout_' + str(batch_idx) + '_' + str(horizon_idx), internals_dict)
1✔
355

356
            ## Calculate the total loss
357
            total_loss = 0
1✔
358
            for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
359
                loss = sum(horizon_losses[ind]) / (prediction_samples + 1)
1✔
360
                aux_losses[ind][batch_idx] = loss.item()
1✔
361
                if total_losses is not None:
1✔
362
                    total_losses[key].append(loss.detach().numpy())
1✔
363
                total_loss += loss
1✔
364

365
            ## Gradient Step
366
            if optimizer:
1✔
367
                total_loss.backward()  ## Backpropagate the error
1✔
368
                optimizer.step()
1✔
369
                self.visualizer.showWeightsInTrain(batch=batch_idx)
1✔
370
            batch_idx += 1
1✔
371

372
        ## return the losses
373
        return aux_losses
1✔
374

375
    def _setup_recurrent_variables(self, prediction_samples, closed_loop, connect):
1✔
376
        ## Prediction samples
377
        check(prediction_samples == 'auto' or prediction_samples >= -1, KeyError, "The sample horizon must be positive, -1, 'auto', for disconnect connection!")
1✔
378
        ## Close loop information
379
        for input, output in closed_loop.items():
1✔
380
            check(input in self._model_def['Inputs'], ValueError, f'the tag {input} is not an input variable.')
1✔
381
            check(output in self._model_def['Outputs'], ValueError, f'the tag {output} is not an output of the network')
1✔
382
            log.info(f'Recurrent train: closing the loop between the the input ports {input} and the output ports {output} for {prediction_samples} samples')
1✔
383
            if self._input_ns_forward[input] > 0:
1✔
384
                    log.warning(f"Closed loop on variable '{input}' with sample in the future.")
1✔
385
        ## Connect information
386
        for input, output in connect.items():
1✔
387
            check(input in self._model_def['Inputs'], ValueError, f'the tag {input} is not an input variable.')
1✔
388
            check(output in self._model_def['Outputs'], ValueError, f'the tag {output} is not an output of the network')
1✔
389
            log.info(f'Recurrent train: connecting the input ports {input} with output ports {output} for {prediction_samples} samples')
1✔
390
            if self._input_ns_forward[input] > 0:
1✔
NEW
391
                    log.warning(f"Connect on variable '{input}' with sample in the future.")
×
392
        ## Disable recurrent training if there are no recurrent variables
393
        if len(connect|closed_loop|self._model_def.recurrentInputs()) == 0:
1✔
394
            if type(prediction_samples) is not str and prediction_samples >= 0:
1✔
395
                log.warning(f"The value of the prediction_samples={prediction_samples} but the network has no recurrent variables.")
1✔
396
            prediction_samples = -1
1✔
397
        return prediction_samples
1✔
398

399
    @enforce_types
1✔
400
    def resetStates(self, states:set={}, *, batch:int=1) -> None:
1✔
401
        """
402
        Resets the state of all the recurrent inputs of the network to zero.
403
        Parameters
404
        ----------
405
        states : set, optional
406
            A set of recurrent inputs names to reset. If provided, only those inputs will be resetted.
407
        batch : int, optional
408
            The batch size for the reset states. Default is 1.
409
        """
410
        if states: ## reset only specific states
1✔
411
            for key in states:
1✔
412
                window_size = self._input_n_samples[key]
1✔
413
                dim = self._model_def['Inputs'][key]['dim']
1✔
414
                self._states[key] = torch.zeros(size=(batch, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
415
        else: ## reset all states
416
            self._states = {}
1✔
417
            for key, state in self._model_def.recurrentInputs().items():
1✔
418
                window_size = self._input_n_samples[key]
1✔
419
                dim = state['dim']
1✔
420
                self._states[key] = torch.zeros(size=(batch, window_size, dim), dtype=TORCH_DTYPE, requires_grad=False)
1✔
421

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc