• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.19
/avalanche/training/supervised/strategy_wrappers.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 01-12-2020                                                             #
7
# Author(s): Antonio Carta, Andrea Cossu                                       #
8
# E-mail: contact@continualai.org                                              #
9
# Website: avalanche.continualai.org                                           #
10
################################################################################
11
from typing import Callable, Optional, Sequence, List, Union
4✔
12
import torch
4✔
13
from torch.nn.parameter import Parameter
4✔
14

15
from torch.nn import Module, CrossEntropyLoss
4✔
16
from torch.optim import Optimizer
4✔
17

18
from avalanche.models.pnn import PNN
4✔
19
from avalanche.training.plugins.evaluation import (
4✔
20
    default_evaluator,
21
    default_loggers,
22
)
23
from avalanche.training.plugins import (
4✔
24
    SupervisedPlugin,
25
    CWRStarPlugin,
26
    ReplayPlugin,
27
    GenerativeReplayPlugin,
28
    TrainGeneratorAfterExpPlugin,
29
    GDumbPlugin,
30
    LwFPlugin,
31
    AGEMPlugin,
32
    GEMPlugin,
33
    EWCPlugin,
34
    EvaluationPlugin,
35
    SynapticIntelligencePlugin,
36
    CoPEPlugin,
37
    GSS_greedyPlugin,
38
    LFLPlugin,
39
    MASPlugin,
40
    BiCPlugin,
41
    MIRPlugin,
42
    FromScratchTrainingPlugin
43
)
44
from avalanche.training.templates.base import BaseTemplate
4✔
45
from avalanche.training.templates import SupervisedTemplate
4✔
46
from avalanche.models.generator import MlpVAE, VAE_loss
4✔
47
from avalanche.logging import InteractiveLogger
4✔
48

49

50
class Naive(SupervisedTemplate):
4✔
51
    """Naive finetuning.
4✔
52

53
    The simplest (and least effective) Continual Learning strategy. Naive just
54
    incrementally fine tunes a single model without employing any method
55
    to contrast the catastrophic forgetting of previous knowledge.
56
    This strategy does not use task identities.
57

58
    Naive is easy to set up and its results are commonly used to show the worst
59
    performing baseline.
60
    """
61

62
    def __init__(
4✔
63
        self,
64
        model: Module,
65
        optimizer: Optimizer,
66
        criterion=CrossEntropyLoss(),
67
        train_mb_size: int = 1,
68
        train_epochs: int = 1,
69
        eval_mb_size: Optional[int] = None,
70
        device: Union[str, torch.device] = "cpu",
71
        plugins: Optional[List[SupervisedPlugin]] = None,
72
        evaluator: Union[
73
            EvaluationPlugin,
74
            Callable[[], EvaluationPlugin]
75
        ] = default_evaluator,
76
        eval_every=-1,
77
        **base_kwargs
78
    ):
79
        """
80
        Creates an instance of the Naive strategy.
81

82
        :param model: The model.
83
        :param optimizer: The optimizer to use.
84
        :param criterion: The loss criterion to use.
85
        :param train_mb_size: The train minibatch size. Defaults to 1.
86
        :param train_epochs: The number of training epochs. Defaults to 1.
87
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
88
        :param device: The device to use. Defaults to None (cpu).
89
        :param plugins: Plugins to be added. Defaults to None.
90
        :param evaluator: (optional) instance of EvaluationPlugin for logging
91
            and metric computations.
92
        :param eval_every: the frequency of the calls to `eval` inside the
93
            training loop. -1 disables the evaluation. 0 means `eval` is called
94
            only at the end of the learning experience. Values >0 mean that
95
            `eval` is called every `eval_every` epochs and at the end of the
96
            learning experience.
97
        :param base_kwargs: any additional
98
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
99
        """
100
        super().__init__(
4✔
101
            model,
102
            optimizer,
103
            criterion,
104
            train_mb_size=train_mb_size,
105
            train_epochs=train_epochs,
106
            eval_mb_size=eval_mb_size,
107
            device=device,
108
            plugins=plugins,
109
            evaluator=evaluator,
110
            eval_every=eval_every,
111
            **base_kwargs
112
        )
113

114

115
class PNNStrategy(SupervisedTemplate):
4✔
116
    """Progressive Neural Network strategy.
4✔
117

118
    To use this strategy you need to instantiate a PNN model.
119
    """
120

121
    def __init__(
4✔
122
        self,
123
        model: Module,
124
        optimizer: Optimizer,
125
        criterion=CrossEntropyLoss(),
126
        train_mb_size: int = 1,
127
        train_epochs: int = 1,
128
        eval_mb_size: int = 1,
129
        device: Union[str, torch.device] = "cpu",
130
        plugins: Optional[Sequence["SupervisedPlugin"]] = None,
131
        evaluator: Union[
132
            EvaluationPlugin,
133
            Callable[[], EvaluationPlugin]
134
        ] = default_evaluator,
135
        eval_every=-1,
136
        **base_kwargs
137
    ):
138
        """Init.
139

140
        :param model: PyTorch model.
141
        :param optimizer: PyTorch optimizer.
142
        :param criterion: loss function.
143
        :param train_mb_size: mini-batch size for training.
144
        :param train_epochs: number of training epochs.
145
        :param eval_mb_size: mini-batch size for eval.
146
        :param device: PyTorch device where the model will be allocated.
147
        :param plugins: (optional) list of StrategyPlugins.
148
        :param evaluator: (optional) instance of EvaluationPlugin for logging
149
            and metric computations. None to remove logging.
150
        :param eval_every: the frequency of the calls to `eval` inside the
151
            training loop. -1 disables the evaluation. 0 means `eval` is called
152
            only at the end of the learning experience. Values >0 mean that
153
            `eval` is called every `eval_every` epochs and at the end of the
154
            learning experience.
155
        :param base_kwargs: any additional
156
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
157
        """
158
        # Check that the model has the correct architecture.
159
        assert isinstance(model, PNN), "PNNStrategy requires a PNN model."
4✔
160
        super().__init__(
4✔
161
            model=model,
162
            optimizer=optimizer,
163
            criterion=criterion,
164
            train_mb_size=train_mb_size,
165
            train_epochs=train_epochs,
166
            eval_mb_size=eval_mb_size,
167
            device=device,
168
            plugins=plugins,
169
            evaluator=evaluator,
170
            eval_every=eval_every,
171
            **base_kwargs
172
        )
173

174

175
class CWRStar(SupervisedTemplate):
4✔
176
    """CWR* Strategy."""
4✔
177

178
    def __init__(
4✔
179
        self,
180
        model: Module,
181
        optimizer: Optimizer,
182
        criterion,
183
        cwr_layer_name: str,
184
        train_mb_size: int = 1,
185
        train_epochs: int = 1,
186
        eval_mb_size: Optional[int] = None,
187
        device: Union[str, torch.device] = "cpu",
188
        plugins: Optional[List[SupervisedPlugin]] = None,
189
        evaluator: Union[
190
            EvaluationPlugin,
191
            Callable[[], EvaluationPlugin]
192
        ] = default_evaluator,
193
        eval_every=-1,
194
        **base_kwargs
195
    ):
196
        """Init.
197

198
        :param model: The model.
199
        :param optimizer: The optimizer to use.
200
        :param criterion: The loss criterion to use.
201
        :param cwr_layer_name: name of the CWR layer. Defaults to None, which
202
            means that the last fully connected layer will be used.
203
        :param train_mb_size: The train minibatch size. Defaults to 1.
204
        :param train_epochs: The number of training epochs. Defaults to 1.
205
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
206
        :param device: The device to use. Defaults to None (cpu).
207
        :param plugins: Plugins to be added. Defaults to None.
208
        :param evaluator: (optional) instance of EvaluationPlugin for logging
209
            and metric computations.
210
        :param eval_every: the frequency of the calls to `eval` inside the
211
            training loop. -1 disables the evaluation. 0 means `eval` is called
212
            only at the end of the learning experience. Values >0 mean that
213
            `eval` is called every `eval_every` epochs and at the end of the
214
            learning experience.
215
        :param \*\*base_kwargs: any additional
216
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
217
        """
218
        cwsp = CWRStarPlugin(model, cwr_layer_name, freeze_remaining_model=True)
4✔
219
        if plugins is None:
4✔
220
            plugins = [cwsp]
4✔
221
        else:
222
            plugins.append(cwsp)
×
223
        super().__init__(
4✔
224
            model,
225
            optimizer,
226
            criterion,
227
            train_mb_size=train_mb_size,
228
            train_epochs=train_epochs,
229
            eval_mb_size=eval_mb_size,
230
            device=device,
231
            plugins=plugins,
232
            evaluator=evaluator,
233
            eval_every=eval_every,
234
            **base_kwargs
235
        )
236

237

238
class Replay(SupervisedTemplate):
4✔
239
    """Experience replay strategy.
4✔
240

241
    See ReplayPlugin for more details.
242
    This strategy does not use task identities.
243
    """
244

245
    def __init__(
4✔
246
        self,
247
        model: Module,
248
        optimizer: Optimizer,
249
        criterion,
250
        mem_size: int = 200,
251
        train_mb_size: int = 1,
252
        train_epochs: int = 1,
253
        eval_mb_size: Optional[int] = None,
254
        device: Union[str, torch.device] = "cpu",
255
        plugins: Optional[List[SupervisedPlugin]] = None,
256
        evaluator: Union[
257
            EvaluationPlugin,
258
            Callable[[], EvaluationPlugin]
259
        ] = default_evaluator,
260
        eval_every=-1,
261
        **base_kwargs
262
    ):
263
        """Init.
264

265
        :param model: The model.
266
        :param optimizer: The optimizer to use.
267
        :param criterion: The loss criterion to use.
268
        :param mem_size: replay buffer size.
269
        :param train_mb_size: The train minibatch size. Defaults to 1.
270
        :param train_epochs: The number of training epochs. Defaults to 1.
271
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
272
        :param device: The device to use. Defaults to None (cpu).
273
        :param plugins: Plugins to be added. Defaults to None.
274
        :param evaluator: (optional) instance of EvaluationPlugin for logging
275
            and metric computations.
276
        :param eval_every: the frequency of the calls to `eval` inside the
277
            training loop. -1 disables the evaluation. 0 means `eval` is called
278
            only at the end of the learning experience. Values >0 mean that
279
            `eval` is called every `eval_every` epochs and at the end of the
280
            learning experience.
281
        :param \*\*base_kwargs: any additional
282
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
283
        """
284

285
        rp = ReplayPlugin(mem_size)
4✔
286
        if plugins is None:
4✔
287
            plugins = [rp]
4✔
288
        else:
289
            plugins.append(rp)
×
290
        super().__init__(
4✔
291
            model,
292
            optimizer,
293
            criterion,
294
            train_mb_size=train_mb_size,
295
            train_epochs=train_epochs,
296
            eval_mb_size=eval_mb_size,
297
            device=device,
298
            plugins=plugins,
299
            evaluator=evaluator,
300
            eval_every=eval_every,
301
            **base_kwargs
302
        )
303

304

305
class GenerativeReplay(SupervisedTemplate):
4✔
306
    """Generative Replay Strategy
4✔
307

308
    This implements Deep Generative Replay for a Scholar consisting of a Solver
309
    and Generator as described in https://arxiv.org/abs/1705.08690.
310

311
    The model parameter should contain the solver. As an optional input
312
    a generator can be wrapped in a trainable strategy
313
    and passed to the generator_strategy parameter. By default a simple VAE will
314
    be used as generator.
315

316
    For the case where the Generator is the model itself that is to be trained,
317
    please simply add the GenerativeReplayPlugin() when instantiating
318
    your Generator's strategy.
319

320
    See GenerativeReplayPlugin for more details.
321
    This strategy does not use task identities.
322
    """
323

324
    def __init__(
4✔
325
        self,
326
        model: Module,
327
        optimizer: Optimizer,
328
        criterion=CrossEntropyLoss(),
329
        train_mb_size: int = 1,
330
        train_epochs: int = 1,
331
        eval_mb_size: Optional[int] = None,
332
        device: Union[str, torch.device] = "cpu",
333
        plugins: Optional[List[SupervisedPlugin]] = None,
334
        evaluator: Union[
335
            EvaluationPlugin,
336
            Callable[[], EvaluationPlugin]
337
        ] = default_evaluator,
338
        eval_every=-1,
339
        generator_strategy: Optional[BaseTemplate] = None,
340
        replay_size: Optional[int] = None,
341
        increasing_replay_size: bool = False,
342
        **base_kwargs
343
    ):
344
        """
345
        Creates an instance of Generative Replay Strategy
346
        for a solver-generator pair.
347

348
        :param model: The solver model.
349
        :param optimizer: The optimizer to use.
350
        :param criterion: The loss criterion to use.
351
        :param train_mb_size: The train minibatch size. Defaults to 1.
352
        :param train_epochs: The number of training epochs. Defaults to 1.
353
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
354
        :param device: The device to use. Defaults to None (cpu).
355
        :param plugins: Plugins to be added. Defaults to None.
356
        :param evaluator: (optional) instance of EvaluationPlugin for logging
357
            and metric computations.
358
        :param eval_every: the frequency of the calls to `eval` inside the
359
            training loop. -1 disables the evaluation. 0 means `eval` is called
360
            only at the end of the learning experience. Values >0 mean that
361
            `eval` is called every `eval_every` epochs and at the end of the
362
            learning experience.
363
        :param generator_strategy: A trainable strategy with a generative model,
364
            which employs GenerativeReplayPlugin. Defaults to None.
365
        :param \*\*base_kwargs: any additional
366
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
367
        """
368

369
        # Check if user inputs a generator model
370
        # (which is wrapped in a strategy that can be trained and
371
        # uses the GenerativeReplayPlugin;
372
        # see 'VAETraining" as an example below.)
373
        if generator_strategy is not None:
×
374
            self.generator_strategy = generator_strategy
×
375
        else:
376
            # By default we use a fully-connected VAE as the generator.
377
            # model:
378
            generator = MlpVAE((1, 28, 28), nhid=2, device=device)
×
379
            # optimzer:
380
            lr = 0.01
×
381
            from torch.optim import Adam
×
382

383
            to_optimize: List[Parameter] = list(
×
384
                filter(lambda p: p.requires_grad, generator.parameters())
385
            )
386
            optimizer_generator = Adam(
×
387
                to_optimize,
388
                lr=lr,
389
                weight_decay=0.0001,
390
            )
391
            # strategy (with plugin):
392
            self.generator_strategy = VAETraining(
×
393
                model=generator,
394
                optimizer=optimizer_generator,
395
                criterion=VAE_loss,
396
                train_mb_size=train_mb_size,
397
                train_epochs=train_epochs,
398
                eval_mb_size=eval_mb_size,
399
                device=device,
400
                plugins=[
401
                    GenerativeReplayPlugin(
402
                        replay_size=replay_size,
403
                        increasing_replay_size=increasing_replay_size,
404
                    )
405
                ],
406
            )
407

408
        rp = GenerativeReplayPlugin(
×
409
            generator_strategy=self.generator_strategy,
410
            replay_size=replay_size,
411
            increasing_replay_size=increasing_replay_size,
412
        )
413

414
        tgp = TrainGeneratorAfterExpPlugin()
×
415

416
        if plugins is None:
×
417
            plugins = [tgp, rp]
×
418
        else:
419
            plugins.append(tgp)
×
420
            plugins.append(rp)
×
421

422
        super().__init__(
×
423
            model,
424
            optimizer,
425
            criterion,
426
            train_mb_size=train_mb_size,
427
            train_epochs=train_epochs,
428
            eval_mb_size=eval_mb_size,
429
            device=device,
430
            plugins=plugins,
431
            evaluator=evaluator,
432
            eval_every=eval_every,
433
            **base_kwargs
434
        )
435

436

437
def get_default_vae_logger():
4✔
NEW
438
    return EvaluationPlugin(loggers=default_loggers)
×
439

440

441
class VAETraining(SupervisedTemplate):
4✔
442
    """VAETraining class
4✔
443

444
    This is the training strategy for the VAE model
445
    found in the models directory.
446
    We make use of the SupervisedTemplate, even though technically this is not a
447
    supervised training. However, this reduces the modification to a minimum.
448

449
    We only need to overwrite the criterion function in order to pass all
450
    necessary variables to the VAE loss function.
451
    Furthermore we remove all metrics from the evaluator.
452
    """
453

454
    def __init__(
4✔
455
        self,
456
        model: Module,
457
        optimizer: Optimizer,
458
        criterion=VAE_loss,
459
        train_mb_size: int = 1,
460
        train_epochs: int = 1,
461
        eval_mb_size: Optional[int] = None,
462
        device: Union[str, torch.device] = "cpu",
463
        plugins: Optional[List[SupervisedPlugin]] = None,
464
        evaluator: Union[
465
            EvaluationPlugin,
466
            Callable[[], EvaluationPlugin]
467
        ] = get_default_vae_logger,
468
        eval_every=-1,
469
        **base_kwargs
470
    ):
471
        """
472
        Creates an instance of the Naive strategy.
473

474
        :param model: The model.
475
        :param optimizer: The optimizer to use.
476
        :param criterion: The loss criterion to use.
477
        :param train_mb_size: The train minibatch size. Defaults to 1.
478
        :param train_epochs: The number of training epochs. Defaults to 1.
479
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
480
        :param device: The device to use. Defaults to None (cpu).
481
        :param plugins: Plugins to be added. Defaults to None.
482
        :param evaluator: (optional) instance of EvaluationPlugin for logging
483
            and metric computations.
484
        :param eval_every: the frequency of the calls to `eval` inside the
485
            training loop. -1 disables the evaluation. 0 means `eval` is called
486
            only at the end of the learning experience. Values >0 mean that
487
            `eval` is called every `eval_every` epochs and at the end of the
488
            learning experience.
489
        :param \*\*base_kwargs: any additional
490
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
491
        """
492

493
        super().__init__(
×
494
            model,
495
            optimizer,
496
            criterion,
497
            train_mb_size=train_mb_size,
498
            train_epochs=train_epochs,
499
            eval_mb_size=eval_mb_size,
500
            device=device,
501
            plugins=plugins,
502
            evaluator=evaluator,
503
            eval_every=eval_every,
504
            **base_kwargs
505
        )
506

507
    def criterion(self):
4✔
508
        """Adapt input to criterion as needed to compute reconstruction loss
509
        and KL divergence. See default criterion VAELoss."""
510
        return self._criterion(self.mb_x, self.mb_output)
×
511

512

513
class GSS_greedy(SupervisedTemplate):
4✔
514
    """Experience replay strategy.
4✔
515

516
    See ReplayPlugin for more details.
517
    This strategy does not use task identities.
518
    """
519

520
    def __init__(
4✔
521
        self,
522
        model: Module,
523
        optimizer: Optimizer,
524
        criterion,
525
        mem_size: int = 200,
526
        mem_strength=1,
527
        input_size=[],
528
        train_mb_size: int = 1,
529
        train_epochs: int = 1,
530
        eval_mb_size: Optional[int] = None,
531
        device: Union[str, torch.device] = "cpu",
532
        plugins: Optional[List[SupervisedPlugin]] = None,
533
        evaluator: Union[
534
            EvaluationPlugin,
535
            Callable[[], EvaluationPlugin]
536
        ] = default_evaluator,
537
        eval_every=-1,
538
        **base_kwargs
539
    ):
540
        """Init.
541

542
        :param model: The model.
543
        :param optimizer: The optimizer to use.
544
        :param criterion: The loss criterion to use.
545
        :param mem_size: replay buffer size.
546
        :param n: memory random set size.
547
        :param train_mb_size: The train minibatch size. Defaults to 1.
548
        :param train_epochs: The number of training epochs. Defaults to 1.
549
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
550
        :param device: The device to use. Defaults to None (cpu).
551
        :param plugins: Plugins to be added. Defaults to None.
552
        :param evaluator: (optional) instance of EvaluationPlugin for logging
553
            and metric computations.
554
        :param eval_every: the frequency of the calls to `eval` inside the
555
            training loop. -1 disables the evaluation. 0 means `eval` is called
556
            only at the end of the learning experience. Values >0 mean that
557
            `eval` is called every `eval_every` epochs and at the end of the
558
            learning experience.
559
        :param \*\*base_kwargs: any additional
560
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
561
        """
562
        rp = GSS_greedyPlugin(
×
563
            mem_size=mem_size, mem_strength=mem_strength, input_size=input_size
564
        )
565
        if plugins is None:
×
566
            plugins = [rp]
×
567
        else:
568
            plugins.append(rp)
×
569
        super().__init__(
×
570
            model,
571
            optimizer,
572
            criterion,
573
            train_mb_size=train_mb_size,
574
            train_epochs=train_epochs,
575
            eval_mb_size=eval_mb_size,
576
            device=device,
577
            plugins=plugins,
578
            evaluator=evaluator,
579
            eval_every=eval_every,
580
            **base_kwargs
581
        )
582

583

584
class GDumb(SupervisedTemplate):
4✔
585
    """GDumb strategy.
4✔
586

587
    See GDumbPlugin for more details.
588
    This strategy does not use task identities.
589
    """
590

591
    def __init__(
4✔
592
        self,
593
        model: Module,
594
        optimizer: Optimizer,
595
        criterion,
596
        mem_size: int = 200,
597
        train_mb_size: int = 1,
598
        train_epochs: int = 1,
599
        eval_mb_size: Optional[int] = None,
600
        device: Union[str, torch.device] = "cpu",
601
        plugins: Optional[List[SupervisedPlugin]] = None,
602
        evaluator: Union[
603
            EvaluationPlugin,
604
            Callable[[], EvaluationPlugin]
605
        ] = default_evaluator,
606
        eval_every=-1,
607
        **base_kwargs
608
    ):
609
        """Init.
610

611
        :param model: The model.
612
        :param optimizer: The optimizer to use.
613
        :param criterion: The loss criterion to use.
614
        :param mem_size: replay buffer size.
615
        :param train_mb_size: The train minibatch size. Defaults to 1.
616
        :param train_epochs: The number of training epochs. Defaults to 1.
617
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
618
        :param device: The device to use. Defaults to None (cpu).
619
        :param plugins: Plugins to be added. Defaults to None.
620
        :param evaluator: (optional) instance of EvaluationPlugin for logging
621
            and metric computations.
622
        :param eval_every: the frequency of the calls to `eval` inside the
623
            training loop. -1 disables the evaluation. 0 means `eval` is called
624
            only at the end of the learning experience. Values >0 mean that
625
            `eval` is called every `eval_every` epochs and at the end of the
626
            learning experience.
627
        :param base_kwargs: any additional
628
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
629
        """
630

631
        gdumb = GDumbPlugin(mem_size)
4✔
632
        if plugins is None:
4✔
633
            plugins = [gdumb]
4✔
634
        else:
635
            plugins.append(gdumb)
×
636

637
        super().__init__(
4✔
638
            model,
639
            optimizer,
640
            criterion,
641
            train_mb_size=train_mb_size,
642
            train_epochs=train_epochs,
643
            eval_mb_size=eval_mb_size,
644
            device=device,
645
            plugins=plugins,
646
            evaluator=evaluator,
647
            eval_every=eval_every,
648
            **base_kwargs
649
        )
650

651

652
class LwF(SupervisedTemplate):
4✔
653
    """Learning without Forgetting (LwF) strategy.
4✔
654

655
    See LwF plugin for details.
656
    """
657

658
    def __init__(
4✔
659
        self,
660
        model: Module,
661
        optimizer: Optimizer,
662
        criterion,
663
        alpha: Union[float, Sequence[float]],
664
        temperature: float,
665
        train_mb_size: int = 1,
666
        train_epochs: int = 1,
667
        eval_mb_size: Optional[int] = None,
668
        device: Union[str, torch.device] = "cpu",
669
        plugins: Optional[List[SupervisedPlugin]] = None,
670
        evaluator: Union[
671
            EvaluationPlugin,
672
            Callable[[], EvaluationPlugin]
673
        ] = default_evaluator,
674
        eval_every=-1,
675
        **base_kwargs
676
    ):
677
        """Init.
678

679
        :param model: The model.
680
        :param optimizer: The optimizer to use.
681
        :param criterion: The loss criterion to use.
682
        :param alpha: distillation hyperparameter. It can be either a float
683
                number or a list containing alpha for each experience.
684
        :param temperature: softmax temperature for distillation
685
        :param train_mb_size: The train minibatch size. Defaults to 1.
686
        :param train_epochs: The number of training epochs. Defaults to 1.
687
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
688
        :param device: The device to use. Defaults to None (cpu).
689
        :param plugins: Plugins to be added. Defaults to None.
690
        :param evaluator: (optional) instance of EvaluationPlugin for logging
691
            and metric computations.
692
        :param eval_every: the frequency of the calls to `eval` inside the
693
            training loop. -1 disables the evaluation. 0 means `eval` is called
694
            only at the end of the learning experience. Values >0 mean that
695
            `eval` is called every `eval_every` epochs and at the end of the
696
            learning experience.
697
        :param base_kwargs: any additional
698
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
699
        """
700

701
        lwf = LwFPlugin(alpha, temperature)
4✔
702
        if plugins is None:
4✔
703
            plugins = [lwf]
4✔
704
        else:
705
            plugins.append(lwf)
×
706

707
        super().__init__(
4✔
708
            model,
709
            optimizer,
710
            criterion,
711
            train_mb_size=train_mb_size,
712
            train_epochs=train_epochs,
713
            eval_mb_size=eval_mb_size,
714
            device=device,
715
            plugins=plugins,
716
            evaluator=evaluator,
717
            eval_every=eval_every,
718
            **base_kwargs
719
        )
720

721

722
class AGEM(SupervisedTemplate):
4✔
723
    """Average Gradient Episodic Memory (A-GEM) strategy.
4✔
724

725
    See AGEM plugin for details.
726
    This strategy does not use task identities.
727
    """
728

729
    def __init__(
4✔
730
        self,
731
        model: Module,
732
        optimizer: Optimizer,
733
        criterion,
734
        patterns_per_exp: int,
735
        sample_size: int = 64,
736
        train_mb_size: int = 1,
737
        train_epochs: int = 1,
738
        eval_mb_size: Optional[int] = None,
739
        device: Union[str, torch.device] = "cpu",
740
        plugins: Optional[List[SupervisedPlugin]] = None,
741
        evaluator: Union[
742
            EvaluationPlugin,
743
            Callable[[], EvaluationPlugin]
744
        ] = default_evaluator,
745
        eval_every=-1,
746
        **base_kwargs
747
    ):
748
        """Init.
749

750
        :param model: The model.
751
        :param optimizer: The optimizer to use.
752
        :param criterion: The loss criterion to use.
753
        :param patterns_per_exp: number of patterns per experience in the memory
754
        :param sample_size: number of patterns in memory sample when computing
755
            reference gradient.
756
        :param train_mb_size: The train minibatch size. Defaults to 1.
757
        :param train_epochs: The number of training epochs. Defaults to 1.
758
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
759
        :param device: The device to use. Defaults to None (cpu).
760
        :param plugins: Plugins to be added. Defaults to None.
761
        :param evaluator: (optional) instance of EvaluationPlugin for logging
762
            and metric computations.
763
        :param eval_every: the frequency of the calls to `eval` inside the
764
            training loop. -1 disables the evaluation. 0 means `eval` is called
765
            only at the end of the learning experience. Values >0 mean that
766
            `eval` is called every `eval_every` epochs and at the end of the
767
            learning experience.
768
        :param base_kwargs: any additional
769
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
770
        """
771

772
        agem = AGEMPlugin(patterns_per_exp, sample_size)
4✔
773
        if plugins is None:
4✔
774
            plugins = [agem]
4✔
775
        else:
776
            plugins.append(agem)
×
777

778
        super().__init__(
4✔
779
            model,
780
            optimizer,
781
            criterion,
782
            train_mb_size=train_mb_size,
783
            train_epochs=train_epochs,
784
            eval_mb_size=eval_mb_size,
785
            device=device,
786
            plugins=plugins,
787
            evaluator=evaluator,
788
            eval_every=eval_every,
789
            **base_kwargs
790
        )
791

792

793
class GEM(SupervisedTemplate):
4✔
794
    """Gradient Episodic Memory (GEM) strategy.
4✔
795

796
    See GEM plugin for details.
797
    This strategy does not use task identities.
798
    """
799

800
    def __init__(
4✔
801
        self,
802
        model: Module,
803
        optimizer: Optimizer,
804
        criterion,
805
        patterns_per_exp: int,
806
        memory_strength: float = 0.5,
807
        train_mb_size: int = 1,
808
        train_epochs: int = 1,
809
        eval_mb_size: Optional[int] = None,
810
        device: Union[str, torch.device] = "cpu",
811
        plugins: Optional[List[SupervisedPlugin]] = None,
812
        evaluator: Union[
813
            EvaluationPlugin,
814
            Callable[[], EvaluationPlugin]
815
        ] = default_evaluator,
816
        eval_every=-1,
817
        **base_kwargs
818
    ):
819
        """Init.
820

821
        :param model: The model.
822
        :param optimizer: The optimizer to use.
823
        :param criterion: The loss criterion to use.
824
        :param patterns_per_exp: number of patterns per experience in the memory
825
        :param memory_strength: offset to add to the projection direction
826
            in order to favour backward transfer (gamma in original paper).
827
        :param train_mb_size: The train minibatch size. Defaults to 1.
828
        :param train_epochs: The number of training epochs. Defaults to 1.
829
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
830
        :param device: The device to use. Defaults to None (cpu).
831
        :param plugins: Plugins to be added. Defaults to None.
832
        :param evaluator: (optional) instance of EvaluationPlugin for logging
833
            and metric computations.
834
        :param eval_every: the frequency of the calls to `eval` inside the
835
            training loop. -1 disables the evaluation. 0 means `eval` is called
836
            only at the end of the learning experience. Values >0 mean that
837
            `eval` is called every `eval_every` epochs and at the end of the
838
            learning experience.
839
        :param base_kwargs: any additional
840
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
841
        """
842

843
        gem = GEMPlugin(patterns_per_exp, memory_strength)
4✔
844
        if plugins is None:
4✔
845
            plugins = [gem]
4✔
846
        else:
847
            plugins.append(gem)
×
848

849
        super().__init__(
4✔
850
            model,
851
            optimizer,
852
            criterion,
853
            train_mb_size=train_mb_size,
854
            train_epochs=train_epochs,
855
            eval_mb_size=eval_mb_size,
856
            device=device,
857
            plugins=plugins,
858
            evaluator=evaluator,
859
            eval_every=eval_every,
860
            **base_kwargs
861
        )
862

863

864
class EWC(SupervisedTemplate):
4✔
865
    """Elastic Weight Consolidation (EWC) strategy.
4✔
866

867
    See EWC plugin for details.
868
    This strategy does not use task identities.
869
    """
870

871
    def __init__(
4✔
872
        self,
873
        model: Module,
874
        optimizer: Optimizer,
875
        criterion,
876
        ewc_lambda: float,
877
        mode: str = "separate",
878
        decay_factor: Optional[float] = None,
879
        keep_importance_data: bool = False,
880
        train_mb_size: int = 1,
881
        train_epochs: int = 1,
882
        eval_mb_size: Optional[int] = None,
883
        device: Union[str, torch.device] = "cpu",
884
        plugins: Optional[List[SupervisedPlugin]] = None,
885
        evaluator: Union[
886
            EvaluationPlugin,
887
            Callable[[], EvaluationPlugin]
888
        ] = default_evaluator,
889
        eval_every=-1,
890
        **base_kwargs
891
    ):
892
        """Init.
893

894
        :param model: The model.
895
        :param optimizer: The optimizer to use.
896
        :param criterion: The loss criterion to use.
897
        :param ewc_lambda: hyperparameter to weigh the penalty inside the total
898
               loss. The larger the lambda, the larger the regularization.
899
        :param mode: `separate` to keep a separate penalty for each previous
900
               experience. `onlinesum` to keep a single penalty summed over all
901
               previous tasks. `onlineweightedsum` to keep a single penalty
902
               summed with a decay factor over all previous tasks.
903
        :param decay_factor: used only if mode is `onlineweightedsum`.
904
               It specify the decay term of the importance matrix.
905
        :param keep_importance_data: if True, keep in memory both parameter
906
                values and importances for all previous task, for all modes.
907
                If False, keep only last parameter values and importances.
908
                If mode is `separate`, the value of `keep_importance_data` is
909
                set to be True.
910
        :param train_mb_size: The train minibatch size. Defaults to 1.
911
        :param train_epochs: The number of training epochs. Defaults to 1.
912
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
913
        :param device: The device to use. Defaults to None (cpu).
914
        :param plugins: Plugins to be added. Defaults to None.
915
        :param evaluator: (optional) instance of EvaluationPlugin for logging
916
            and metric computations.
917
        :param eval_every: the frequency of the calls to `eval` inside the
918
            training loop. -1 disables the evaluation. 0 means `eval` is called
919
            only at the end of the learning experience. Values >0 mean that
920
            `eval` is called every `eval_every` epochs and at the end of the
921
            learning experience.
922
        :param base_kwargs: any additional
923
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
924
        """
925
        ewc = EWCPlugin(ewc_lambda, mode, decay_factor, keep_importance_data)
4✔
926
        if plugins is None:
4✔
927
            plugins = [ewc]
4✔
928
        else:
929
            plugins.append(ewc)
×
930

931
        super().__init__(
4✔
932
            model,
933
            optimizer,
934
            criterion,
935
            train_mb_size=train_mb_size,
936
            train_epochs=train_epochs,
937
            eval_mb_size=eval_mb_size,
938
            device=device,
939
            plugins=plugins,
940
            evaluator=evaluator,
941
            eval_every=eval_every,
942
            **base_kwargs
943
        )
944

945

946
class SynapticIntelligence(SupervisedTemplate):
4✔
947
    """Synaptic Intelligence strategy.
4✔
948

949
    This is the Synaptic Intelligence PyTorch implementation of the
950
    algorithm described in the paper
951
    "Continuous Learning in Single-Incremental-Task Scenarios"
952
    (https://arxiv.org/abs/1806.08568)
953

954
    The original implementation has been proposed in the paper
955
    "Continual Learning Through Synaptic Intelligence"
956
    (https://arxiv.org/abs/1703.04200).
957

958
    The Synaptic Intelligence regularization can also be used in a different
959
    strategy by applying the :class:`SynapticIntelligencePlugin` plugin.
960
    """
961

962
    def __init__(
4✔
963
        self,
964
        model: Module,
965
        optimizer: Optimizer,
966
        criterion,
967
        si_lambda: Union[float, Sequence[float]],
968
        eps: float = 0.0000001,
969
        train_mb_size: int = 1,
970
        train_epochs: int = 1,
971
        eval_mb_size: int = 1,
972
        device: Union[str, torch.device] = "cpu",
973
        plugins: Optional[Sequence["SupervisedPlugin"]] = None,
974
        evaluator: Union[
975
            EvaluationPlugin,
976
            Callable[[], EvaluationPlugin]
977
        ] = default_evaluator,
978
        eval_every=-1,
979
        **base_kwargs
980
    ):
981
        """Init.
982

983
        Creates an instance of the Synaptic Intelligence strategy.
984

985
        :param model: PyTorch model.
986
        :param optimizer: PyTorch optimizer.
987
        :param criterion: loss function.
988
        :param si_lambda: Synaptic Intelligence lambda term.
989
            If list, one lambda for each experience. If the list has less
990
            elements than the number of experiences, last lambda will be
991
            used for the remaining experiences.
992
        :param eps: Synaptic Intelligence damping parameter.
993
        :param train_mb_size: mini-batch size for training.
994
        :param train_epochs: number of training epochs.
995
        :param eval_mb_size: mini-batch size for eval.
996
        :param device: PyTorch device to run the model.
997
        :param plugins: (optional) list of StrategyPlugins.
998
        :param evaluator: (optional) instance of EvaluationPlugin for logging
999
            and metric computations.
1000
        :param eval_every: the frequency of the calls to `eval` inside the
1001
            training loop. -1 disables the evaluation. 0 means `eval` is called
1002
            only at the end of the learning experience. Values >0 mean that
1003
            `eval` is called every `eval_every` epochs and at the end of the
1004
            learning experience.
1005
        :param base_kwargs: any additional
1006
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1007
        """
1008
        if plugins is None:
4✔
1009
            plugins = []
4✔
1010
        
1011
        plugins = list(plugins)
4✔
1012

1013
        # This implementation relies on the S.I. Plugin, which contains the
1014
        # entire implementation of the strategy!
1015
        plugins.append(SynapticIntelligencePlugin(si_lambda=si_lambda, eps=eps))
4✔
1016

1017
        super(SynapticIntelligence, self).__init__(
4✔
1018
            model,
1019
            optimizer,
1020
            criterion,
1021
            train_mb_size,
1022
            train_epochs,
1023
            eval_mb_size,
1024
            device=device,
1025
            plugins=plugins,
1026
            evaluator=evaluator,
1027
            eval_every=eval_every,
1028
            **base_kwargs
1029
        )
1030

1031

1032
class CoPE(SupervisedTemplate):
4✔
1033
    """Continual Prototype Evolution strategy.
4✔
1034

1035
    See CoPEPlugin for more details.
1036
    This strategy does not use task identities during training.
1037
    """
1038

1039
    def __init__(
4✔
1040
        self,
1041
        model: Module,
1042
        optimizer: Optimizer,
1043
        criterion,
1044
        mem_size: int = 200,
1045
        n_classes: int = 10,
1046
        p_size: int = 100,
1047
        alpha: float = 0.99,
1048
        T: float = 0.1,
1049
        train_mb_size: int = 1,
1050
        train_epochs: int = 1,
1051
        eval_mb_size: Optional[int] = None,
1052
        device: Union[str, torch.device] = "cpu",
1053
        plugins: Optional[List[SupervisedPlugin]] = None,
1054
        evaluator: Union[
1055
            EvaluationPlugin,
1056
            Callable[[], EvaluationPlugin]
1057
        ] = default_evaluator,
1058
        eval_every=-1,
1059
        **base_kwargs
1060
    ):
1061
        """Init.
1062

1063
        :param model: The model.
1064
        :param optimizer: The optimizer to use.
1065
        :param criterion: Loss criterion to use. Standard overwritten by
1066
            PPPloss (see CoPEPlugin).
1067
        :param mem_size: replay buffer size.
1068
        :param n_classes: total number of classes that will be encountered. This
1069
            is used to output predictions for all classes, with zero probability
1070
            for unseen classes.
1071
        :param p_size: The prototype size, which equals the feature size of the
1072
            last layer.
1073
        :param alpha: The momentum for the exponentially moving average of the
1074
            prototypes.
1075
        :param T: The softmax temperature, used as a concentration parameter.
1076
        :param train_mb_size: The train minibatch size. Defaults to 1.
1077
        :param train_epochs: The number of training epochs. Defaults to 1.
1078
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
1079
        :param device: The device to use. Defaults to None (cpu).
1080
        :param plugins: Plugins to be added. Defaults to None.
1081
        :param evaluator: (optional) instance of EvaluationPlugin for logging
1082
            and metric computations.
1083
        :param eval_every: the frequency of the calls to `eval` inside the
1084
            training loop. -1 disables the evaluation. 0 means `eval` is called
1085
            only at the end of the learning experience. Values >0 mean that
1086
            `eval` is called every `eval_every` epochs and at the end of the
1087
            learning experience.
1088
        :param base_kwargs: any additional
1089
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1090
        """
1091
        copep = CoPEPlugin(mem_size, n_classes, p_size, alpha, T)
4✔
1092
        if plugins is None:
4✔
1093
            plugins = [copep]
4✔
1094
        else:
1095
            plugins.append(copep)
×
1096
        super().__init__(
4✔
1097
            model,
1098
            optimizer,
1099
            criterion,
1100
            train_mb_size=train_mb_size,
1101
            train_epochs=train_epochs,
1102
            eval_mb_size=eval_mb_size,
1103
            device=device,
1104
            plugins=plugins,
1105
            evaluator=evaluator,
1106
            eval_every=eval_every,
1107
            **base_kwargs
1108
        )
1109

1110

1111
class LFL(SupervisedTemplate):
4✔
1112
    """Less Forgetful Learning strategy.
4✔
1113

1114
    See LFL plugin for details.
1115
    Refer Paper: https://arxiv.org/pdf/1607.00122.pdf
1116
    This strategy does not use task identities.
1117
    """
1118

1119
    def __init__(
4✔
1120
        self,
1121
        model: Module,
1122
        optimizer: Optimizer,
1123
        criterion,
1124
        lambda_e: Union[float, Sequence[float]],
1125
        train_mb_size: int = 1,
1126
        train_epochs: int = 1,
1127
        eval_mb_size: Optional[int] = None,
1128
        device: Union[str, torch.device] = "cpu",
1129
        plugins: Optional[List[SupervisedPlugin]] = None,
1130
        evaluator: Union[
1131
            EvaluationPlugin,
1132
            Callable[[], EvaluationPlugin]
1133
        ] = default_evaluator,
1134
        eval_every=-1,
1135
        **base_kwargs
1136
    ):
1137
        """Init.
1138

1139
        :param model: The model.
1140
        :param optimizer: The optimizer to use.
1141
        :param criterion: The loss criterion to use.
1142
        :param lambda_e: euclidean loss hyper parameter. It can be either a
1143
                float number or a list containing lambda_e for each experience.
1144
        :param train_mb_size: The train minibatch size. Defaults to 1.
1145
        :param train_epochs: The number of training epochs. Defaults to 1.
1146
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
1147
        :param device: The device to use. Defaults to None (cpu).
1148
        :param plugins: Plugins to be added. Defaults to None.
1149
        :param evaluator: (optional) instance of EvaluationPlugin for logging
1150
            and metric computations.
1151
        :param eval_every: the frequency of the calls to `eval` inside the
1152
            training loop. -1 disables the evaluation. 0 means `eval` is called
1153
            only at the end of the learning experience. Values >0 mean that
1154
            `eval` is called every `eval_every` epochs and at the end of the
1155
            learning experience.
1156
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1157
        """
1158

1159
        lfl = LFLPlugin(lambda_e)
4✔
1160
        if plugins is None:
4✔
1161
            plugins = [lfl]
4✔
1162
        else:
1163
            plugins.append(lfl)
×
1164

1165
        super().__init__(
4✔
1166
            model,
1167
            optimizer,
1168
            criterion,
1169
            train_mb_size=train_mb_size,
1170
            train_epochs=train_epochs,
1171
            eval_mb_size=eval_mb_size,
1172
            device=device,
1173
            plugins=plugins,
1174
            evaluator=evaluator,
1175
            eval_every=eval_every,
1176
            **base_kwargs
1177
        )
1178

1179

1180
class MAS(SupervisedTemplate):
4✔
1181
    """Memory Aware Synapses (MAS) strategy.
4✔
1182

1183
    See MAS plugin for details.
1184
    This strategy does not use task identities.
1185
    """
1186

1187
    def __init__(
4✔
1188
        self,
1189
        model: Module,
1190
        optimizer: Optimizer,
1191
        criterion,
1192
        lambda_reg: float = 1.0,
1193
        alpha: float = 0.5,
1194
        verbose: bool = False,
1195
        train_mb_size: int = 1,
1196
        train_epochs: int = 1,
1197
        eval_mb_size: int = 1,
1198
        device: Union[str, torch.device] = "cpu",
1199
        plugins: Optional[List[SupervisedPlugin]] = None,
1200
        evaluator: Union[
1201
            EvaluationPlugin,
1202
            Callable[[], EvaluationPlugin]
1203
        ] = default_evaluator,
1204
        eval_every=-1,
1205
        **base_kwargs
1206
    ):
1207
        """Init.
1208

1209
        :param model: The model.
1210
        :param optimizer: The optimizer to use.
1211
        :param criterion: The loss criterion to use.
1212
        :param lambda_reg: hyperparameter weighting the penalty term
1213
               in the overall loss.
1214
        :param alpha: hyperparameter that specifies the weight given
1215
               to the influence of the previous experience.
1216
        :param verbose: when True, the computation of the influence of
1217
               each parameter shows a progress bar.
1218
        :param train_mb_size: The train minibatch size. Defaults to 1.
1219
        :param train_epochs: The number of training epochs. Defaults to 1.
1220
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
1221
        :param device: The device to use. Defaults to None (cpu).
1222
        :param plugins: Plugins to be added. Defaults to None.
1223
        :param evaluator: (optional) instance of EvaluationPlugin for logging
1224
            and metric computations.
1225
        :param eval_every: the frequency of the calls to `eval` inside the
1226
            training loop. -1 disables the evaluation. 0 means `eval` is called
1227
            only at the end of the learning experience. Values >0 mean that
1228
            `eval` is called every `eval_every` epochs and at the end of the
1229
            learning experience.
1230
        :param **base_kwargs: any additional
1231
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1232
        """
1233

1234
        # Instantiate plugin
1235
        mas = MASPlugin(lambda_reg=lambda_reg, alpha=alpha, verbose=verbose)
4✔
1236

1237
        # Add plugin to the strategy
1238
        if plugins is None:
4✔
1239
            plugins = [mas]
4✔
1240
        else:
1241
            plugins.append(mas)
×
1242

1243
        super().__init__(
4✔
1244
            model,
1245
            optimizer,
1246
            criterion,
1247
            train_mb_size=train_mb_size,
1248
            train_epochs=train_epochs,
1249
            eval_mb_size=eval_mb_size,
1250
            device=device,
1251
            plugins=plugins,
1252
            evaluator=evaluator,
1253
            eval_every=eval_every,
1254
            **base_kwargs
1255
        )
1256

1257

1258
class BiC(SupervisedTemplate):
4✔
1259
    """Bias Correction (BiC) strategy.
4✔
1260

1261
    See BiC plugin for details.
1262
    This strategy does not use task identities.
1263
    """
1264

1265
    def __init__(
4✔
1266
        self,
1267
        model: Module,
1268
        optimizer: Optimizer,
1269
        criterion,
1270
        mem_size: int = 200,
1271
        val_percentage: float = 0.1,
1272
        T: int = 2, 
1273
        stage_2_epochs: int = 200,
1274
        lamb: float = -1, 
1275
        lr: float = 0.1,
1276
        train_mb_size: int = 1,
1277
        train_epochs: int = 1,
1278
        eval_mb_size: Optional[int] = None,
1279
        device: Union[str, torch.device] = "cpu",
1280
        plugins: Optional[List[SupervisedPlugin]] = None,
1281
        evaluator: Union[
1282
            EvaluationPlugin,
1283
            Callable[[], EvaluationPlugin]
1284
        ] = default_evaluator,
1285
        eval_every=-1,
1286
        **base_kwargs
1287
    ):
1288
        """Init.
1289

1290
        :param model: The model.
1291
        :param optimizer: The optimizer to use.
1292
        :param criterion: The loss criterion to use.
1293
        :param mem_size: replay buffer size.
1294
        :param val_percentage: hyperparameter used to set the 
1295
                percentage of exemplars in the val set.
1296
        :param T: hyperparameter used to set the temperature 
1297
                used in stage 1.
1298
        :param stage_2_epochs: hyperparameter used to set the 
1299
                amount of epochs of stage 2.
1300
        :param lamb: hyperparameter used to balance the distilling 
1301
                loss and the classification loss.
1302
        :param lr: hyperparameter used as a learning rate for
1303
                the second phase of training.
1304
        :param train_mb_size: The train minibatch size. Defaults to 1.
1305
        :param train_epochs: The number of training epochs. Defaults to 1.
1306
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
1307
        :param device: The device to use. Defaults to None (cpu).
1308
        :param plugins: Plugins to be added. Defaults to None.
1309
        :param evaluator: (optional) instance of EvaluationPlugin for logging
1310
            and metric computations.
1311
        :param eval_every: the frequency of the calls to `eval` inside the
1312
            training loop. -1 disables the evaluation. 0 means `eval` is called
1313
            only at the end of the learning experience. Values >0 mean that
1314
            `eval` is called every `eval_every` epochs and at the end of the
1315
            learning experience.
1316
        :param \*\*base_kwargs: any additional
1317
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1318
        """
1319

1320
        # Instantiate plugin
1321
        bic = BiCPlugin(mem_size=mem_size, 
4✔
1322
                        val_percentage=val_percentage,
1323
                        T=T,
1324
                        stage_2_epochs=stage_2_epochs,
1325
                        lamb=lamb,
1326
                        lr=lr,
1327
                        )
1328

1329
        # Add plugin to the strategy
1330
        if plugins is None:
4✔
1331
            plugins = [bic]
4✔
1332
        else:
1333
            plugins.append(bic)
×
1334

1335
        super().__init__(
4✔
1336
            model,
1337
            optimizer,
1338
            criterion,
1339
            train_mb_size=train_mb_size,
1340
            train_epochs=train_epochs,
1341
            eval_mb_size=eval_mb_size,
1342
            device=device,
1343
            plugins=plugins,
1344
            evaluator=evaluator,
1345
            eval_every=eval_every,
1346
            **base_kwargs
1347
        )
1348

1349

1350
class MIR(SupervisedTemplate):
4✔
1351
    """Maximally Interfered Replay Strategy
4✔
1352
    See ER_MIR plugin for details.
1353
    """
1354

1355
    def __init__(
4✔
1356
        self,
1357
        model: Module,
1358
        optimizer: Optimizer,
1359
        criterion,
1360
        mem_size: int,
1361
        subsample: int,
1362
        batch_size_mem: int = 1,
1363
        train_mb_size: int = 1,
1364
        train_epochs: int = 1,
1365
        eval_mb_size: int = 1,
1366
        device: Union[str, torch.device] = "cpu",
1367
        plugins: Optional[List[SupervisedPlugin]] = None,
1368
        evaluator: Union[
1369
            EvaluationPlugin,
1370
            Callable[[], EvaluationPlugin]
1371
        ] = default_evaluator,
1372
        eval_every=-1,
1373
        **base_kwargs
1374
    ):
1375
        """Init.
1376
        :param model: The model.
1377
        :param optimizer: The optimizer to use.
1378
        :param criterion: The loss criterion to use.
1379
        :param mem_size: Amount of fixed memory to use
1380
        :param subsample: Size of the initial sample
1381
                from which to select the replay batch
1382
        :param batch_size_mem: Size of the replay batch after
1383
                loss-based selection
1384
        :param train_mb_size: The train minibatch size. Defaults to 1.
1385
        :param train_epochs: The number of training epochs. Defaults to 1.
1386
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
1387
        :param device: The device to use. Defaults to None (cpu).
1388
        :param plugins: Plugins to be added. Defaults to None.
1389
        :param evaluator: (optional) instance of EvaluationPlugin for logging
1390
            and metric computations.
1391
        :param eval_every: the frequency of the calls to `eval` inside the
1392
            training loop. -1 disables the evaluation. 0 means `eval` is called
1393
            only at the end of the learning experience. Values >0 mean that
1394
            `eval` is called every `eval_every` epochs and at the end of the
1395
            learning experience.
1396
        :param **base_kwargs: any additional
1397
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1398
        """
1399

1400
        # Instantiate plugin
1401
        mir = MIRPlugin(
4✔
1402
            mem_size=mem_size, 
1403
            subsample=subsample,
1404
            batch_size_mem=batch_size_mem
1405
        )
1406

1407
        # Add plugin to the strategy
1408
        if plugins is None:
4✔
1409
            plugins = [mir]
4✔
1410
        else:
1411
            plugins.append(mir)
×
1412

1413
        super().__init__(
4✔
1414
            model,
1415
            optimizer,
1416
            criterion,
1417
            train_mb_size=train_mb_size,
1418
            train_epochs=train_epochs,
1419
            eval_mb_size=eval_mb_size,
1420
            device=device,
1421
            plugins=plugins,
1422
            evaluator=evaluator,
1423
            eval_every=eval_every,
1424
            **base_kwargs
1425
        )
1426

1427

1428
class FromScratchTraining(SupervisedTemplate):
4✔
1429
    """From scratch training strategy.
4✔
1430
    This strategy trains a model on a stream of experiences, but resets the
1431
    model's weight initialization and optimizer state after each experience.
1432
    It is usually used a baseline for comparison with the Naive strategy where
1433
    the model is fine-tuned to every new experience. See
1434
    FromScratchTrainingPlugin for more details.
1435
    """
1436

1437
    def __init__(
4✔
1438
        self,
1439
        model: Module,
1440
        optimizer: Optimizer,
1441
        criterion,
1442
        reset_optimizer: bool = True,
1443
        train_mb_size: int = 1,
1444
        train_epochs: int = 1,
1445
        eval_mb_size: Optional[int] = None,
1446
        device: Union[str, torch.device] = "cpu",
1447
        plugins: Optional[List[SupervisedPlugin]] = None,
1448
        evaluator: Union[
1449
            EvaluationPlugin,
1450
            Callable[[], EvaluationPlugin]
1451
        ] = default_evaluator,
1452
        eval_every=-1,
1453
        **base_kwargs
1454
    ):
1455
        """Init.
1456

1457
        :param model: The model.
1458
        :param optimizer: The optimizer to use.
1459
        :param criterion: The loss criterion to use.
1460
        :param reset_optimizer: If True, optimizer state will be reset after
1461
            each experience.
1462
        :param train_mb_size: The train minibatch size. Defaults to 1.
1463
        :param train_epochs: The number of training epochs. Defaults to 1.
1464
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
1465
        :param device: The device to use. Defaults to None (cpu).
1466
        :param plugins: Plugins to be added. Defaults to None.
1467
        :param evaluator: (optional) instance of EvaluationPlugin for logging
1468
            and metric computations.
1469
        :param eval_every: the frequency of the calls to `eval` inside the
1470
            training loop. -1 disables the evaluation. 0 means `eval` is called
1471
            only at the end of the learning experience. Values >0 mean that
1472
            `eval` is called every `eval_every` epochs and at the end of the
1473
            learning experience.
1474
        :param \*\*base_kwargs: any additional
1475
            :class:`~avalanche.training.BaseTemplate` constructor arguments.
1476
        """
1477

1478
        fstp = FromScratchTrainingPlugin(reset_optimizer=reset_optimizer)
×
1479
        if plugins is None:
×
1480
            plugins = [fstp]
×
1481
        else:
1482
            plugins.append(fstp)
×
1483
        super().__init__(
×
1484
            model,
1485
            optimizer,
1486
            criterion,
1487
            train_mb_size=train_mb_size,
1488
            train_epochs=train_epochs,
1489
            eval_mb_size=eval_mb_size,
1490
            device=device,
1491
            plugins=plugins,
1492
            evaluator=evaluator,
1493
            eval_every=eval_every,
1494
            **base_kwargs
1495
        )
1496

1497

1498
__all__ = [
4✔
1499
    "Naive",
1500
    "PNNStrategy",
1501
    "CWRStar",
1502
    "Replay",
1503
    "GenerativeReplay",
1504
    "VAETraining",
1505
    "GDumb",
1506
    "LwF",
1507
    "AGEM",
1508
    "GEM",
1509
    "EWC",
1510
    "SynapticIntelligence",
1511
    "GSS_greedy",
1512
    "CoPE",
1513
    "LFL",
1514
    "MAS",
1515
    "BiC",
1516
    "MIR",
1517
    "FromScratchTraining"
1518
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc