• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 16502811447

24 Jul 2025 04:44PM UTC coverage: 97.767% (+0.1%) from 97.651%
16502811447

push

github

web-flow
New version 1.5.0

This pull request introduces version 1.5.0 of **nnodely**, featuring several updates:
1. Improved clarity of documentation and examples.
2. Support for managing multi-dataset features is now available.
3. DataFrames can now be used to create datasets.
4. Datasets can now be resampled.
5. Random data training has been fixed for both classic and recurrent training.
6. The `state` variable has been removed.
7. It is now possible to add or remove a connection or a closed loop.
8. Partial models can now be exported.
9. The `train` function and the result analysis have been separated.
10. A new function, `trainAndAnalyse`, is now available.
11. The report now works across all network types.
12. The training function code has been reorganized.

2901 of 2967 new or added lines in 53 files covered. (97.78%)

16 existing lines in 6 files now uncovered.

12652 of 12941 relevant lines covered (97.77%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.01
/nnodely/operators/trainer.py
1
import copy, torch, time, inspect
1✔
2

3
from collections.abc import Callable
1✔
4
from functools import wraps
1✔
5

6
from nnodely.basic.modeldef import ModelDef
1✔
7
from nnodely.basic.model import Model
1✔
8
from nnodely.basic.optimizer import Optimizer, SGD, Adam
1✔
9
from nnodely.basic.loss import CustomLoss
1✔
10
from nnodely.operators.network import Network
1✔
11
from nnodely.support.utils import check, enforce_types
1✔
12
from nnodely.basic.relation import Stream
1✔
13
from nnodely.layers.output import Output
1✔
14

15
from nnodely.support.logger import logging, nnLogger
1✔
16
log = nnLogger(__name__, logging.CRITICAL)
1✔
17

18
class Trainer(Network):
1✔
19
    def __init__(self):
1✔
20
        check(type(self) is not Trainer, TypeError, "Trainer class cannot be instantiated directly")
1✔
21
        super().__init__()
1✔
22

23
        ## User Parameters
24
        self.running_parameters = {}
1✔
25

26
        # Training Losses
27
        self.__loss_functions = {}
1✔
28

29
        # Optimizer
30
        self.__optimizer = None
1✔
31

32
    @enforce_types
1✔
33
    def addMinimize(self, name:str, streamA:Stream|Output, streamB:Stream|Output, loss_function:str='mse') -> None:
1✔
34
        """
35
        Adds a minimize loss function to the model.
36

37
        Parameters
38
        ----------
39
        name : str
40
            The name of the cost function.
41
        streamA : Stream
42
            The first relation stream for the minimize operation.
43
        streamB : Stream
44
            The second relation stream for the minimize operation.
45
        loss_function : str, optional
46
            The loss function to use from the ones provided. Default is 'mse'.
47

48
        Example
49
        -------
50
        Example usage:
51
            >>> model.addMinimize('minimize_op', streamA, streamB, loss_function='mse')
52
        """
53
        self._model_def.addMinimize(name, streamA, streamB, loss_function)
1✔
54
        self.visualizer.showaddMinimize(name)
1✔
55

56
    @enforce_types
1✔
57
    def removeMinimize(self, name_list:list|str) -> None:
1✔
58
        """
59
        Removes minimize loss functions using the given list of names.
60

61
        Parameters
62
        ----------
63
        name_list : list of str
64
            The list of minimize operation names to remove.
65

66
        Example
67
        -------
68
        Example usage:
69
            >>> model.removeMinimize(['minimize_op1', 'minimize_op2'])
70
        """
71
        self._model_def.removeMinimize(name_list)
1✔
72

73
    def __preliminary_checks(self, **kwargs):
1✔
74
        check(self._data_loaded, RuntimeError, 'There is no data loaded! The Training will stop.')
1✔
75
        check('Models' in self._model_def.getJson(), RuntimeError, 'There are no models to train. Load a model using the addModel function.')
1✔
76
        check(list(self._model.parameters()), RuntimeError, 'There are no modules with learnable parameters! The Training will stop.')
1✔
77
        if kwargs.get('train_dataset', None) is None:
1✔
78
            check(kwargs.get('validation_dataset', None) is None, ValueError, 'If train_dataset is None, validation_dataset must also be None.')
1✔
79
        for model in kwargs['models']:
1✔
80
            check(model in kwargs['all_models'], ValueError, f'The model {model} is not in the model definition')
1✔
81

82
    def fill_parameters(func):
1✔
83
        @wraps(func)
1✔
84
        def wrapper(self, *args, **kwargs):
1✔
85
            sig = inspect.signature(func)
1✔
86
            bound = sig.bind(self, *args, **kwargs)
1✔
87
            bound.apply_defaults()
1✔
88
            # Get standard parameters
89
            standard = self._standard_train_parameters
1✔
90
            # Get user_parameters
91
            users = bound.arguments.get('training_params', None)
1✔
92
            # Fill missing (None) arguments
93
            for param in sig.parameters.values():
1✔
94
                if param.name == 'self' or param.name == 'lr' or param.name == 'lr_param':
1✔
95
                    continue
1✔
96
                if bound.arguments.get(param.name, None) is None:
1✔
97
                    if param.name in users.keys():
1✔
98
                        bound.arguments[param.name] = users[param.name]
1✔
99
                    else:
100
                        bound.arguments[param.name] = standard.get(param.name, None)
1✔
101
            return func(**bound.arguments)
1✔
102
        return wrapper
1✔
103

104
    def __initialize_optimizer(self, models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param):
1✔
105
        ## Get models
106
        params_to_train = set()
1✔
107
        for model in models:
1✔
108
            if type(self._model_def['Models']) is dict:
1✔
109
                params_to_train |= set(self._model_def['Models'][model]['Parameters'])
1✔
110
            else:
111
                params_to_train |= set(self._model_def['Parameters'].keys())
1✔
112

113
        # Get the optimizer
114
        if type(optimizer) is str:
1✔
115
            if optimizer == 'SGD':
1✔
116
                optimizer = SGD({}, [])
1✔
117
            elif optimizer == 'Adam':
1✔
118
                optimizer = Adam({}, [])
1✔
119
        else:
120
            optimizer = copy.deepcopy(optimizer)
1✔
121
            check(issubclass(type(optimizer), Optimizer), TypeError, "The optimizer must be an Optimizer or str")
1✔
122

123
        optimizer.set_params_to_train(self._model.all_parameters, params_to_train)
1✔
124

125
        optimizer.add_defaults('lr', self._standard_train_parameters['lr'])
1✔
126

127
        if training_params and 'lr' in training_params:
1✔
128
            optimizer.add_defaults('lr', training_params['lr'])
1✔
129
        if training_params and 'lr_param' in training_params:
1✔
130
            optimizer.add_option_to_params('lr', training_params['lr_param'])
1✔
131

132
        if optimizer_defaults != {}:
1✔
133
            optimizer.set_defaults(optimizer_defaults)
1✔
134
        if optimizer_params != []:
1✔
135
            optimizer.set_params(optimizer_params)
1✔
136

137
        for key, value in add_optimizer_defaults.items():
1✔
138
            optimizer.add_defaults(key, value)
1✔
139

140
        add_optimizer_params = optimizer.unfold(add_optimizer_params)
1✔
141
        for param in add_optimizer_params:
1✔
142
            par = param['params']
1✔
143
            del param['params']
1✔
144
            for key, value in param.items():
1✔
145
                optimizer.add_option_to_params(key, {par: value})
1✔
146

147
        # Modify the parameter
148
        optimizer.add_defaults('lr', lr)
1✔
149
        if lr_param:
1✔
150
            optimizer.add_option_to_params('lr', lr_param)
1✔
151

152
        self.__optimizer = optimizer
1✔
153

154
    def __initialize_loss(self):
1✔
155
        for name, values in self._model_def['Minimizers'].items():
1✔
156
            self.__loss_functions[name] = CustomLoss(values['loss'])
1✔
157

158
    def get_training_info(self):
1✔
159
        """
160
        Returns a dictionary with the training parameters and information.
161
        Parameters
162
        ----------
163
        **kwargs : dict
164
            Additional parameters to include in the training information.
165
        Returns
166
        -------
167
        dict
168
            A dictionary containing the training parameters and information.
169
        """
170
        to_remove =  ['XY_train','XY_val','XY_test','train_indexes','val_indexes','test_indexes']
1✔
171
        tp = copy.deepcopy({key:value for key, value in self.running_parameters.items() if key not in to_remove})
1✔
172

173
        ## training
174
        tp['update_per_epochs'] = len(self.running_parameters['train_indexes']) // (tp['train_batch_size'] + tp['step'])
1✔
175
        if tp['prediction_samples'] >= 0: # TODO
1✔
176
            tp['n_first_samples_train'] = len(self.running_parameters['train_indexes'])
1✔
177
            if tp['n_samples_val'] > 0:
1✔
178
                tp['n_first_samples_val'] = len(self.running_parameters['val_indexes'])
1✔
179
            if tp['n_samples_test'] > 0:
1✔
180
                tp['n_first_samples_test'] = len(self.running_parameters['test_indexes'])
1✔
181

182

183
        ## optimizer
184
        tp['optimizer'] = self.__optimizer.name
1✔
185
        tp['optimizer_defaults'] = self.__optimizer.optimizer_defaults
1✔
186
        tp['optimizer_params'] = self.__optimizer.optimizer_params
1✔
187

188
        ## early stopping
189
        early_stopping = tp['early_stopping']
1✔
190
        if early_stopping:
1✔
191
            tp['early_stopping'] = early_stopping.__name__
1✔
192

193
        ## Loss functions
194
        tp['minimizers'] = {}
1✔
195
        for name, values in self._model_def['Minimizers'].items():
1✔
196
            tp['minimizers'][name] = {}
1✔
197
            tp['minimizers'][name]['A'] = values['A']
1✔
198
            tp['minimizers'][name]['B'] = values['B']
1✔
199
            tp['minimizers'][name]['loss'] = values['loss']
1✔
200
            if name in tp['minimize_gain']:
1✔
201
                tp['minimizers'][name]['gain'] = tp['minimize_gain'][name]
1✔
202

203
        return tp
1✔
204

205
    def __check_needed_keys(self, train_data, connect, closed_loop):
1✔
206
        # Needed keys
207
        keys = set(self._model_def['Inputs'].keys())
1✔
208
        keys |= ({value['A'] for value in self._model_def['Minimizers'].values()} | {value['B'] for value in  self._model_def['Minimizers'].values()})
1✔
209
        # Available keys
210
        keys -= set(self._model_def['Outputs'].keys()|self._model_def['Relations'].keys())
1✔
211
        keys -= set(self._model_def.recurrentInputs().keys())
1✔
212
        keys -= (set(connect.keys()|closed_loop.keys()))
1✔
213
        # Check if the keys are in the dataset
214
        check(set(keys).issubset(set(train_data.keys())), KeyError, f"Not all the mandatory keys {keys} are present in the training dataset {set(train_data.keys())}.")
1✔
215

216
    @enforce_types
1✔
217
    @fill_parameters
1✔
218
    def trainModel(self, *,
1✔
219
                   name: str | None = None,
220
                   models: str | list | None = None,
221
                   train_dataset: str | list | dict | None = None, validation_dataset: str | list | dict | None = None,
222
                   dataset: str | list | None = None, splits: list | None = None,
223
                   closed_loop: dict | None = None, connect: dict | None = None, step: int | None = None, prediction_samples: int | None = None,
224
                   shuffle_data: bool | None = None,
225
                   early_stopping: Callable | None = None, early_stopping_params: dict | None = None,
226
                   select_model: Callable | None = None, select_model_params: dict | None = None,
227
                   minimize_gain: dict | None = None,
228
                   num_of_epochs: int = None,
229
                   train_batch_size: int = None, val_batch_size: int = None,
230
                   optimizer: str | Optimizer | None = None,
231
                   lr: int | float | None = None, lr_param: dict | None = None,
232
                   optimizer_params: list | None = None, optimizer_defaults: dict | None = None,
233
                   add_optimizer_params: list | None = None, add_optimizer_defaults: dict | None = None,
234
                   training_params: dict | None = {}
235
                   ) -> None:
236
        """
237
        Trains the model using the provided datasets and parameters.
238

239
        Notes
240
        -----
241
        .. note::
242
            If no datasets are provided, the model will use all the datasets loaded inside nnodely.
243

244
        Parameters
245
        ----------
246
        name : str or None, optional
247
            A name used to identify the training operation.
248
        models : str or list or None, optional
249
            A list or name of models to train. Default is all the models loaded.
250
        train_dataset : str or None, optional
251
            The name of datasets to use for training.
252
        validation_dataset : str or None, optional
253
            The name of datasets to use for validation.
254
        dataset : str or None, optional
255
            The name of the datasets to use for training, validation and test.
256
        splits : list or None, optional
257
            A list of 3 elements specifying the percentage of splits for training, validation, and testing. The three elements must sum up to 100! default is [100, 0, 0]
258
            The parameter splits is only used when 'dataset' is not None.
259
        closed_loop : dict or None, optional
260
            A dictionary specifying closed loop connections. The keys are input names and the values are output names. Default is None.
261
        connect : dict or None, optional
262
            A dictionary specifying connections. The keys are input names and the values are output names. Default is None.
263
        step : int or None, optional
264
            The step size for training. A big value will result in less data used for each epochs and a faster train. Default is zero.
265
        prediction_samples : int or None, optional
266
            The size of the prediction horizon. Number of samples at each recurrent window Default is zero.
267
        shuffle_data : bool or None, optional
268
            Whether to shuffle the data during training. Default is True.
269
        early_stopping : Callable or None, optional
270
            A callable for early stopping. Default is None.
271
        early_stopping_params : dict or None, optional
272
            A dictionary of parameters for early stopping. Default is None.
273
        select_model : Callable or None, optional
274
            A callable for selecting the best model. Default is None.
275
        select_model_params : dict or None, optional
276
            A dictionary of parameters for selecting the best model. Default is None.
277
        minimize_gain : dict or None, optional
278
            A dictionary specifying the gain for each minimization loss function. Default is None.
279
        num_of_epochs : int or None, optional
280
            The number of epochs to train the model. Default is 100.
281
        train_batch_size : int or None, optional
282
            The batch size for training. Default is 128.
283
        val_batch_size : int or None, optional
284
            The batch size for validation. Default is 128.
285
        optimizer : Optimizer or None, optional
286
            The optimizer to use for training. Default is 'Adam'.
287
        lr : float or None, optional
288
            The learning rate. Default is 0.001
289
        lr_param : dict or None, optional
290
            A dictionary of learning rate parameters. Default is None.
291
        optimizer_params : list or dict or None, optional
292
            A dictionary of optimizer parameters. Default is None.
293
        optimizer_defaults : dict or None, optional
294
            A dictionary of default optimizer settings. Default is None.
295
        training_params : dict or None, optional
296
            A dictionary of training parameters. Default is None.
297
        add_optimizer_params : list or None, optional
298
            Additional optimizer parameters. Default is None.
299
        add_optimizer_defaults : dict or None, optional
300
            Additional default optimizer settings. Default is None.
301

302
        Examples
303
        --------
304
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
305
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/training.ipynb
306
            :alt: Open in Colab
307

308
        Example - basic feed-forward training:
309
            >>> x = Input('x')
310
            >>> F = Input('F')
311

312
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
313

314
            >>> mass_spring_damper = Modely(seed=0)
315
            >>> mass_spring_damper.addModel('xk1',xk1)
316
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
317

318
            >>> data_struct = ['time','x','dx','F']
319
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
320
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
321

322
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
323
            >>> mass_spring_damper.trainModel(splits=[70,20,10], training_params = params)
324

325
        Example - recurrent training:
326
            >>> x = Input('x')
327
            >>> F = Input('F')
328

329
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
330

331
            >>> mass_spring_damper = Modely(seed=0)
332
            >>> mass_spring_damper.addModel('xk1',xk1)
333
            >>> mass_spring_damper.addClosedLoop(xk1, x)
334
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
335

336
            >>> data_struct = ['time','x','dx','F']
337
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
338
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
339

340
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
341
            >>> mass_spring_damper.trainModel(splits=[70,20,10], prediction_samples=10, training_params = params)
342
        """
343
        ## Get model for train
344
        all_models = list(self._model_def['Models'].keys()) if type(self._model_def['Models']) is dict else [self._model_def['Models']]
1✔
345
        if models is None:
1✔
346
            models = all_models
1✔
347
        if isinstance(models, str):
1✔
348
            models = [models]
1✔
349

350
        ## Preliminary Checks
351
        self.__preliminary_checks(models = models, all_models = all_models, train_dataset = train_dataset, validation_dataset = validation_dataset)
1✔
352

353
        ## Recurret variables
354
        prediction_samples = self._setup_recurrent_variables(prediction_samples, closed_loop, connect)
1✔
355

356
        ## Get the dataset
357
        XY_train, XY_val, XY_test = self._setup_dataset(train_dataset, validation_dataset, None, dataset, splits)
1✔
358
        self.__check_needed_keys(train_data=XY_train, connect=connect, closed_loop=closed_loop)
1✔
359

360
        n_samples_train = next(iter(XY_train.values())).size(0)
1✔
361
        n_samples_val = next(iter(XY_val.values())).size(0) if XY_val else 0
1✔
362
        n_samples_test = next(iter(XY_test.values())).size(0) if XY_test else 0
1✔
363

364
        if train_dataset is not None:
1✔
365
            train_tag = self._get_tag(train_dataset)
1✔
366
            val_tag = self._get_tag(validation_dataset)
1✔
367
        else: ## splits is used
368
            if dataset is None:
1✔
369
                dataset = list(self._data.keys())
1✔
370
            tag = self._get_tag(dataset)
1✔
371
            train_tag = f"{tag}_train"
1✔
372
            val_tag = f"{tag}_val" if n_samples_val > 0 else None
1✔
373
            test_tag = f"{tag}_test" if n_samples_test > 0 else None
1✔
374

375
        train_indexes, val_indexes = [], []
1✔
376
        if train_dataset is not None:
1✔
377
            train_indexes, val_indexes = self._get_batch_indexes(train_dataset, n_samples_train, prediction_samples), self._get_batch_indexes(validation_dataset, n_samples_val, prediction_samples)
1✔
378
        else:
379
            dataset = list(self._data.keys()) if dataset is None else dataset
1✔
380
            train_indexes = self._get_batch_indexes(dataset, n_samples_train, prediction_samples)
1✔
381
            check(len(train_indexes) > 0, ValueError,
1✔
382
                  'The number of valid train samples is less than the number of prediction samples.')
383
            if n_samples_val > 0:
1✔
384
                val_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val, prediction_samples)
1✔
385
                val_indexes = [i - n_samples_train for i in val_indexes if i >= n_samples_train]
1✔
386
                if len(val_indexes) < 0:
1✔
NEW
387
                    log.warning('The number of valid validation samples is less than the number of prediction samples.')
×
388
            if n_samples_test > 0:
1✔
389
                test_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val + n_samples_test, prediction_samples)
1✔
390
                test_indexes = [i - (n_samples_train+n_samples_val)for i in test_indexes if i >= (n_samples_train+n_samples_val)]
1✔
391
                if len(test_indexes) < 0:
1✔
NEW
392
                    log.warning('The number of valid test samples is less than the number of prediction samples.')
×
393

394
        ## clip batch size and step
395
        train_batch_size = self._clip_batch_size(len(train_indexes), train_batch_size)
1✔
396
        train_step = self._clip_step(step, train_indexes, train_batch_size)
1✔
397
        if n_samples_val > 0:
1✔
398
            val_batch_size = self._clip_batch_size(len(val_indexes), val_batch_size)
1✔
399
            val_step = self._clip_step(step, val_indexes, val_batch_size)
1✔
400

401
        ## Save the training parameters
402
        self.running_parameters = {key:value for key,value in locals().items() if key not in ['self', 'kwargs', 'training_params', 'lr', 'lr_param']}
1✔
403

404
        ## Define the optimizer
405
        self.__initialize_optimizer(models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param)
1✔
406
        torch_optimizer = self.__optimizer.get_torch_optimizer()
1✔
407

408
        ## Define the loss functions
409
        self.__initialize_loss()
1✔
410

411
        ## Define mandatory inputs
412
        mandatory_inputs, non_mandatory_inputs = self._get_mandatory_inputs(connect, closed_loop)
1✔
413

414
        ## Check close loop and connect
415
        self._clean_log_internal()
1✔
416

417
        ## Create the train, validation and test loss dictionaries
418
        train_losses, val_losses = {}, {}
1✔
419
        for key in self._model_def['Minimizers'].keys():
1✔
420
            train_losses[key] = []
1✔
421
            if n_samples_val > 0:
1✔
422
                val_losses[key] = []
1✔
423

424
        ## Set the gradient to true if necessary
425
        model_inputs = self._model_def['Inputs']
1✔
426
        for key in model_inputs.keys():
1✔
427
            if 'type' in model_inputs[key]:
1✔
428
                if key in XY_train:
1✔
429
                    XY_train[key].requires_grad_(True)
1✔
430
                if key in XY_val:
1✔
431
                    XY_val[key].requires_grad_(True)
1✔
432
        selected_model_def = ModelDef(self._model_def.getJson())
1✔
433

434
        ## Show the training parameters
435
        self.visualizer.showTrainParams()
1✔
436
        self.visualizer.showStartTraining()
1✔
437

438
        ## Update with virtual states
439
        if prediction_samples >= 0:
1✔
440
            self._model.update(closed_loop=closed_loop, connect=connect)
1✔
441
        else:
442
            self._model.update(disconnect=True)
1✔
443

444
        self.resetStates()  ## Reset the states
1✔
445

446
        ## start the train timer
447
        start = time.time()
1✔
448
        for epoch in range(num_of_epochs):
1✔
449
            ## TRAIN
450
            self._model.train()
1✔
451
            if prediction_samples >= 0:
1✔
452
                losses = self._recurrent_inference(XY_train, train_indexes, train_batch_size, minimize_gain, prediction_samples, train_step, non_mandatory_inputs, mandatory_inputs, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
453
            else:
454
                losses = self._inference(XY_train, n_samples_train, train_batch_size, minimize_gain, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
455
            ## save the losses
456
            for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
457
                train_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
458

459
            if n_samples_val > 0:
1✔
460
                ## VALIDATION
461
                self._model.eval()
1✔
462
                setted_log_internal = self._log_internal
1✔
463
                self._set_log_internal(False)  # TODO To remove when the function is moved outside the train
1✔
464
                if prediction_samples >= 0:
1✔
465
                    losses = self._recurrent_inference(XY_val, val_indexes, val_batch_size, minimize_gain, prediction_samples, val_step,
1✔
466
                                                       non_mandatory_inputs, mandatory_inputs, self.__loss_functions)
467
                else:
468
                    losses = self._inference(XY_val, n_samples_val, val_batch_size, minimize_gain, self.__loss_functions)
1✔
469
                self._set_log_internal(setted_log_internal)
1✔
470

471
                ## save the losses
472
                for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
473
                    val_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
474

475
            ## Early-stopping
476
            if callable(early_stopping):
1✔
477
                if early_stopping(train_losses, val_losses, early_stopping_params):
1✔
478
                    log.info(f'Stopping the training at epoch {epoch} due to early stopping.')
1✔
479
                    break
1✔
480

481
            if callable(select_model):
1✔
482
                if select_model(train_losses, val_losses, select_model_params):
×
483
                    best_model_epoch = epoch
×
484
                    selected_model_def.updateParameters(self._model)
×
485

486
            ## Visualize the training...
487
            self.visualizer.showTraining(epoch, train_losses, val_losses)
1✔
488
            self.visualizer.showWeightsInTrain(epoch=epoch)
1✔
489

490
        ## Visualize the training time
491
        end = time.time()
1✔
492
        self.visualizer.showTrainingTime(end - start)
1✔
493

494
        for key in self._model_def['Minimizers'].keys():
1✔
495
            self._training[key] = {'train': train_losses[key]}
1✔
496
            if n_samples_val > 0:
1✔
497
                self._training[key]['val'] = val_losses[key]
1✔
498
        self.visualizer.showEndTraining(num_of_epochs - 1, train_losses, val_losses)
1✔
499

500
        ## Select the model
501
        if callable(select_model):
1✔
502
            log.info(f'Selected the model at the epoch {best_model_epoch + 1}.')
×
503
            self._model = Model(selected_model_def)
×
504
        else:
505
            log.info('The selected model is the LAST model of the training.')
1✔
506

507
        ## Remove virtual states
508
        self._removeVirtualStates(connect, closed_loop)
1✔
509

510
        ## Get trained model from torch and set the model_def
511
        self._model_def.updateParameters(self._model)
1✔
512
        return self.get_training_info()
1✔
513
#from 685
514
#from 840
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc