• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.82
/avalanche/training/templates/base_sgd.py
1
from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union
4✔
2
from pkg_resources import parse_version
4✔
3

4
import torch
4✔
5
from torch.nn import Module, CrossEntropyLoss
4✔
6
from torch.optim import Optimizer
4✔
7
from torch.utils.data import DataLoader
4✔
8
from torch import Tensor
4✔
9

10
from avalanche.benchmarks import CLExperience, CLStream
4✔
11
from avalanche.benchmarks.scenarios.generic_scenario import DatasetExperience
4✔
12
from avalanche.benchmarks.utils.data import AvalancheDataset
4✔
13
from avalanche.core import BasePlugin, BaseSGDPlugin
4✔
14
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
4✔
15
from avalanche.training.plugins.clock import Clock
4✔
16
from avalanche.training.plugins.evaluation import default_evaluator
4✔
17
from avalanche.training.templates.base import BaseTemplate
4✔
18
from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader, \
4✔
19
    collate_from_data_or_kwargs
20
from avalanche.training.templates.strategy_mixin_protocol import \
4✔
21
    SGDStrategyProtocol
22
from avalanche.training.utils import trigger_plugins
4✔
23

24

25
TDatasetExperience = TypeVar('TDatasetExperience', bound=DatasetExperience)
4✔
26
TMBInput = TypeVar('TMBInput')
4✔
27
TMBOutput = TypeVar('TMBOutput')
4✔
28

29

30
class BaseSGDTemplate(
4✔
31
        SGDStrategyProtocol[
32
            TDatasetExperience,
33
            TMBInput,
34
            TMBOutput],
35
        BaseTemplate[
36
            TDatasetExperience]
37
        ):
38
    """Base SGD class for continual learning skeletons.
4✔
39

40
    **Training loop**
41
    The training loop is organized as follows::
42

43
        train
44
            train_exp  # for each experience
45

46
    **Evaluation loop**
47
    The evaluation loop is organized as follows::
48

49
        eval
50
            eval_exp  # for each experience
51

52
    """
53

54
    PLUGIN_CLASS = BaseSGDPlugin
4✔
55

56
    def __init__(
4✔
57
        self,
58
        model: Module,
59
        optimizer: Optimizer,
60
        criterion=CrossEntropyLoss(),
61
        train_mb_size: int = 1,
62
        train_epochs: int = 1,
63
        eval_mb_size: Optional[int] = 1,
64
        device: Union[str, torch.device] = "cpu",
65
        plugins: Optional[Sequence[BasePlugin]] = None,
66
        evaluator: Union[
67
            EvaluationPlugin,
68
            Callable[[], EvaluationPlugin]
69
        ] = default_evaluator,
70
        eval_every=-1,
71
        peval_mode="epoch"
72
    ):
73
        """Init.
74

75
        :param model: PyTorch model.
76
        :param optimizer: PyTorch optimizer.
77
        :param criterion: loss function.
78
        :param train_mb_size: mini-batch size for training.
79
        :param train_epochs: number of training epochs.
80
        :param eval_mb_size: mini-batch size for eval.
81
        :param evaluator: (optional) instance of EvaluationPlugin for logging
82
            and metric computations. None to remove logging.
83
        :param eval_every: the frequency of the calls to `eval` inside the
84
            training loop. -1 disables the evaluation. 0 means `eval` is called
85
            only at the end of the learning experience. Values >0 mean that
86
            `eval` is called every `eval_every` epochs and at the end of the
87
            learning experience.
88
        :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
89
            periodic evaluation during training should execute every
90
            `eval_every` epochs or iterations (Default='epoch').
91
        """
92

93
        super().__init__()  # type: ignore
4✔
94
        BaseTemplate.__init__(
4✔
95
            self=self,
96
            model=model,
97
            device=device,
98
            plugins=plugins)
99

100
        self.optimizer: Optimizer = optimizer
4✔
101
        """ PyTorch optimizer. """
1✔
102

103
        self._criterion = criterion
4✔
104
        """ Criterion. """
1✔
105

106
        self.train_epochs: int = train_epochs
4✔
107
        """ Number of training epochs. """
1✔
108

109
        self.train_mb_size: int = train_mb_size
4✔
110
        """ Training mini-batch size. """
1✔
111

112
        self.eval_mb_size: int = (
4✔
113
            train_mb_size if eval_mb_size is None else eval_mb_size
114
        )
115
        """ Eval mini-batch size. """
1✔
116

117
        if evaluator is None:
4✔
118
            evaluator = EvaluationPlugin()
4✔
119
        elif callable(evaluator):
4✔
120
            evaluator = evaluator()
4✔
121

122
        self.plugins.append(evaluator)  # type: ignore
4✔
123
        self.evaluator: EvaluationPlugin = evaluator
4✔
124
        """ EvaluationPlugin used for logging and metric computations. """
1✔
125

126
        # Configure periodic evaluation.
127
        assert peval_mode in {"experience", "epoch", "iteration"}
4✔
128
        self.eval_every = eval_every
4✔
129
        peval = PeriodicEval(eval_every, peval_mode)
4✔
130
        self.plugins.append(peval)
4✔
131

132
        self.clock: Clock = Clock()
4✔
133
        """ Incremental counters for strategy events. """
1✔
134
        # WARNING: Clock needs to be the last plugin, otherwise
135
        # counters will be wrong for plugins called after it.
136
        self.plugins.append(self.clock)
4✔
137

138
        ###################################################################
139
        # State variables. These are updated during the train/eval loops. #
140
        ###################################################################
141

142
        self.adapted_dataset: Optional[AvalancheDataset] = None
4✔
143
        """ Data used to train. It may be modified by plugins. Plugins can 
1✔
144
        append data to it (e.g. for replay). 
145

146
        .. note::
147

148
            This dataset may contain samples from different experiences. If you 
149
            want the original data for the current experience  
150
            use :attr:`.BaseTemplate.experience`.
151
        """
152

153
        self.dataloader: Iterable[Any] = []
4✔
154
        """ Dataloader. """
1✔
155

156
        self.mbatch: Optional[TMBInput] = None
4✔
157
        """ Current mini-batch. """
1✔
158

159
        self.mb_output: Optional[TMBOutput] = None
4✔
160
        """ Model's output computed on the current mini-batch. """
1✔
161

162
        self.loss: Tensor = self._make_empty_loss()
4✔
163
        """ Loss of the current mini-batch. """
1✔
164

165
        self._stop_training = False
4✔
166

167
    def train(self,
4✔
168
              experiences: Union[TDatasetExperience,
169
                                 Iterable[TDatasetExperience]],
170
              eval_streams: Optional[
171
                    Sequence[Union[TDatasetExperience,
172
                                   Iterable[TDatasetExperience]]]] = None,
173
              **kwargs):
174

175
        super().train(experiences, eval_streams, **kwargs)
4✔
176
        return self.evaluator.get_last_metrics()
4✔
177

178
    @torch.no_grad()
4✔
179
    def eval(self, exp_list: Union[CLExperience, CLStream], **kwargs):
4✔
180
        """
181
        Evaluate the current model on a series of experiences and
182
        returns the last recorded value for each metric.
183

184
        :param exp_list: CL experience information.
185
        :param kwargs: custom arguments.
186

187
        :return: dictionary containing last recorded value for
188
            each metric name
189
        """
190
        super().eval(exp_list, **kwargs)
4✔
191
        return self.evaluator.get_last_metrics()
4✔
192

193
    def _eval_exp(self, **kwargs):
4✔
194
        self.eval_epoch(**kwargs)
4✔
195

196
    def make_optimizer(self):
4✔
197
        """Optimizer initialization."""
198
        # Should be implemented in Observation Type
199
        raise NotImplementedError()
×
200

201
    def criterion(self) -> Tensor:
4✔
202
        """Compute loss function."""
203
        raise NotImplementedError()
×
204

205
    def forward(self):
4✔
206
        """Compute the model's output given the current mini-batch."""
207
        raise NotImplementedError()
×
208

209
    def model_adaptation(self, model=None):
4✔
210
        """Adapts the model to the current experience."""
211
        raise NotImplementedError()
×
212

213
    def stop_training(self):
4✔
214
        """Signals to stop training at the next iteration."""
215
        self._stop_training = True
4✔
216

217
    def training_epoch(self, **kwargs):
4✔
218
        # Should be implemented in Update Type
219
        raise NotADirectoryError()
×
220

221
    def backward(self):
4✔
222
        """Run the backward pass."""
223
        self.loss.backward()
4✔
224

225
    def optimizer_step(self):
4✔
226
        """Execute the optimizer step (weights update)."""
227
        self.optimizer.step()
4✔
228

229
    def eval_epoch(self, **kwargs):
4✔
230
        """Evaluation loop over the current `self.dataloader`."""
231
        for self.mbatch in self.dataloader:
4✔
232
            self._unpack_minibatch()
4✔
233
            self._before_eval_iteration(**kwargs)
4✔
234

235
            self._before_eval_forward(**kwargs)
4✔
236
            self.mb_output = self.forward()
4✔
237
            self._after_eval_forward(**kwargs)
4✔
238
            self.loss = self.criterion()
4✔
239

240
            self._after_eval_iteration(**kwargs)
4✔
241

242
    # ==================================================================> NEW
243

244
    def check_model_and_optimizer(self):
4✔
245
        # Should be implemented in observation type
246
        raise NotImplementedError()
×
247

248
    def _before_training_exp(self, **kwargs):
4✔
249
        """Setup to train on a single experience."""
250
        # Data Adaptation (e.g. add new samples/data augmentation)
251
        self._before_train_dataset_adaptation(**kwargs)
4✔
252
        self.train_dataset_adaptation(**kwargs)
4✔
253
        self._after_train_dataset_adaptation(**kwargs)
4✔
254

255
        self.make_train_dataloader(**kwargs)
4✔
256

257
        # Model Adaptation (e.g. freeze/add new units)
258
        # self.model = self.model_adaptation()
259
        # self.make_optimizer()
260
        self.check_model_and_optimizer()
4✔
261

262
        super()._before_training_exp(**kwargs)
4✔
263

264
    def _train_cleanup(self):
4✔
265
        super()._train_cleanup()
4✔
266
        # reset for faster serialization
267
        self.adapted_dataset = None
4✔
268
        self.dataloader = []
4✔
269
        self.mbatch = None
4✔
270
        self.mb_output = None
4✔
271
        self.loss = self._make_empty_loss()
4✔
272

273
    def _eval_cleanup(self):
4✔
274
        super()._eval_cleanup()
4✔
275
        # reset for faster serialization
276
        self.adapted_dataset = None
4✔
277
        self.dataloader = []
4✔
278
        self.mbatch = None
4✔
279
        self.mb_output = None
4✔
280
        self.loss = self._make_empty_loss()
4✔
281

282
    def _train_exp(
4✔
283
        self, experience: CLExperience, eval_streams=None, **kwargs
284
    ):
285
        """Training loop over a single Experience object.
286

287
        :param experience: CL experience information.
288
        :param eval_streams: list of streams for evaluation.
289
            If None: use the training experience for evaluation.
290
            Use [] if you do not want to evaluate during training.
291
        :param kwargs: custom arguments.
292
        """
293
        if eval_streams is None:
4✔
294
            eval_streams = [experience]
×
295
        for i, exp in enumerate(eval_streams):
4✔
296
            if not isinstance(exp, Iterable):
4✔
297
                eval_streams[i] = [exp]
×
298
        for _ in range(self.train_epochs):
4✔
299
            self._before_training_epoch(**kwargs)
4✔
300

301
            if self._stop_training:  # Early stopping
4✔
302
                self._stop_training = False
4✔
303
                break
4✔
304

305
            self.training_epoch(**kwargs)
4✔
306
            self._after_training_epoch(**kwargs)
4✔
307

308
    def _save_train_state(self):
4✔
309
        """Save the training state which may be modified by the eval loop.
310

311
        This currently includes: experience, adapted_dataset, dataloader,
312
        is_training, and train/eval modes for each module.
313

314
        TODO: we probably need a better way to do this.
315
        """
316
        state = super()._save_train_state()
4✔
317
        new_state = {
4✔
318
            "adapted_dataset": self.adapted_dataset,
319
            "dataloader": self.dataloader,
320
        }
321
        return {**state, **new_state}
4✔
322

323
    def train_dataset_adaptation(self, **kwargs):
4✔
324
        """Initialize `self.adapted_dataset`."""
325
        assert self.experience is not None
4✔
326
        self.adapted_dataset = self.experience.dataset
4✔
327
        assert self.adapted_dataset is not None
4✔
328
        self.adapted_dataset = self.adapted_dataset.train()
4✔
329

330
    def _load_train_state(self, prev_state):
4✔
331
        super()._load_train_state(prev_state)
4✔
332
        self.adapted_dataset = prev_state["adapted_dataset"]
4✔
333
        self.dataloader = prev_state["dataloader"]
4✔
334

335
    def _before_eval_exp(self, **kwargs):
4✔
336

337
        # Data Adaptation
338
        self._before_eval_dataset_adaptation(**kwargs)
4✔
339
        self.eval_dataset_adaptation(**kwargs)
4✔
340
        self._after_eval_dataset_adaptation(**kwargs)
4✔
341

342
        self.make_eval_dataloader(**kwargs)
4✔
343
        # Model Adaptation (e.g. freeze/add new units)
344
        self.model = self.model_adaptation()
4✔
345

346
        super()._before_eval_exp(**kwargs)
4✔
347

348
    def _obtain_common_dataloader_parameters(self, **kwargs):
4✔
349
        """
350
        Utility function that returns the dictionary of parameters to be passed
351
        to the train and eval dataloaders.
352

353
        This function can be useful when in need to customize the data loading
354
        parameters but no radical changes are needed. When overriding to 
355
        add/customize parameters, it is recommended to first call this 
356
        implementation (super) to obtain a base dictionary of parameters.
357

358
        However, if a more deep change is needed in the data loading procedure,
359
        it is better to overrride :meth:`make_train_dataloader` and/or
360
        :meth:`make_eval_dataloader` directly.
361

362
        Note: the resulting dictionary does not include the collate function
363
        unless explicitly passed.
364

365
        :param kwargs: The dataloader arguments as passed to the `train`
366
            or `eval` method.
367
        :return: A dictionary of parameters to be passed to the DataLoader class
368
            or to one of the Avalanche dataloaders.
369
        """
370
        other_dataloader_args = {}
4✔
371

372
        if 'persistent_workers' in kwargs:
4✔
373
            if parse_version(torch.__version__) >= parse_version("1.7.0"):
4✔
374
                other_dataloader_args["persistent_workers"] = \
4✔
375
                    kwargs['persistent_workers']
376
            else:
NEW
377
                del kwargs['persistent_workers']
×
378

379
        for k, v in kwargs.items():
4✔
380
            other_dataloader_args[k] = v
4✔
381

382
        if other_dataloader_args.get('pin_memory', None) is None:
4✔
383
            other_dataloader_args['pin_memory'] = self.device.type == 'cuda'
4✔
384

385
        return other_dataloader_args
4✔
386

387
    def make_train_dataloader(
4✔
388
        self,
389
        num_workers=0,
390
        shuffle=True,
391
        pin_memory=None,
392
        persistent_workers=False,
393
        **kwargs
394
    ):
395
        """Data loader initialization.
396

397
        Called at the start of each learning experience after the dataset
398
        adaptation.
399

400
        :param num_workers: number of thread workers for the data loading.
401
        :param shuffle: True if the data should be shuffled, False otherwise.
402
        :param pin_memory: If True, the data loader will copy Tensors into CUDA
403
            pinned memory before returning them. Defaults to True.
404
        """
405

406
        assert self.adapted_dataset is not None
4✔
407

408
        other_dataloader_args = self._obtain_common_dataloader_parameters(
4✔
409
            batch_size=self.train_mb_size,
410
            num_workers=num_workers,
411
            shuffle=shuffle,
412
            pin_memory=pin_memory,
413
            persistent_workers=persistent_workers,
414
            **kwargs
415
        )
416

417
        self.dataloader = TaskBalancedDataLoader(
4✔
418
            self.adapted_dataset,
419
            oversample_small_groups=True,
420
            **other_dataloader_args
421
        )
422

423
    def make_eval_dataloader(
4✔
424
        self,
425
        num_workers=0,
426
        shuffle=False,
427
        pin_memory=None,
428
        persistent_workers=False,
429
        drop_last=False,
430
        **kwargs
431
    ):
432
        """
433
        Initializes the eval data loader.
434
        :param num_workers: How many subprocesses to use for data loading.
435
            0 means that the data will be loaded in the main process.
436
            (default: 0).
437
        :param pin_memory: If True, the data loader will copy Tensors into CUDA
438
            pinned memory before returning them. Defaults to True.
439
        :param kwargs:
440
        :return:
441
        """
442

443
        assert self.adapted_dataset is not None
4✔
444

445
        other_dataloader_args = self._obtain_common_dataloader_parameters(
4✔
446
            batch_size=self.eval_mb_size,
447
            num_workers=num_workers,
448
            shuffle=shuffle,
449
            pin_memory=pin_memory,
450
            persistent_workers=persistent_workers,
451
            drop_last=drop_last,
452
            **kwargs
453
        )
454

455
        collate_from_data_or_kwargs(
4✔
456
            self.adapted_dataset,
457
            other_dataloader_args)
458
        
459
        self.dataloader = DataLoader(
4✔
460
            self.adapted_dataset,
461
            **other_dataloader_args
462
        )
463

464
    def eval_dataset_adaptation(self, **kwargs):
4✔
465
        """Initialize `self.adapted_dataset`."""
466
        assert self.experience is not None
4✔
467
        self.adapted_dataset = self.experience.dataset
4✔
468
        assert self.adapted_dataset is not None
4✔
469
        self.adapted_dataset = self.adapted_dataset.eval()
4✔
470

471
    def _unpack_minibatch(self):
4✔
472
        raise NotImplementedError()
×
473

474
    def _make_empty_loss(self) -> Tensor:
4✔
475
        return torch.zeros(1, device=self.device)
4✔
476

477
    #########################################################
478
    # Plugin Triggers                                       #
479
    #########################################################
480

481
    def _before_training_epoch(self, **kwargs):
4✔
482
        trigger_plugins(self, "before_training_epoch", **kwargs)
4✔
483

484
    def _after_training_epoch(self, **kwargs):
4✔
485
        trigger_plugins(self, "after_training_epoch", **kwargs)
4✔
486

487
    def _before_training_iteration(self, **kwargs):
4✔
488
        trigger_plugins(self, "before_training_iteration", **kwargs)
4✔
489

490
    def _before_forward(self, **kwargs):
4✔
491
        trigger_plugins(self, "before_forward", **kwargs)
4✔
492

493
    def _after_forward(self, **kwargs):
4✔
494
        trigger_plugins(self, "after_forward", **kwargs)
4✔
495

496
    def _before_backward(self, **kwargs):
4✔
497
        trigger_plugins(self, "before_backward", **kwargs)
4✔
498

499
    def _after_backward(self, **kwargs):
4✔
500
        trigger_plugins(self, "after_backward", **kwargs)
4✔
501

502
    def _after_training_iteration(self, **kwargs):
4✔
503
        trigger_plugins(self, "after_training_iteration", **kwargs)
4✔
504

505
    def _before_update(self, **kwargs):
4✔
506
        trigger_plugins(self, "before_update", **kwargs)
4✔
507

508
    def _after_update(self, **kwargs):
4✔
509
        trigger_plugins(self, "after_update", **kwargs)
4✔
510

511
    def _before_eval_iteration(self, **kwargs):
4✔
512
        trigger_plugins(self, "before_eval_iteration", **kwargs)
4✔
513

514
    def _before_eval_forward(self, **kwargs):
4✔
515
        trigger_plugins(self, "before_eval_forward", **kwargs)
4✔
516

517
    def _after_eval_forward(self, **kwargs):
4✔
518
        trigger_plugins(self, "after_eval_forward", **kwargs)
4✔
519

520
    def _after_eval_iteration(self, **kwargs):
4✔
521
        trigger_plugins(self, "after_eval_iteration", **kwargs)
4✔
522

523
    # ==================================================================> NEW
524

525
    def _before_train_dataset_adaptation(self, **kwargs):
4✔
526
        trigger_plugins(self, "before_train_dataset_adaptation", **kwargs)
4✔
527

528
    def _after_train_dataset_adaptation(self, **kwargs):
4✔
529
        trigger_plugins(self, "after_train_dataset_adaptation", **kwargs)
4✔
530

531
    def _before_eval_dataset_adaptation(self, **kwargs):
4✔
532
        trigger_plugins(self, "before_eval_dataset_adaptation", **kwargs)
4✔
533

534
    def _after_eval_dataset_adaptation(self, **kwargs):
4✔
535
        trigger_plugins(self, "after_eval_dataset_adaptation", **kwargs)
4✔
536

537

538
class PeriodicEval(BaseSGDPlugin, supports_distributed=True):
4✔
539
    """Schedules periodic evaluation during training.
4✔
540

541
    This plugin is automatically configured and added by the BaseTemplate.
542
    """
543

544
    def __init__(
4✔
545
            self,
546
            eval_every=-1,
547
            peval_mode="epoch",
548
            do_initial=True):
549
        """Init.
550

551
        :param eval_every: the frequency of the calls to `eval` inside the
552
            training loop. -1 disables the evaluation. 0 means `eval` is called
553
            only at the end of the learning experience. Values >0 mean that
554
            `eval` is called every `eval_every` epochs and at the end of the
555
            learning experience.
556
        :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
557
            periodic evaluation during training should execute every
558
            `eval_every` epochs or iterations (Default='epoch').
559
        :param do_initial: whether to evaluate before each `train` call.
560
            Occasionally needed becuase some metrics need to know the
561
            accuracy before training.
562
        """
563
        super().__init__()
4✔
564
        assert peval_mode in {"experience", "epoch", "iteration"}
4✔
565
        self.eval_every = eval_every
4✔
566
        self.peval_mode = peval_mode
4✔
567
        self.do_initial = do_initial and eval_every > -1
4✔
568
        self.do_final: Optional[bool] = None
4✔
569
        self._is_eval_updated = False
4✔
570

571
    def before_training(self, strategy, **kwargs):
4✔
572
        """Eval before each learning experience.
573

574
        Occasionally needed because some metrics need the accuracy before
575
        training.
576
        """
577
        if self.do_initial:
4✔
578
            self._peval(strategy, **kwargs)
4✔
579

580
    def before_training_exp(self, strategy, **kwargs):
4✔
581
        # We evaluate at the start of each experience because train_epochs
582
        # could change.
583
        self.do_final = True
4✔
584
        if self.peval_mode == "epoch":
4✔
585
            if (
4✔
586
                self.eval_every > 0
587
                and (strategy.train_epochs - 1) % self.eval_every == 0
588
            ):
589
                self.do_final = False
4✔
590
        else:  # peval_mode == 'iteration'
591
            # we may need to fix this but we don't have a way to know
592
            # the number of total iterations.
593
            # Right now there may be two eval calls at the last iterations.
594
            pass
1✔
595
        self.do_final = self.do_final and self.eval_every > -1
4✔
596

597
    def _peval(self, strategy, **kwargs):
4✔
598
        for el in strategy._eval_streams:
4✔
599
            strategy.eval(el, **kwargs)
4✔
600

601
    def _maybe_peval(self, strategy, counter, **kwargs):
4✔
602
        if self.eval_every > 0 and counter % self.eval_every == 0:
4✔
603
            self._peval(strategy, **kwargs)
4✔
604

605
    def after_training_epoch(self, strategy: "BaseSGDTemplate",
4✔
606
                             **kwargs):
607
        """Periodic eval controlled by `self.eval_every` and
608
        `self.peval_mode`."""
609
        if self.peval_mode == "epoch":
4✔
610
            self._maybe_peval(strategy, strategy.clock.train_exp_epochs,
4✔
611
                              **kwargs)
612

613
    def after_training_iteration(self, strategy: "BaseSGDTemplate",
4✔
614
                                 **kwargs):
615
        """Periodic eval controlled by `self.eval_every` and
616
        `self.peval_mode`."""
617
        if self.peval_mode == "iteration":
4✔
618
            self._maybe_peval(strategy, strategy.clock.train_exp_iterations,
4✔
619
                              **kwargs)
620

621
    # ---> New
622
    def after_training_exp(self, strategy, **kwargs):
4✔
623
        """Final eval after a learning experience."""
624
        if self.do_final:
4✔
625
            self._peval(strategy, **kwargs)
4✔
626

627
    # def after_training_exp(self, strategy: "BaseOnlineSGDTemplate", **kwargs):
628
    #     """Periodic eval controlled by `self.eval_every` and
629
    #     `self.peval_mode`."""
630
    #     if self.peval_mode == "experience":
631
    #         self._maybe_peval(strategy, strategy.clock.train_exp_counter,
632
    #                           **kwargs)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc