• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 17272281697

27 Aug 2025 04:09PM UTC coverage: 97.727% (-0.04%) from 97.767%
17272281697

push

github

tonegas
minor chages

8 of 12 new or added lines in 2 files covered. (66.67%)

14 existing lines in 4 files now uncovered.

12727 of 13023 relevant lines covered (97.73%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.14
/nnodely/operators/trainer.py
1
import copy, torch, time, inspect
1✔
2

3
from collections.abc import Callable
1✔
4
from functools import wraps
1✔
5

6
from nnodely.basic.modeldef import ModelDef
1✔
7
from nnodely.basic.model import Model
1✔
8
from nnodely.basic.optimizer import Optimizer, SGD, Adam
1✔
9
from nnodely.basic.loss import CustomLoss
1✔
10
from nnodely.operators.network import Network
1✔
11
from nnodely.support.utils import check, enforce_types
1✔
12
from nnodely.basic.relation import Stream
1✔
13
from nnodely.layers.output import Output
1✔
14

15
from nnodely.support.logger import logging, nnLogger
1✔
16
log = nnLogger(__name__, logging.CRITICAL)
1✔
17

18

19
class Trainer(Network):
1✔
20
    def __init__(self):
1✔
21
        check(type(self) is not Trainer, TypeError, "Trainer class cannot be instantiated directly")
1✔
22
        super().__init__()
1✔
23

24
        ## User Parameters
25
        self.running_parameters = {}
1✔
26

27
        # Training Losses
28
        self.__loss_functions = {}
1✔
29

30
        # Optimizer
31
        self.__optimizer = None
1✔
32

33
    @enforce_types
1✔
34
    def addMinimize(self, name:str, streamA:str|Stream|Output, streamB:str|Stream|Output, loss_function:str='mse') -> None:
1✔
35
        """
36
        Adds a minimize loss function to the model.
37

38
        Parameters
39
        ----------
40
        name : str
41
            The name of the cost function.
42
        streamA : Stream
43
            The first relation stream for the minimize operation.
44
        streamB : Stream
45
            The second relation stream for the minimize operation.
46
        loss_function : str, optional
47
            The loss function to use from the ones provided. Default is 'mse'.
48

49
        Example
50
        -------
51
        Example usage:
52
            >>> model.addMinimize('minimize_op', streamA, streamB, loss_function='mse')
53
        """
54
        self._model_def.addMinimize(name, streamA, streamB, loss_function)
1✔
55
        self.visualizer.showaddMinimize(name)
1✔
56

57
    @enforce_types
1✔
58
    def removeMinimize(self, name_list:list|str) -> None:
1✔
59
        """
60
        Removes minimize loss functions using the given list of names.
61

62
        Parameters
63
        ----------
64
        name_list : list of str
65
            The list of minimize operation names to remove.
66

67
        Example
68
        -------
69
        Example usage:
70
            >>> model.removeMinimize(['minimize_op1', 'minimize_op2'])
71
        """
72
        self._model_def.removeMinimize(name_list)
1✔
73

74
    def __preliminary_checks(self, **kwargs):
1✔
75
        check(self._data_loaded, RuntimeError, 'There is no data loaded! The Training will stop.')
1✔
76
        check('Models' in self._model_def.getJson(), RuntimeError, 'There are no models to train. Load a model using the addModel function.')
1✔
77
        check(list(self._model.parameters()), RuntimeError, 'There are no modules with learnable parameters! The Training will stop.')
1✔
78
        if kwargs.get('train_dataset', None) is None:
1✔
79
            check(kwargs.get('validation_dataset', None) is None, ValueError, 'If train_dataset is None, validation_dataset must also be None.')
1✔
80
        for model in kwargs['models']:
1✔
81
            check(model in kwargs['all_models'], ValueError, f'The model {model} is not in the model definition')
1✔
82

83
    def __fill_parameters(func):
1✔
84
        @wraps(func)
1✔
85
        def wrapper(self, *args, **kwargs):
1✔
86
            sig = inspect.signature(func)
1✔
87
            bound = sig.bind(self, *args, **kwargs)
1✔
88
            bound.apply_defaults()
1✔
89
            # Get standard parameters
90
            standard = self._standard_train_parameters
1✔
91
            # Get user_parameters
92
            users = bound.arguments.get('training_params', None)
1✔
93
            # Fill missing (None) arguments
94
            for param in sig.parameters.values():
1✔
95
                if param.name == 'self' or param.name == 'lr' or param.name == 'lr_param':
1✔
96
                    continue
1✔
97
                if bound.arguments.get(param.name, None) is None:
1✔
98
                    if param.name in users.keys():
1✔
99
                        bound.arguments[param.name] = users[param.name]
1✔
100
                    else:
101
                        bound.arguments[param.name] = standard.get(param.name, None)
1✔
102
            return func(**bound.arguments)
1✔
103
        return wrapper
1✔
104

105
    def __initialize_optimizer(self, models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param):
1✔
106
        ## Get models
107
        params_to_train = set()
1✔
108
        for model in models:
1✔
109
            if type(self._model_def['Models']) is dict:
1✔
110
                params_to_train |= set(self._model_def['Models'][model]['Parameters'])
1✔
111
            else:
112
                params_to_train |= set(self._model_def['Parameters'].keys())
1✔
113

114
        # Get the optimizer
115
        if type(optimizer) is str:
1✔
116
            if optimizer == 'SGD':
1✔
117
                optimizer = SGD({}, [])
1✔
118
            elif optimizer == 'Adam':
1✔
119
                optimizer = Adam({}, [])
1✔
120
        else:
121
            optimizer = copy.deepcopy(optimizer)
1✔
122
            check(issubclass(type(optimizer), Optimizer), TypeError, "The optimizer must be an Optimizer or str")
1✔
123

124
        optimizer.set_params_to_train(self._model.all_parameters, params_to_train)
1✔
125

126
        optimizer.add_defaults('lr', self._standard_train_parameters['lr'])
1✔
127

128
        if training_params and 'lr' in training_params:
1✔
129
            optimizer.add_defaults('lr', training_params['lr'])
1✔
130
        if training_params and 'lr_param' in training_params:
1✔
131
            optimizer.add_option_to_params('lr', training_params['lr_param'])
1✔
132

133
        if optimizer_defaults != {}:
1✔
134
            optimizer.set_defaults(optimizer_defaults)
1✔
135
        if optimizer_params != []:
1✔
136
            optimizer.set_params(optimizer_params)
1✔
137

138
        for key, value in add_optimizer_defaults.items():
1✔
139
            optimizer.add_defaults(key, value)
1✔
140

141
        add_optimizer_params = optimizer.unfold(add_optimizer_params)
1✔
142
        for param in add_optimizer_params:
1✔
143
            par = param['params']
1✔
144
            del param['params']
1✔
145
            for key, value in param.items():
1✔
146
                optimizer.add_option_to_params(key, {par: value})
1✔
147

148
        # Modify the parameter
149
        optimizer.add_defaults('lr', lr)
1✔
150
        if lr_param:
1✔
151
            optimizer.add_option_to_params('lr', lr_param)
1✔
152

153
        self.__optimizer = optimizer
1✔
154

155
    def __initialize_loss(self):
1✔
156
        for name, values in self._model_def['Minimizers'].items():
1✔
157
            self.__loss_functions[name] = CustomLoss(values['loss'])
1✔
158

159
    def getTrainingInfo(self):
1✔
160
        """
161
        Returns a dictionary with the training parameters and information.
162
        Parameters
163
        ----------
164
        **kwargs : dict
165
            Additional parameters to include in the training information.
166
        Returns
167
        -------
168
        dict
169
            A dictionary containing the training parameters and information.
170
        """
171
        to_remove =  ['XY_train','XY_val','XY_test','train_indexes','val_indexes','test_indexes']
1✔
172
        tp = copy.deepcopy({key:value for key, value in self.running_parameters.items() if key not in to_remove})
1✔
173

174
        ## training
175
        tp['update_per_epochs'] = len(self.running_parameters['train_indexes']) // (tp['train_batch_size'] + tp['step'])
1✔
176
        if tp['prediction_samples'] >= 0: # TODO
1✔
177
            tp['n_first_samples_train'] = len(self.running_parameters['train_indexes'])
1✔
178
            if tp['n_samples_val'] > 0:
1✔
179
                tp['n_first_samples_val'] = len(self.running_parameters['val_indexes'])
1✔
180
            if tp['n_samples_test'] > 0:
1✔
181
                tp['n_first_samples_test'] = len(self.running_parameters['test_indexes'])
1✔
182

183

184
        ## optimizer
185
        tp['optimizer'] = self.__optimizer.name
1✔
186
        tp['optimizer_defaults'] = self.__optimizer.optimizer_defaults
1✔
187
        tp['optimizer_params'] = self.__optimizer.optimizer_params
1✔
188

189
        ## early stopping
190
        early_stopping = tp['early_stopping']
1✔
191
        if early_stopping:
1✔
UNCOV
192
            tp['early_stopping'] = early_stopping.__name__
×
193

194
        ## Loss functions
195
        tp['minimizers'] = {}
1✔
196
        for name, values in self._model_def['Minimizers'].items():
1✔
197
            tp['minimizers'][name] = {}
1✔
198
            tp['minimizers'][name]['A'] = values['A']
1✔
199
            tp['minimizers'][name]['B'] = values['B']
1✔
200
            tp['minimizers'][name]['loss'] = values['loss']
1✔
201
            if name in tp['minimize_gain']:
1✔
UNCOV
202
                tp['minimizers'][name]['gain'] = tp['minimize_gain'][name]
×
203

204
        return tp
1✔
205

206
    def __check_needed_keys(self, train_data, connect, closed_loop):
1✔
207
        # Needed keys
208
        keys = set(self._model_def['Inputs'].keys())
1✔
209
        keys |= ({value['A'] for value in self._model_def['Minimizers'].values()} | {value['B'] for value in  self._model_def['Minimizers'].values()})
1✔
210
        # Available keys
211
        keys -= set(self._model_def['Outputs'].keys()|self._model_def['Relations'].keys())
1✔
212
        keys -= set(self._model_def.recurrentInputs().keys())
1✔
213
        keys -= (set(connect.keys()|closed_loop.keys()))
1✔
214
        # Check if the keys are in the dataset
215
        check(set(keys).issubset(set(train_data.keys())), KeyError, f"Not all the mandatory keys {keys} are present in the training dataset {set(train_data.keys())}.")
1✔
216

217
    @enforce_types
1✔
218
    @__fill_parameters
1✔
219
    def trainModel(self, *,
1✔
220
                   name: str | None = None,
221
                   models: str | list | None = None,
222
                   train_dataset: str | list | dict | None = None, validation_dataset: str | list | dict | None = None,
223
                   dataset: str | list | None = None, splits: list | None = None,
224
                   closed_loop: dict | None = None, connect: dict | None = None, step: int | None = None, prediction_samples: int | None = None,
225
                   shuffle_data: bool | None = None,
226
                   early_stopping: Callable | None = None, early_stopping_params: dict | None = None,
227
                   select_model: Callable | None = None, select_model_params: dict | None = None,
228
                   minimize_gain: dict | None = None,
229
                   num_of_epochs: int = None,
230
                   train_batch_size: int = None, val_batch_size: int = None,
231
                   optimizer: str | Optimizer | None = None,
232
                   lr: int | float | None = None, lr_param: dict | None = None,
233
                   optimizer_params: list | None = None, optimizer_defaults: dict | None = None,
234
                   add_optimizer_params: list | None = None, add_optimizer_defaults: dict | None = None,
235
                   training_params: dict | None = {}
236
                   ) -> None:
237
        """
238
        Trains the model using the provided datasets and parameters.
239

240
        Notes
241
        -----
242
        .. note::
243
            If no datasets are provided, the model will use all the datasets loaded inside nnodely.
244

245
        Parameters
246
        ----------
247
        name : str or None, optional
248
            A name used to identify the training operation.
249
        models : str or list or None, optional
250
            A list or name of models to train. Default is all the models loaded.
251
        train_dataset : str or None, optional
252
            The name of datasets to use for training.
253
        validation_dataset : str or None, optional
254
            The name of datasets to use for validation.
255
        dataset : str or None, optional
256
            The name of the datasets to use for training, validation and test.
257
        splits : list or None, optional
258
            A list of 3 elements specifying the percentage of splits for training, validation, and testing. The three elements must sum up to 100! default is [100, 0, 0]
259
            The parameter splits is only used when 'dataset' is not None.
260
        closed_loop : dict or None, optional
261
            A dictionary specifying closed loop connections. The keys are input names and the values are output names. Default is None.
262
        connect : dict or None, optional
263
            A dictionary specifying connections. The keys are input names and the values are output names. Default is None.
264
        step : int or None, optional
265
            The step size for training. A big value will result in less data used for each epochs and a faster train. Default is zero.
266
        prediction_samples : int or None, optional
267
            The size of the prediction horizon. Number of samples at each recurrent window Default is zero.
268
        shuffle_data : bool or None, optional
269
            Whether to shuffle the data during training. Default is True.
270
        early_stopping : Callable or None, optional
271
            A callable for early stopping. Default is None.
272
        early_stopping_params : dict or None, optional
273
            A dictionary of parameters for early stopping. Default is None.
274
        select_model : Callable or None, optional
275
            A callable for selecting the best model. Default is None.
276
        select_model_params : dict or None, optional
277
            A dictionary of parameters for selecting the best model. Default is None.
278
        minimize_gain : dict or None, optional
279
            A dictionary specifying the gain for each minimization loss function. Default is None.
280
        num_of_epochs : int or None, optional
281
            The number of epochs to train the model. Default is 100.
282
        train_batch_size : int or None, optional
283
            The batch size for training. Default is 128.
284
        val_batch_size : int or None, optional
285
            The batch size for validation. Default is 128.
286
        optimizer : Optimizer or None, optional
287
            The optimizer to use for training. Default is 'Adam'.
288
        lr : float or None, optional
289
            The learning rate. Default is 0.001
290
        lr_param : dict or None, optional
291
            A dictionary of learning rate parameters. Default is None.
292
        optimizer_params : list or dict or None, optional
293
            A dictionary of optimizer parameters. Default is None.
294
        optimizer_defaults : dict or None, optional
295
            A dictionary of default optimizer settings. Default is None.
296
        training_params : dict or None, optional
297
            A dictionary of training parameters. Default is None.
298
        add_optimizer_params : list or None, optional
299
            Additional optimizer parameters. Default is None.
300
        add_optimizer_defaults : dict or None, optional
301
            Additional default optimizer settings. Default is None.
302

303
        Examples
304
        --------
305
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
306
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/training.ipynb
307
            :alt: Open in Colab
308

309
        Example - basic feed-forward training:
310
            >>> x = Input('x')
311
            >>> F = Input('F')
312

313
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
314

315
            >>> mass_spring_damper = Modely(seed=0)
316
            >>> mass_spring_damper.addModel('xk1',xk1)
317
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
318

319
            >>> data_struct = ['time','x','dx','F']
320
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
321
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
322

323
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
324
            >>> mass_spring_damper.trainModel(splits=[70,20,10], training_params = params)
325

326
        Example - recurrent training:
327
            >>> x = Input('x')
328
            >>> F = Input('F')
329

330
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
331

332
            >>> mass_spring_damper = Modely(seed=0)
333
            >>> mass_spring_damper.addModel('xk1',xk1)
334
            >>> mass_spring_damper.addClosedLoop(xk1, x)
335
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
336

337
            >>> data_struct = ['time','x','dx','F']
338
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
339
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
340

341
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
342
            >>> mass_spring_damper.trainModel(splits=[70,20,10], prediction_samples=10, training_params = params)
343
        """
344
        ## Get model for train
345
        all_models = list(self._model_def['Models'].keys()) if type(self._model_def['Models']) is dict else [self._model_def['Models']]
1✔
346
        if models is None:
1✔
347
            models = all_models
1✔
348
        if isinstance(models, str):
1✔
349
            models = [models]
1✔
350

351
        ## Preliminary Checks
352
        self.__preliminary_checks(models = models, all_models = all_models, train_dataset = train_dataset, validation_dataset = validation_dataset)
1✔
353

354
        ## Recurret variables
355
        prediction_samples = self._setup_recurrent_variables(prediction_samples, closed_loop, connect)
1✔
356

357
        ## Get the dataset
358
        XY_train, XY_val, XY_test = self._setup_dataset(train_dataset, validation_dataset, None, dataset, splits)
1✔
359
        self.__check_needed_keys(train_data=XY_train, connect=connect, closed_loop=closed_loop)
1✔
360

361
        n_samples_train = next(iter(XY_train.values())).size(0)
1✔
362
        n_samples_val = next(iter(XY_val.values())).size(0) if XY_val else 0
1✔
363
        n_samples_test = next(iter(XY_test.values())).size(0) if XY_test else 0
1✔
364

365
        if train_dataset is not None:
1✔
366
            train_tag = self._get_tag(train_dataset)
1✔
367
            val_tag = self._get_tag(validation_dataset)
1✔
368
        else: ## splits is used
369
            if dataset is None:
1✔
370
                dataset = list(self._data.keys())
1✔
371
            tag = self._get_tag(dataset)
1✔
372
            train_tag = f"{tag}_train"
1✔
373
            val_tag = f"{tag}_val" if n_samples_val > 0 else None
1✔
374
            test_tag = f"{tag}_test" if n_samples_test > 0 else None
1✔
375

376
        train_indexes, val_indexes = [], []
1✔
377
        if train_dataset is not None:
1✔
378
            train_indexes, val_indexes = self._get_batch_indexes(train_dataset, n_samples_train, prediction_samples), self._get_batch_indexes(validation_dataset, n_samples_val, prediction_samples)
1✔
379
        else:
380
            dataset = list(self._data.keys()) if dataset is None else dataset
1✔
381
            train_indexes = self._get_batch_indexes(dataset, n_samples_train, prediction_samples)
1✔
382
            check(len(train_indexes) > 0, ValueError,
1✔
383
                  'The number of valid train samples is less than the number of prediction samples.')
384
            if n_samples_val > 0:
1✔
385
                val_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val, prediction_samples)
1✔
386
                val_indexes = [i - n_samples_train for i in val_indexes if i >= n_samples_train]
1✔
387
                if len(val_indexes) < 0:
1✔
UNCOV
388
                    log.warning('The number of valid validation samples is less than the number of prediction samples.')
×
389
            if n_samples_test > 0:
1✔
390
                test_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val + n_samples_test, prediction_samples)
1✔
391
                test_indexes = [i - (n_samples_train+n_samples_val)for i in test_indexes if i >= (n_samples_train+n_samples_val)]
1✔
392
                if len(test_indexes) < 0:
1✔
UNCOV
393
                    log.warning('The number of valid test samples is less than the number of prediction samples.')
×
394

395
        ## clip batch size and step
396
        train_batch_size = self._clip_batch_size(len(train_indexes), train_batch_size)
1✔
397
        train_step = self._clip_step(step, train_indexes, train_batch_size)
1✔
398
        if n_samples_val > 0:
1✔
399
            val_batch_size = self._clip_batch_size(len(val_indexes), val_batch_size)
1✔
400
            val_step = self._clip_step(step, val_indexes, val_batch_size)
1✔
401

402
        ## Save the training parameters
403
        self.running_parameters = {key:value for key,value in locals().items() if key not in ['self', 'kwargs', 'training_params', 'lr', 'lr_param']}
1✔
404

405
        ## Define the optimizer
406
        self.__initialize_optimizer(models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param)
1✔
407
        torch_optimizer = self.__optimizer.get_torch_optimizer()
1✔
408

409
        ## Define the loss functions
410
        self.__initialize_loss()
1✔
411

412
        ## Define mandatory inputs
413
        mandatory_inputs, non_mandatory_inputs = self._get_mandatory_inputs(connect, closed_loop)
1✔
414

415
        ## Check close loop and connect
416
        self._clean_log_internal()
1✔
417

418
        ## Create the train, validation and test loss dictionaries
419
        train_losses, val_losses = {}, {}
1✔
420
        for key in self._model_def['Minimizers'].keys():
1✔
421
            train_losses[key] = []
1✔
422
            if n_samples_val > 0:
1✔
423
                val_losses[key] = []
1✔
424

425
        ## Set the gradient to true if necessary
426
        model_inputs = self._model_def['Inputs']
1✔
427
        for key in model_inputs.keys():
1✔
428
            if 'type' in model_inputs[key]:
1✔
429
                if key in XY_train:
1✔
430
                    XY_train[key].requires_grad_(True)
1✔
431
                if key in XY_val:
1✔
432
                    XY_val[key].requires_grad_(True)
1✔
433
        selected_model_def = ModelDef(self._model_def.getJson())
1✔
434

435
        ## Show the training parameters
436
        self.visualizer.showTrainParams()
1✔
437
        self.visualizer.showStartTraining()
1✔
438

439
        ## Update with virtual states
440
        if prediction_samples >= 0:
1✔
441
            self._model.update(closed_loop=closed_loop, connect=connect)
1✔
442
        else:
443
            self._model.update(disconnect=True)
1✔
444

445
        self.resetStates()  ## Reset the states
1✔
446

447
        ## start the train timer
448
        start = time.time()
1✔
449
        for epoch in range(num_of_epochs):
1✔
450
            ## TRAIN
451
            self._model.train()
1✔
452
            if prediction_samples >= 0:
1✔
453
                losses = self._recurrent_inference(XY_train, train_indexes, train_batch_size, minimize_gain, prediction_samples, train_step, non_mandatory_inputs, mandatory_inputs, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
454
            else:
455
                losses = self._inference(XY_train, n_samples_train, train_batch_size, minimize_gain, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
456
            ## save the losses
457
            for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
458
                train_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
459

460
            if n_samples_val > 0:
1✔
461
                ## VALIDATION
462
                self._model.eval()
1✔
463
                setted_log_internal = self._log_internal
1✔
464
                self._set_log_internal(False)  # TODO To remove when the function is moved outside the train
1✔
465
                if prediction_samples >= 0:
1✔
466
                    losses = self._recurrent_inference(XY_val, val_indexes, val_batch_size, minimize_gain, prediction_samples, val_step,
1✔
467
                                                       non_mandatory_inputs, mandatory_inputs, self.__loss_functions)
468
                else:
469
                    losses = self._inference(XY_val, n_samples_val, val_batch_size, minimize_gain, self.__loss_functions)
1✔
470
                self._set_log_internal(setted_log_internal)
1✔
471

472
                ## save the losses
473
                for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
474
                    val_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
475

476
            ## Early-stopping
477
            if callable(early_stopping):
1✔
478
                if early_stopping(train_losses, val_losses, early_stopping_params):
1✔
479
                    log.info(f'Stopping the training at epoch {epoch} due to early stopping.')
1✔
480
                    break
1✔
481

482
            if callable(select_model):
1✔
483
                if select_model(train_losses, val_losses, select_model_params):
×
484
                    best_model_epoch = epoch
×
UNCOV
485
                    selected_model_def.updateParameters(self._model)
×
486

487
            ## Visualize the training...
488
            self.visualizer.showTraining(epoch, train_losses, val_losses)
1✔
489
            self.visualizer.showWeightsInTrain(epoch=epoch)
1✔
490

491
        ## Visualize the training time
492
        end = time.time()
1✔
493
        self.visualizer.showTrainingTime(end - start)
1✔
494

495
        for key in self._model_def['Minimizers'].keys():
1✔
496
            self._training[key] = {'train': train_losses[key]}
1✔
497
            if n_samples_val > 0:
1✔
498
                self._training[key]['val'] = val_losses[key]
1✔
499
        self.visualizer.showEndTraining(num_of_epochs - 1, train_losses, val_losses)
1✔
500

501
        ## Select the model
502
        if callable(select_model):
1✔
503
            log.info(f'Selected the model at the epoch {best_model_epoch + 1}.')
×
UNCOV
504
            self._model = Model(selected_model_def)
×
505
        else:
506
            log.info('The selected model is the LAST model of the training.')
1✔
507

508
        ## Remove virtual states
509
        self._remove_virtual_states(connect, closed_loop)
1✔
510

511
        ## Get trained model from torch and set the model_def
512
        self._model_def.updateParameters(self._model)
1✔
513

514
#from 685
515
#from 840
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc