• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 18321197581

07 Oct 2025 05:39PM UTC coverage: 97.731% (+0.05%) from 97.683%
18321197581

push

github

tonegas
Added some test for format in loadData

56 of 56 new or added lines in 2 files covered. (100.0%)

10 existing lines in 3 files now uncovered.

12794 of 13091 relevant lines covered (97.73%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.32
/nnodely/operators/trainer.py
1
import copy, torch, time, inspect
1✔
2

3
from collections.abc import Callable
1✔
4
from functools import wraps
1✔
5

6
from nnodely.basic.modeldef import ModelDef
1✔
7
from nnodely.basic.model import Model
1✔
8
from nnodely.basic.optimizer import Optimizer, SGD, Adam
1✔
9
from nnodely.basic.loss import CustomLoss
1✔
10
from nnodely.operators.network import Network
1✔
11
from nnodely.support.utils import check, enforce_types
1✔
12
from nnodely.basic.relation import Stream
1✔
13
from nnodely.layers.output import Output
1✔
14

15
from nnodely.support.logger import logging, nnLogger
1✔
16
log = nnLogger(__name__, logging.INFO)
1✔
17

18

19
class Trainer(Network):
1✔
20
    def __init__(self):
1✔
21
        check(type(self) is not Trainer, TypeError, "Trainer class cannot be instantiated directly")
1✔
22
        super().__init__()
1✔
23

24
        ## User Parameters
25
        self.running_parameters = {}
1✔
26

27
        # Training Losses
28
        self.__loss_functions = {}
1✔
29

30
        # Optimizer
31
        self.__optimizer = None
1✔
32

33
    @enforce_types
1✔
34
    def addMinimize(self, name:str, streamA:str|Stream|Output, streamB:str|Stream|Output, loss_function:str='mse') -> None:
1✔
35
        """
36
        Adds a minimize loss function to the model.
37

38
        Parameters
39
        ----------
40
        name : str
41
            The name of the cost function.
42
        streamA : Stream
43
            The first relation stream for the minimize operation.
44
        streamB : Stream
45
            The second relation stream for the minimize operation.
46
        loss_function : str, optional
47
            The loss function to use from the ones provided. Default is 'mse'.
48

49
        Example
50
        -------
51
        Example usage:
52
            >>> model.addMinimize('minimize_op', streamA, streamB, loss_function='mse')
53
        """
54
        self._model_def.addMinimize(name, streamA, streamB, loss_function)
1✔
55
        self.visualizer.showaddMinimize(name)
1✔
56

57
    @enforce_types
1✔
58
    def removeMinimize(self, name_list:list|str) -> None:
1✔
59
        """
60
        Removes minimize loss functions using the given list of names.
61

62
        Parameters
63
        ----------
64
        name_list : list of str
65
            The list of minimize operation names to remove.
66

67
        Example
68
        -------
69
        Example usage:
70
            >>> model.removeMinimize(['minimize_op1', 'minimize_op2'])
71
        """
72
        self._model_def.removeMinimize(name_list)
1✔
73

74
    def __preliminary_checks(self, **kwargs):
1✔
75
        check(self._data_loaded, RuntimeError, 'There is no data loaded! The Training will stop.')
1✔
76
        check('Models' in self._model_def.getJson(), RuntimeError, 'There are no models to train. Load a model using the addModel function.')
1✔
77
        check(list(self._model.parameters()), RuntimeError, 'There are no modules with learnable parameters! The Training will stop.')
1✔
78
        if kwargs.get('train_dataset', None) is None:
1✔
79
            check(kwargs.get('validation_dataset', None) is None, ValueError, 'If train_dataset is None, validation_dataset must also be None.')
1✔
80
        for model in kwargs['models']:
1✔
81
            check(model in kwargs['all_models'], ValueError, f'The model {model} is not in the model definition')
1✔
82

83
    def __fill_parameters(func):
1✔
84
        @wraps(func)
1✔
85
        def wrapper(self, *args, **kwargs):
1✔
86
            sig = inspect.signature(func)
1✔
87
            bound = sig.bind(self, *args, **kwargs)
1✔
88
            bound.apply_defaults()
1✔
89
            # Get standard parameters
90
            standard = self._standard_train_parameters
1✔
91
            # Get user_parameters
92
            users = bound.arguments.get('training_params', None)
1✔
93
            # Fill missing (None) arguments
94
            for param in sig.parameters.values():
1✔
95
                if param.name == 'self' or param.name == 'lr' or param.name == 'lr_param':
1✔
96
                    continue
1✔
97
                if bound.arguments.get(param.name, None) is None:
1✔
98
                    if param.name in users.keys():
1✔
99
                        bound.arguments[param.name] = users[param.name]
1✔
100
                    else:
101
                        bound.arguments[param.name] = standard.get(param.name, None)
1✔
102
            return func(**bound.arguments)
1✔
103
        return wrapper
1✔
104

105
    def __initialize_optimizer(self, models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param):
1✔
106
        ## Get models
107
        params_to_train = set()
1✔
108
        for model in models:
1✔
109
            if type(self._model_def['Models']) is dict:
1✔
110
                params_to_train |= set(self._model_def['Models'][model]['Parameters'])
1✔
111
            else:
112
                params_to_train |= set(self._model_def['Parameters'].keys())
1✔
113

114
        # Get the optimizer
115
        if type(optimizer) is str:
1✔
116
            if optimizer == 'SGD':
1✔
117
                optimizer = SGD({}, [])
1✔
118
            elif optimizer == 'Adam':
1✔
119
                optimizer = Adam({}, [])
1✔
120
        else:
121
            optimizer = copy.deepcopy(optimizer)
1✔
122
            check(issubclass(type(optimizer), Optimizer), TypeError, "The optimizer must be an Optimizer or str")
1✔
123

124
        optimizer.set_params_to_train(self._model.all_parameters, params_to_train)
1✔
125

126
        optimizer.add_defaults('lr', self._standard_train_parameters['lr'])
1✔
127

128
        if training_params and 'lr' in training_params:
1✔
129
            optimizer.add_defaults('lr', training_params['lr'])
1✔
130
        if training_params and 'lr_param' in training_params:
1✔
131
            optimizer.add_option_to_params('lr', training_params['lr_param'])
1✔
132

133
        if optimizer_defaults != {}:
1✔
134
            optimizer.set_defaults(optimizer_defaults)
1✔
135
        if optimizer_params != []:
1✔
136
            optimizer.set_params(optimizer_params)
1✔
137

138
        for key, value in add_optimizer_defaults.items():
1✔
139
            optimizer.add_defaults(key, value)
1✔
140

141
        add_optimizer_params = optimizer.unfold(add_optimizer_params)
1✔
142
        for param in add_optimizer_params:
1✔
143
            par = param['params']
1✔
144
            del param['params']
1✔
145
            for key, value in param.items():
1✔
146
                optimizer.add_option_to_params(key, {par: value})
1✔
147

148
        # Modify the parameter
149
        optimizer.add_defaults('lr', lr)
1✔
150
        if lr_param:
1✔
151
            optimizer.add_option_to_params('lr', lr_param)
1✔
152

153
        self.__optimizer = optimizer
1✔
154

155
    def __initialize_loss(self):
1✔
156
        for name, values in self._model_def['Minimizers'].items():
1✔
157
            self.__loss_functions[name] = CustomLoss(values['loss'])
1✔
158

159
    def getTrainingInfo(self):
1✔
160
        """
161
        Returns a dictionary with the training parameters and information.
162
        Parameters
163
        ----------
164
        **kwargs : dict
165
            Additional parameters to include in the training information.
166
        Returns
167
        -------
168
        dict
169
            A dictionary containing the training parameters and information.
170
        """
171
        to_remove =  ['XY_train','XY_val','XY_test','train_indexes','val_indexes','test_indexes']
1✔
172
        tp = copy.deepcopy({key:value for key, value in self.running_parameters.items() if key not in to_remove})
1✔
173

174
        ## training
175
        tp['update_per_epochs'] = len(self.running_parameters['train_indexes']) // (tp['train_batch_size'] + tp['step'])
1✔
176
        if tp['prediction_samples'] >= 0: # TODO
1✔
177
            tp['n_first_samples_train'] = len(self.running_parameters['train_indexes'])
1✔
178
            if tp['n_samples_val'] > 0:
1✔
179
                tp['n_first_samples_val'] = len(self.running_parameters['val_indexes'])
1✔
180
            if tp['n_samples_test'] > 0:
1✔
181
                tp['n_first_samples_test'] = len(self.running_parameters['test_indexes'])
1✔
182

183

184
        ## optimizer
185
        tp['optimizer'] = self.__optimizer.name
1✔
186
        tp['optimizer_defaults'] = self.__optimizer.optimizer_defaults
1✔
187
        tp['optimizer_params'] = self.__optimizer.optimizer_params
1✔
188

189
        ## early stopping
190
        early_stopping = tp['early_stopping']
1✔
191
        if early_stopping:
1✔
192
            tp['early_stopping'] = early_stopping.__name__
×
193

194
        ## Loss functions
195
        tp['minimizers'] = {}
1✔
196
        for name, values in self._model_def['Minimizers'].items():
1✔
197
            tp['minimizers'][name] = {}
1✔
198
            tp['minimizers'][name]['A'] = values['A']
1✔
199
            tp['minimizers'][name]['B'] = values['B']
1✔
200
            tp['minimizers'][name]['loss'] = values['loss']
1✔
201
            if name in tp['minimize_gain']:
1✔
202
                tp['minimizers'][name]['gain'] = tp['minimize_gain'][name]
×
203

204
        return tp
1✔
205

206
    def __check_needed_keys(self, train_data, connect, closed_loop):
1✔
207
        # Needed keys
208
        keys = set(self._model_def['Inputs'].keys())
1✔
209
        keys |= ({value['A'] for value in self._model_def['Minimizers'].values()} | {value['B'] for value in  self._model_def['Minimizers'].values()})
1✔
210
        # Available keys
211
        keys -= set(self._model_def['Outputs'].keys()|self._model_def['Relations'].keys())
1✔
212
        keys -= set(self._model_def.recurrentInputs().keys())
1✔
213
        keys -= (set(connect.keys()|closed_loop.keys()))
1✔
214
        # Check if the keys are in the dataset
215
        check(set(keys).issubset(set(train_data.keys())), KeyError, f"Not all the mandatory keys {keys} are present in the training dataset {set(train_data.keys())}.")
1✔
216

217
    @enforce_types
1✔
218
    @__fill_parameters
1✔
219
    def trainModel(self, *,
1✔
220
                   name: str | None = None,
221
                   models: str | list | None = None,
222
                   train_dataset: str | list | dict | None = None, validation_dataset: str | list | dict | None = None,
223
                   dataset: str | list | None = None, splits: list | None = None,
224
                   closed_loop: dict | None = None, connect: dict | None = None, step: int | None = None, prediction_samples: int | None = None,
225
                   shuffle_data: bool | None = None,
226
                   early_stopping: Callable | None = None, early_stopping_params: dict | None = None,
227
                   select_model: Callable | None = None, select_model_params: dict | None = None,
228
                   minimize_gain: dict | None = None,
229
                   num_of_epochs: int = None,
230
                   train_batch_size: int = None, val_batch_size: int = None,
231
                   optimizer: str | Optimizer | None = None,
232
                   lr: int | float | None = None, lr_param: dict | None = None,
233
                   optimizer_params: list | None = None, optimizer_defaults: dict | None = None,
234
                   add_optimizer_params: list | None = None, add_optimizer_defaults: dict | None = None,
235
                   training_params: dict | None = {}
236
                   ) -> None:
237
        """
238
        Trains the model using the provided datasets and parameters.
239

240
        Notes
241
        -----
242
        .. note::
243
            If no datasets are provided, the model will use all the datasets loaded inside nnodely.
244

245
        Parameters
246
        ----------
247
        name : str or None, optional
248
            A name used to identify the training operation.
249
        models : str or list or None, optional
250
            A list or name of models to train. Default is all the models loaded.
251
        train_dataset : str or None, optional
252
            The name of datasets to use for training.
253
        validation_dataset : str or None, optional
254
            The name of datasets to use for validation.
255
        dataset : str or None, optional
256
            The name of the datasets to use for training, validation and test.
257
        splits : list or None, optional
258
            A list of 3 elements specifying the percentage of splits for training, validation, and testing. The three elements must sum up to 100! default is [100, 0, 0]
259
            The parameter splits is only used when 'dataset' is not None.
260
        closed_loop : dict or None, optional
261
            A dictionary specifying closed loop connections. The keys are input names and the values are output names. Default is None.
262
        connect : dict or None, optional
263
            A dictionary specifying connections. The keys are input names and the values are output names. Default is None.
264
        step : int or None, optional
265
            The step size for training. A big value will result in less data used for each epochs and a faster train. Default is zero.
266
        prediction_samples : int or None, optional
267
            The size of the prediction horizon. Number of samples at each recurrent window Default is zero.
268
        shuffle_data : bool or None, optional
269
            Whether to shuffle the data during training. Default is True.
270
        early_stopping : Callable or None, optional
271
            A callable for early stopping. Default is None.
272
        early_stopping_params : dict or None, optional
273
            A dictionary of parameters for early stopping. Default is None.
274
        select_model : Callable or None, optional
275
            A callable for selecting the best model. Default is None.
276
        select_model_params : dict or None, optional
277
            A dictionary of parameters for selecting the best model. Default is None.
278
        minimize_gain : dict or None, optional
279
            A dictionary specifying the gain for each minimization loss function. Default is None.
280
        num_of_epochs : int or None, optional
281
            The number of epochs to train the model. Default is 100.
282
        train_batch_size : int or None, optional
283
            The batch size for training. Default is 128.
284
        val_batch_size : int or None, optional
285
            The batch size for validation. Default is 128.
286
        optimizer : Optimizer or None, optional
287
            The optimizer to use for training. Default is 'Adam'.
288
        lr : float or None, optional
289
            The learning rate. Default is 0.001
290
        lr_param : dict or None, optional
291
            A dictionary of learning rate parameters. Default is None.
292
        optimizer_params : list or dict or None, optional
293
            A dictionary of optimizer parameters. Default is None.
294
        optimizer_defaults : dict or None, optional
295
            A dictionary of default optimizer settings. Default is None.
296
        training_params : dict or None, optional
297
            A dictionary of training parameters. Default is None.
298
        add_optimizer_params : list or None, optional
299
            Additional optimizer parameters. Default is None.
300
        add_optimizer_defaults : dict or None, optional
301
            Additional default optimizer settings. Default is None.
302

303
        Examples
304
        --------
305
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
306
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/training.ipynb
307
            :alt: Open in Colab
308

309
        Example - basic feed-forward training:
310
            >>> x = Input('x')
311
            >>> F = Input('F')
312

313
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
314

315
            >>> mass_spring_damper = Modely(seed=0)
316
            >>> mass_spring_damper.addModel('xk1',xk1)
317
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
318

319
            >>> data_struct = ['time','x','dx','F']
320
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
321
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
322

323
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
324
            >>> mass_spring_damper.trainModel(splits=[70,20,10], training_params = params)
325

326
        Example - recurrent training:
327
            >>> x = Input('x')
328
            >>> F = Input('F')
329

330
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
331

332
            >>> mass_spring_damper = Modely(seed=0)
333
            >>> mass_spring_damper.addModel('xk1',xk1)
334
            >>> mass_spring_damper.addClosedLoop(xk1, x)
335
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
336

337
            >>> data_struct = ['time','x','dx','F']
338
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
339
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
340

341
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
342
            >>> mass_spring_damper.trainModel(splits=[70,20,10], prediction_samples=10, training_params = params)
343
        """
344
        ## Get model for train
345
        all_models = list(self._model_def['Models'].keys()) if type(self._model_def['Models']) is dict else [self._model_def['Models']]
1✔
346
        if models is None:
1✔
347
            models = all_models
1✔
348
        if isinstance(models, str):
1✔
349
            models = [models]
1✔
350

351
        ## Preliminary Checks
352
        self.__preliminary_checks(models = models, all_models = all_models, train_dataset = train_dataset, validation_dataset = validation_dataset)
1✔
353

354
        ## Recurret variables
355
        prediction_samples = self._setup_recurrent_variables(prediction_samples, closed_loop, connect)
1✔
356

357
        ## Get the dataset
358
        XY_train, XY_val, XY_test = self._setup_dataset(train_dataset, validation_dataset, None, dataset, splits)
1✔
359
        self.__check_needed_keys(train_data=XY_train, connect=connect, closed_loop=closed_loop)
1✔
360

361
        n_samples_train = next(iter(XY_train.values())).size(0)
1✔
362
        n_samples_val = next(iter(XY_val.values())).size(0) if XY_val else 0
1✔
363
        n_samples_test = next(iter(XY_test.values())).size(0) if XY_test else 0
1✔
364

365
        if train_dataset is not None:
1✔
366
            train_tag = self._get_tag(train_dataset)
1✔
367
            val_tag = self._get_tag(validation_dataset)
1✔
368
        else: ## splits is used
369
            if dataset is None:
1✔
370
                dataset = list(self._data.keys())
1✔
371
            tag = self._get_tag(dataset)
1✔
372
            train_tag = f"{tag}_train"
1✔
373
            val_tag = f"{tag}_val" if n_samples_val > 0 else None
1✔
374
            test_tag = f"{tag}_test" if n_samples_test > 0 else None
1✔
375

376
        train_indexes, val_indexes = [], []
1✔
377
        if train_dataset is not None:
1✔
378
            train_indexes, val_indexes = self._get_batch_indexes(train_dataset, n_samples_train, prediction_samples), self._get_batch_indexes(validation_dataset, n_samples_val, prediction_samples)
1✔
379
        else:
380
            dataset = list(self._data.keys()) if dataset is None else dataset
1✔
381
            train_indexes = self._get_batch_indexes(dataset, n_samples_train, prediction_samples)
1✔
382
            check(len(train_indexes) > 0, ValueError,
1✔
383
                  'The number of valid train samples is less than the number of prediction samples.')
384
            if n_samples_val > 0:
1✔
385
                val_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val, prediction_samples)
1✔
386
                val_indexes = [i - n_samples_train for i in val_indexes if i >= n_samples_train]
1✔
387
                if len(val_indexes) < 0:
1✔
388
                    log.warning('The number of valid validation samples is less than the number of prediction samples.')
×
389
            if n_samples_test > 0:
1✔
390
                test_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val + n_samples_test, prediction_samples)
1✔
391
                test_indexes = [i - (n_samples_train+n_samples_val)for i in test_indexes if i >= (n_samples_train+n_samples_val)]
1✔
392
                if len(test_indexes) < 0:
1✔
393
                    log.warning('The number of valid test samples is less than the number of prediction samples.')
×
394

395
        ## clip batch size and step
396
        train_batch_size = self._clip_batch_size(len(train_indexes), train_batch_size)
1✔
397
        train_step = self._clip_step(step, train_indexes, train_batch_size)
1✔
398
        if n_samples_val > 0:
1✔
399
            val_batch_size = self._clip_batch_size(len(val_indexes), val_batch_size)
1✔
400
            val_step = self._clip_step(step, val_indexes, val_batch_size)
1✔
401

402
        ## Save the training parameters
403
        self.running_parameters = {key:value for key,value in locals().items() if key not in ['self', 'kwargs', 'training_params', 'lr', 'lr_param']}
1✔
404

405
        ## Define the optimizer
406
        self.__initialize_optimizer(models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param)
1✔
407
        torch_optimizer = self.__optimizer.get_torch_optimizer()
1✔
408

409
        ## Define the loss functions
410
        self.__initialize_loss()
1✔
411

412
        ## Define mandatory inputs
413
        mandatory_inputs, non_mandatory_inputs = self._get_mandatory_inputs(connect, closed_loop)
1✔
414

415
        ## Check close loop and connect
416
        self._clean_log_internal()
1✔
417

418
        ## Create the train, validation and test loss dictionaries
419
        train_losses, val_losses = {}, {}
1✔
420
        for key in self._model_def['Minimizers'].keys():
1✔
421
            train_losses[key] = []
1✔
422
            if n_samples_val > 0:
1✔
423
                val_losses[key] = []
1✔
424

425
        ## Set the gradient to true if necessary
426
        model_inputs = self._model_def['Inputs']
1✔
427
        for key in model_inputs.keys():
1✔
428
            if 'type' in model_inputs[key]:
1✔
429
                if key in XY_train:
1✔
430
                    XY_train[key].requires_grad_(True)
1✔
431
                if key in XY_val:
1✔
432
                    XY_val[key].requires_grad_(True)
1✔
433
        selected_model_def = ModelDef(self._model_def.getJson())
1✔
434

435
        ## Show the training parameters
436
        self.visualizer.showTrainParams()
1✔
437
        self.visualizer.showStartTraining()
1✔
438

439
        ## Update with virtual states
440
        if prediction_samples >= 0:
1✔
441
            self._model.update(closed_loop=closed_loop, connect=connect)
1✔
442
        else:
443
            self._model.update(disconnect=True)
1✔
444

445
        self.resetStates()  ## Reset the states
1✔
446

447
        ## start the train timer
448
        start = time.time()
1✔
449
        for epoch in range(num_of_epochs):
1✔
450
            ## TRAIN
451
            self._model.train()
1✔
452
            if prediction_samples >= 0:
1✔
453
                losses = self._recurrent_inference(XY_train, train_indexes, train_batch_size, minimize_gain, prediction_samples, train_step, non_mandatory_inputs, mandatory_inputs, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
454
            else:
455
                losses = self._inference(XY_train, n_samples_train, train_batch_size, minimize_gain, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
456
            ## save the losses
457
            for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
458
                train_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
459

460
            if n_samples_val > 0:
1✔
461
                ## VALIDATION
462
                self._model.eval()
1✔
463
                setted_log_internal = self._log_internal
1✔
464
                self._set_log_internal(False)  # TODO To remove when the function is moved outside the train
1✔
465
                if prediction_samples >= 0:
1✔
466
                    losses = self._recurrent_inference(XY_val, val_indexes, val_batch_size, minimize_gain, prediction_samples, val_step,
1✔
467
                                                       non_mandatory_inputs, mandatory_inputs, self.__loss_functions)
468
                else:
469
                    losses = self._inference(XY_val, n_samples_val, val_batch_size, minimize_gain, self.__loss_functions)
1✔
470
                self._set_log_internal(setted_log_internal)
1✔
471

472
                ## save the losses
473
                for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
474
                    val_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
475

476
            if callable(select_model):
1✔
UNCOV
477
                if select_model(train_losses, val_losses, select_model_params):
×
UNCOV
478
                    best_model_epoch = epoch
×
UNCOV
479
                    selected_model_def.updateParameters(self._model)
×
480

481
            ## Early-stopping
482
            if callable(early_stopping):
1✔
483
                if early_stopping(train_losses, val_losses, early_stopping_params):
1✔
484
                    log.info(f'Stopping the training at epoch {epoch} due to early stopping.')
1✔
485
                    break
1✔
486

487
            ## Visualize the training...
488
            self.visualizer.showTraining(epoch, train_losses, val_losses)
1✔
489
            self.visualizer.showWeightsInTrain(epoch=epoch)
1✔
490

491
        ## Visualize the training time
492
        end = time.time()
1✔
493
        self.visualizer.showTrainingTime(end - start)
1✔
494

495
        for key in self._model_def['Minimizers'].keys():
1✔
496
            self._training[key] = {'train': train_losses[key]}
1✔
497
            if n_samples_val > 0:
1✔
498
                self._training[key]['val'] = val_losses[key]
1✔
499
        self.visualizer.showEndTraining(num_of_epochs - 1, train_losses, val_losses)
1✔
500

501
        ## Select the model
502
        if callable(select_model):
1✔
503
            if not val_losses:
×
504
                # The model selected is updated for the last time by the final batch;
505
                # so the minimum loss (selected model) is referred to the model before the last update.
506
                # If the batch is small compared to the dataset dimension the differences in the model are small.
UNCOV
507
                log.warning('If not validation set is provided the selected model can differ from the optimal.')
×
UNCOV
508
            log.info(f'Selected the model at the epoch {best_model_epoch + 1}.')
×
UNCOV
509
            self._model = Model(selected_model_def)
×
510
        else:
511
            log.info('The selected model is the LAST model of the training.')
1✔
512

513
        ## Remove virtual states
514
        self._remove_virtual_states(connect, closed_loop)
1✔
515

516
        ## Get trained model from torch and set the model_def
517
        self._model_def.updateParameters(self._model)
1✔
518

519
#from 685
520
#from 840
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc