• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tonegas / nnodely / 20074796492

09 Dec 2025 06:47PM UTC coverage: 96.807% (-1.0%) from 97.767%
20074796492

Pull #109

github

tonegas
Fixes of tests
Pull Request #109: New version of nnodely

867 of 887 new or added lines in 37 files covered. (97.75%)

157 existing lines in 5 files now uncovered.

13066 of 13497 relevant lines covered (96.81%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.36
/nnodely/operators/trainer.py
1
import copy, torch, time, inspect
1✔
2

3
from collections.abc import Callable
1✔
4
from functools import wraps
1✔
5

6
from nnodely.basic.modeldef import ModelDef
1✔
7
from nnodely.basic.model import Model
1✔
8
from nnodely.basic.optimizer import Optimizer, SGD, Adam
1✔
9
from nnodely.basic.loss import CustomLoss
1✔
10
from nnodely.operators.network import Network
1✔
11
from nnodely.support.utils import check, enforce_types
1✔
12
from nnodely.basic.relation import Stream
1✔
13
from nnodely.layers.output import Output
1✔
14

15
from nnodely.support.logger import logging, nnLogger
1✔
16
log = nnLogger(__name__, logging.INFO)
1✔
17

18

19
class Trainer(Network):
1✔
20
    def __init__(self):
1✔
21
        check(type(self) is not Trainer, TypeError, "Trainer class cannot be instantiated directly")
1✔
22
        super().__init__()
1✔
23

24
        ## User Parameters
25
        self.running_parameters = {}
1✔
26

27
        # Training Losses
28
        self.__loss_functions = {}
1✔
29

30
        # Optimizer
31
        self.__optimizer = None
1✔
32

33
    @enforce_types
1✔
34
    def addMinimize(self, name:str, streamA:str|Stream|Output, streamB:str|Stream|Output, loss_function:str='mse') -> None:
1✔
35
        """
36
        Adds a minimize loss function to the model.
37

38
        Parameters
39
        ----------
40
        name : str
41
            The name of the cost function.
42
        streamA : Stream
43
            The first relation stream for the minimize operation.
44
        streamB : Stream
45
            The second relation stream for the minimize operation.
46
        loss_function : str, optional
47
            The loss function to use from the ones provided. Default is 'mse'.
48

49
        Example
50
        -------
51
        Example usage:
52
            >>> model.addMinimize('minimize_op', streamA, streamB, loss_function='mse')
53
        """
54
        self._model_def.addMinimize(name, streamA, streamB, loss_function)
1✔
55
        self.visualizer.showaddMinimize(name)
1✔
56
        self._neuralized = False
1✔
57

58
    @enforce_types
1✔
59
    def removeMinimize(self, name_list:list|str) -> None:
1✔
60
        """
61
        Removes minimize loss functions using the given list of names.
62

63
        Parameters
64
        ----------
65
        name_list : list of str
66
            The list of minimize operation names to remove.
67

68
        Example
69
        -------
70
        Example usage:
71
            >>> model.removeMinimize(['minimize_op1', 'minimize_op2'])
72
        """
73
        self._model_def.removeMinimize(name_list)
1✔
74
        self._neuralized = False
1✔
75

76
    def __preliminary_checks(self, **kwargs):
1✔
77
        check(self._data_loaded, RuntimeError, 'There is no data loaded! The Training will stop.')
1✔
78
        check('Models' in self._model_def.getJson(), RuntimeError, 'There are no models to train. Load a model using the addModel function.')
1✔
79
        check(list(self._model.parameters()), RuntimeError, 'There are no modules with learnable parameters! The Training will stop.')
1✔
80
        if kwargs.get('train_dataset', None) is None:
1✔
81
            check(kwargs.get('validation_dataset', None) is None, ValueError, 'If train_dataset is None, validation_dataset must also be None.')
1✔
82
        for model in kwargs['models']:
1✔
83
            check(model in kwargs['all_models'], ValueError, f'The model {model} is not in the model definition')
1✔
84

85
    def __fill_parameters(func):
1✔
86
        @wraps(func)
1✔
87
        def wrapper(self, *args, **kwargs):
1✔
88
            sig = inspect.signature(func)
1✔
89
            bound = sig.bind(self, *args, **kwargs)
1✔
90
            bound.apply_defaults()
1✔
91
            # Get standard parameters
92
            standard = self._standard_train_parameters
1✔
93
            # Get user_parameters
94
            users = bound.arguments.get('training_params', None)
1✔
95
            # Fill missing (None) arguments
96
            for param in sig.parameters.values():
1✔
97
                if param.name == 'self' or param.name == 'lr' or param.name == 'lr_param':
1✔
98
                    continue
1✔
99
                if bound.arguments.get(param.name, None) is None:
1✔
100
                    if param.name in users.keys():
1✔
101
                        bound.arguments[param.name] = users[param.name]
1✔
102
                    else:
103
                        bound.arguments[param.name] = standard.get(param.name, None)
1✔
104
            return func(**bound.arguments)
1✔
105
        return wrapper
1✔
106

107
    def __initialize_optimizer(self, models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param):
1✔
108
        ## Get models
109
        params_to_train = set()
1✔
110
        for model in models:
1✔
111
            if type(self._model_def['Models']) is dict:
1✔
112
                params_to_train |= set(self._model_def['Models'][model]['Parameters'])
1✔
113
            else:
114
                params_to_train |= set(self._model_def['Parameters'].keys())
1✔
115

116
        # Get the optimizer
117
        if type(optimizer) is str:
1✔
118
            if optimizer == 'SGD':
1✔
119
                optimizer = SGD({}, [])
1✔
120
            elif optimizer == 'Adam':
1✔
121
                optimizer = Adam({}, [])
1✔
122
        else:
123
            optimizer = copy.deepcopy(optimizer)
1✔
124
            check(issubclass(type(optimizer), Optimizer), TypeError, "The optimizer must be an Optimizer or str")
1✔
125

126
        optimizer.set_params_to_train(self._model.all_parameters, params_to_train)
1✔
127

128
        optimizer.add_defaults('lr', self._standard_train_parameters['lr'])
1✔
129

130
        if training_params and 'lr' in training_params:
1✔
131
            optimizer.add_defaults('lr', training_params['lr'])
1✔
132
        if training_params and 'lr_param' in training_params:
1✔
133
            optimizer.add_option_to_params('lr', training_params['lr_param'])
1✔
134

135
        if optimizer_defaults != {}:
1✔
136
            optimizer.set_defaults(optimizer_defaults)
1✔
137
        if optimizer_params != []:
1✔
138
            optimizer.set_params(optimizer_params)
1✔
139

140
        for key, value in add_optimizer_defaults.items():
1✔
141
            optimizer.add_defaults(key, value)
1✔
142

143
        add_optimizer_params = optimizer.unfold(add_optimizer_params)
1✔
144
        for param in add_optimizer_params:
1✔
145
            par = param['params']
1✔
146
            del param['params']
1✔
147
            for key, value in param.items():
1✔
148
                optimizer.add_option_to_params(key, {par: value})
1✔
149

150
        # Modify the parameter
151
        optimizer.add_defaults('lr', lr)
1✔
152
        if lr_param:
1✔
153
            optimizer.add_option_to_params('lr', lr_param)
1✔
154

155
        self.__optimizer = optimizer
1✔
156

157
    def __initialize_loss(self):
1✔
158
        for name, values in self._model_def['Minimizers'].items():
1✔
159
            self.__loss_functions[name] = CustomLoss(values['loss'])
1✔
160

161
    def getTrainingInfo(self):
1✔
162
        """
163
        Returns a dictionary with the training parameters and information.
164
        Parameters
165
        ----------
166
        **kwargs : dict
167
            Additional parameters to include in the training information.
168
        Returns
169
        -------
170
        dict
171
            A dictionary containing the training parameters and information.
172
        """
173
        to_remove =  ['XY_train','XY_val','XY_test','train_indexes','val_indexes','test_indexes']
1✔
174
        tp = copy.deepcopy({key:value for key, value in self.running_parameters.items() if key not in to_remove})
1✔
175

176
        ## training
177
        tp['update_per_epochs'] = len(self.running_parameters['train_indexes']) // (tp['train_batch_size'] + tp['step'])
1✔
178
        if tp['prediction_samples'] >= 0: # TODO
1✔
179
            tp['n_first_samples_train'] = len(self.running_parameters['train_indexes'])
1✔
180
            if tp['n_samples_val'] > 0:
1✔
181
                tp['n_first_samples_val'] = len(self.running_parameters['val_indexes'])
1✔
182
            if tp['n_samples_test'] > 0:
1✔
183
                tp['n_first_samples_test'] = len(self.running_parameters['test_indexes'])
1✔
184

185

186
        ## optimizer
187
        tp['optimizer'] = self.__optimizer.name
1✔
188
        tp['optimizer_defaults'] = self.__optimizer.optimizer_defaults
1✔
189
        tp['optimizer_params'] = self.__optimizer.optimizer_params
1✔
190

191
        ## early stopping
192
        early_stopping = tp['early_stopping']
1✔
193
        if early_stopping:
1✔
UNCOV
194
            tp['early_stopping'] = early_stopping.__name__
×
195

196
        ## Loss functions
197
        tp['minimizers'] = {}
1✔
198
        for name, values in self._model_def['Minimizers'].items():
1✔
199
            tp['minimizers'][name] = {}
1✔
200
            tp['minimizers'][name]['A'] = values['A']
1✔
201
            tp['minimizers'][name]['B'] = values['B']
1✔
202
            tp['minimizers'][name]['loss'] = values['loss']
1✔
203
            if name in tp['minimize_gain']:
1✔
UNCOV
204
                tp['minimizers'][name]['gain'] = tp['minimize_gain'][name]
×
205

206
        return tp
1✔
207

208
    def __check_needed_keys(self, train_data, connect, closed_loop):
1✔
209
        # Needed keys
210
        keys = set(self._model_def['Inputs'].keys())
1✔
211
        keys |= ({value['A'] for value in self._model_def['Minimizers'].values()} | {value['B'] for value in  self._model_def['Minimizers'].values()})
1✔
212
        # Available keys
213
        keys -= set(self._model_def['Outputs'].keys()|self._model_def['Relations'].keys())
1✔
214
        keys -= set(self._model_def.recurrentInputs().keys())
1✔
215
        keys -= (set(connect.keys()|closed_loop.keys()))
1✔
216
        # Check if the keys are in the dataset
217
        check(set(keys).issubset(set(train_data.keys())), KeyError, f"Not all the mandatory keys {keys} are present in the training dataset {set(train_data.keys())}.")
1✔
218

219
    @enforce_types
1✔
220
    @__fill_parameters
1✔
221
    def trainModel(self, *,
1✔
222
                   name: str | None = None,
223
                   models: str | list | None = None,
224
                   train_dataset: str | list | dict | None = None, validation_dataset: str | list | dict | None = None,
225
                   dataset: str | list | None = None, splits: list | None = None,
226
                   closed_loop: dict | None = None, connect: dict | None = None, step: int | None = None, prediction_samples: int | None = None,
227
                   shuffle_data: bool | None = None,
228
                   early_stopping: Callable | None = None, early_stopping_params: dict | None = None,
229
                   select_model: Callable | None = None, select_model_params: dict | None = None,
230
                   minimize_gain: dict | None = None,
231
                   num_of_epochs: int = None,
232
                   train_batch_size: int = None, val_batch_size: int = None,
233
                   optimizer: str | Optimizer | None = None,
234
                   lr: int | float | None = None, lr_param: dict | None = None,
235
                   optimizer_params: list | None = None, optimizer_defaults: dict | None = None,
236
                   add_optimizer_params: list | None = None, add_optimizer_defaults: dict | None = None,
237
                   training_params: dict | None = {}
238
                   ) -> None:
239
        """
240
        Trains the model using the provided datasets and parameters.
241

242
        Notes
243
        -----
244
        .. note::
245
            If no datasets are provided, the model will use all the datasets loaded inside nnodely.
246

247
        Parameters
248
        ----------
249
        name : str or None, optional
250
            A name used to identify the training operation.
251
        models : str or list or None, optional
252
            A list or name of models to train. Default is all the models loaded.
253
        train_dataset : str or None, optional
254
            The name of datasets to use for training.
255
        validation_dataset : str or None, optional
256
            The name of datasets to use for validation.
257
        dataset : str or None, optional
258
            The name of the datasets to use for training, validation and test.
259
            If dataset is None and train_dataset is None, the model will use all the datasets loaded inside nnodely.
260
        splits : list or None, optional
261
            A list of 3 elements specifying the percentage of splits for training, validation, and testing. The three elements must sum up to 100! default is [100, 0, 0]
262
            The parameter splits is only used when 'dataset' is not None.
263
        closed_loop : dict or None, optional
264
            A dictionary specifying closed loop connections. The keys are input names and the values are output names. Default is None.
265
        connect : dict or None, optional
266
            A dictionary specifying connections. The keys are input names and the values are output names. Default is None.
267
        step : int or None, optional
268
            The step size for training. A big value will result in less data used for each epochs and a faster train. Default is zero.
269
        prediction_samples : int or None, optional
270
            The size of the prediction horizon. Number of samples at each recurrent window Default is zero.
271
        shuffle_data : bool or None, optional
272
            Whether to shuffle the data during training. Default is True.
273
        early_stopping : Callable or None, optional
274
            A callable for early stopping. Default is None.
275
        early_stopping_params : dict or None, optional
276
            A dictionary of parameters for early stopping. Default is None.
277
        select_model : Callable or None, optional
278
            A callable for selecting the best model. Default is None.
279
        select_model_params : dict or None, optional
280
            A dictionary of parameters for selecting the best model. Default is None.
281
        minimize_gain : dict or None, optional
282
            A dictionary specifying the gain for each minimization loss function. Default is None.
283
        num_of_epochs : int or None, optional
284
            The number of epochs to train the model. Default is 100.
285
        train_batch_size : int or None, optional
286
            The batch size for training. Default is 128.
287
        val_batch_size : int or None, optional
288
            The batch size for validation. Default is 128.
289
        optimizer : Optimizer or None, optional
290
            The optimizer to use for training. Default is 'Adam'.
291
        lr : float or None, optional
292
            The learning rate. Default is 0.001
293
        lr_param : dict or None, optional
294
            A dictionary of learning rate parameters. Default is None.
295
        optimizer_params : list or dict or None, optional
296
            A dictionary of optimizer parameters. Default is None.
297
        optimizer_defaults : dict or None, optional
298
            A dictionary of default optimizer settings. Default is None.
299
        training_params : dict or None, optional
300
            A dictionary of training parameters. Default is None.
301
        add_optimizer_params : list or None, optional
302
            Additional optimizer parameters. Default is None.
303
        add_optimizer_defaults : dict or None, optional
304
            Additional default optimizer settings. Default is None.
305

306
        Examples
307
        --------
308
        .. image:: https://colab.research.google.com/assets/colab-badge.svg
309
            :target: https://colab.research.google.com/github/tonegas/nnodely/blob/main/examples/training.ipynb
310
            :alt: Open in Colab
311

312
        Example - basic feed-forward training:
313
            >>> x = Input('x')
314
            >>> F = Input('F')
315

316
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
317

318
            >>> mass_spring_damper = Modely(seed=0)
319
            >>> mass_spring_damper.addModel('xk1',xk1)
320
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
321

322
            >>> data_struct = ['time','x','dx','F']
323
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
324
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
325

326
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
327
            >>> mass_spring_damper.trainModel(splits=[70,20,10], training_params = params)
328

329
        Example - recurrent training:
330
            >>> x = Input('x')
331
            >>> F = Input('F')
332

333
            >>> xk1 = Output('x[k+1]', Fir()(x.tw(0.2))+Fir()(F.last()))
334

335
            >>> mass_spring_damper = Modely(seed=0)
336
            >>> mass_spring_damper.addModel('xk1',xk1)
337
            >>> mass_spring_damper.addClosedLoop(xk1, x)
338
            >>> mass_spring_damper.neuralizeModel(sample_time = 0.05)
339

340
            >>> data_struct = ['time','x','dx','F']
341
            >>> data_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)),'dataset','data')
342
            >>> mass_spring_damper.loadData(name='mass_spring_dataset', source=data_folder, format=data_struct, delimiter=';')
343

344
            >>> params = {'num_of_epochs': 100,'train_batch_size': 128,'lr':0.001}
345
            >>> mass_spring_damper.trainModel(splits=[70,20,10], prediction_samples=10, training_params = params)
346
        """
347
        ## Get model for train
348
        all_models = list(self._model_def['Models'].keys()) if type(self._model_def['Models']) is dict else [self._model_def['Models']]
1✔
349
        if models is None:
1✔
350
            models = all_models
1✔
351
        if isinstance(models, str):
1✔
352
            models = [models]
1✔
353

354
        ## Preliminary Checks
355
        self.__preliminary_checks(models = models, all_models = all_models, train_dataset = train_dataset, validation_dataset = validation_dataset)
1✔
356

357
        ## Recurret variables
358
        prediction_samples = self._setup_recurrent_variables(prediction_samples, closed_loop, connect)
1✔
359

360
        ## Get the dataset
361
        XY_train, XY_val, XY_test = self._setup_dataset(train_dataset, validation_dataset, None, dataset, splits)
1✔
362
        self.__check_needed_keys(train_data=XY_train, connect=connect, closed_loop=closed_loop)
1✔
363

364
        n_samples_train = next(iter(XY_train.values())).size(0)
1✔
365
        n_samples_val = next(iter(XY_val.values())).size(0) if XY_val else 0
1✔
366
        n_samples_test = next(iter(XY_test.values())).size(0) if XY_test else 0
1✔
367

368
        if train_dataset is not None:
1✔
369
            train_tag = self._get_tag(train_dataset)
1✔
370
            val_tag = self._get_tag(validation_dataset)
1✔
371
        else: ## splits is used
372
            if dataset is None:
1✔
373
                dataset = list(self._data.keys())
1✔
374
            tag = self._get_tag(dataset)
1✔
375
            train_tag = f"{tag}_train"
1✔
376
            val_tag = f"{tag}_val" if n_samples_val > 0 else None
1✔
377
            test_tag = f"{tag}_test" if n_samples_test > 0 else None
1✔
378

379
        train_indexes, val_indexes = [], []
1✔
380
        if train_dataset is not None:
1✔
381
            train_indexes, val_indexes = self._get_batch_indexes(train_dataset, n_samples_train, prediction_samples), self._get_batch_indexes(validation_dataset, n_samples_val, prediction_samples)
1✔
382
        else:
383
            dataset = list(self._data.keys()) if dataset is None else dataset
1✔
384
            train_indexes = self._get_batch_indexes(dataset, n_samples_train, prediction_samples)
1✔
385
            check(len(train_indexes) > 0, ValueError,
1✔
386
                  'The number of valid train samples is less than the number of prediction samples.')
387
            if n_samples_val > 0:
1✔
388
                val_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val, prediction_samples)
1✔
389
                val_indexes = [i - n_samples_train for i in val_indexes if i >= n_samples_train]
1✔
390
                if len(val_indexes) < 0:
1✔
391
                    log.warning('The number of valid validation samples is less than the number of prediction samples.')
×
392
            if n_samples_test > 0:
1✔
393
                test_indexes = self._get_batch_indexes(dataset, n_samples_train + n_samples_val + n_samples_test, prediction_samples)
1✔
394
                test_indexes = [i - (n_samples_train+n_samples_val)for i in test_indexes if i >= (n_samples_train+n_samples_val)]
1✔
395
                if len(test_indexes) < 0:
1✔
396
                    log.warning('The number of valid test samples is less than the number of prediction samples.')
×
397

398
        ## clip batch size and step
399
        train_batch_size = self._clip_batch_size(len(train_indexes), train_batch_size)
1✔
400
        train_step = self._clip_step(step, train_indexes, train_batch_size)
1✔
401
        if n_samples_val > 0:
1✔
402
            val_batch_size = self._clip_batch_size(len(val_indexes), val_batch_size)
1✔
403
            val_step = self._clip_step(step, val_indexes, val_batch_size)
1✔
404

405
        ## Save the training parameters
406
        self.running_parameters = {key:value for key,value in locals().items() if key not in ['self', 'kwargs', 'training_params', 'lr', 'lr_param']}
1✔
407

408
        ## Define the optimizer
409
        self.__initialize_optimizer(models, optimizer, training_params, optimizer_params, optimizer_defaults, add_optimizer_defaults, add_optimizer_params, lr, lr_param)
1✔
410
        torch_optimizer = self.__optimizer.get_torch_optimizer()
1✔
411

412
        ## Define the loss functions
413
        self.__initialize_loss()
1✔
414

415
        ## Define mandatory inputs
416
        mandatory_inputs, non_mandatory_inputs = self._get_mandatory_inputs(connect, closed_loop)
1✔
417

418
        ## Check close loop and connect
419
        self._clean_log_internal()
1✔
420

421
        ## Create the train, validation and test loss dictionaries
422
        train_losses, val_losses = {}, {}
1✔
423
        for key in self._model_def['Minimizers'].keys():
1✔
424
            train_losses[key] = []
1✔
425
            if n_samples_val > 0:
1✔
426
                val_losses[key] = []
1✔
427

428
        ## Set the gradient to true if necessary
429
        model_inputs = self._model_def['Inputs']
1✔
430
        for key in model_inputs.keys():
1✔
431
            if 'type' in model_inputs[key]:
1✔
432
                if key in XY_train:
1✔
433
                    XY_train[key].requires_grad_(True)
1✔
434
                if key in XY_val:
1✔
435
                    XY_val[key].requires_grad_(True)
1✔
436
        selected_model_def = ModelDef(self._model_def.getJson())
1✔
437

438
        ## Show the training parameters
439
        self.visualizer.showTrainParams()
1✔
440
        self.visualizer.showStartTraining()
1✔
441

442
        ## Update with virtual states
443
        if prediction_samples >= 0:
1✔
444
            self._model.update(closed_loop=closed_loop, connect=connect)
1✔
445
        else:
446
            self._model.update(disconnect=True)
1✔
447

448
        self.resetStates()  ## Reset the states
1✔
449

450
        ## start the train timer
451
        start = time.time()
1✔
452
        for epoch in range(num_of_epochs):
1✔
453
            ## TRAIN
454
            self._model.train()
1✔
455
            if prediction_samples >= 0:
1✔
456
                losses = self._recurrent_inference(XY_train, train_indexes, train_batch_size, minimize_gain, prediction_samples, train_step, non_mandatory_inputs, mandatory_inputs, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
457
            else:
458
                losses = self._inference(XY_train, n_samples_train, train_batch_size, minimize_gain, self.__loss_functions, shuffle=shuffle_data, optimizer=torch_optimizer)
1✔
459
            ## save the losses
460
            for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
461
                train_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
462

463
            if n_samples_val > 0:
1✔
464
                ## VALIDATION
465
                self._model.eval()
1✔
466
                setted_log_internal = self._log_internal
1✔
467
                self._set_log_internal(False)  # TODO To remove when the function is moved outside the train
1✔
468
                if prediction_samples >= 0:
1✔
469
                    losses = self._recurrent_inference(XY_val, val_indexes, val_batch_size, minimize_gain, prediction_samples, val_step,
1✔
470
                                                       non_mandatory_inputs, mandatory_inputs, self.__loss_functions)
471
                else:
472
                    losses = self._inference(XY_val, n_samples_val, val_batch_size, minimize_gain, self.__loss_functions)
1✔
473
                self._set_log_internal(setted_log_internal)
1✔
474

475
                ## save the losses
476
                for ind, key in enumerate(self._model_def['Minimizers'].keys()):
1✔
477
                    val_losses[key].append(torch.mean(losses[ind]).tolist())
1✔
478

479
            if callable(select_model):
1✔
NEW
480
                if select_model(train_losses, val_losses, select_model_params):
×
NEW
481
                    best_model_epoch = epoch
×
NEW
482
                    selected_model_def.updateParameters(self._model)
×
483

484
            ## Early-stopping
485
            if callable(early_stopping):
1✔
486
                if early_stopping(train_losses, val_losses, early_stopping_params):
1✔
487
                    log.info(f'Stopping the training at epoch {epoch} due to early stopping.')
1✔
488
                    break
1✔
489

490
            ## Visualize the training...
491
            self.visualizer.showTraining(epoch, train_losses, val_losses)
1✔
492
            self.visualizer.showWeightsInTrain(epoch=epoch)
1✔
493

494
        ## Visualize the training time
495
        end = time.time()
1✔
496
        self.visualizer.showTrainingTime(end - start)
1✔
497

498
        for key in self._model_def['Minimizers'].keys():
1✔
499
            self._training[key] = {'train': train_losses[key]}
1✔
500
            if n_samples_val > 0:
1✔
501
                self._training[key]['val'] = val_losses[key]
1✔
502
        self.visualizer.showEndTraining(num_of_epochs - 1, train_losses, val_losses)
1✔
503

504
        ## Select the model
505
        if callable(select_model):
1✔
NEW
506
            if not val_losses:
×
507
                # The model selected is updated for the last time by the final batch;
508
                # so the minimum loss (selected model) is referred to the model before the last update.
509
                # If the batch is small compared to the dataset dimension the differences in the model are small.
NEW
510
                log.warning('If not validation set is provided the selected model can differ from the optimal.')
×
511
            log.info(f'Selected the model at the epoch {best_model_epoch + 1}.')
×
512
            self._model = Model(selected_model_def)
×
513
        else:
514
            log.info('The selected model is the LAST model of the training.')
1✔
515

516
        ## Remove virtual states
517
        self._remove_virtual_states(connect, closed_loop)
1✔
518

519
        ## Get trained model from torch and set the model_def
520
        self._model_def.updateParameters(self._model)
1✔
521

522
#from 685
523
#from 840
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc