• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8317913660

17 Mar 2024 07:59PM UTC coverage: 51.761% (-0.05%) from 51.806%
8317913660

Pull #1596

github

web-flow
Merge 6082a3cc7 into dbdc3804b
Pull Request #1596: Generative Replay with weighted loss for replayed data

3 of 29 new or added lines in 1 file covered. (10.34%)

1 existing line in 1 file now uncovered.

14757 of 28510 relevant lines covered (51.76%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

20.73
/avalanche/training/plugins/generative_replay.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 05-03-2022                                                             #
7
# Author: Florian Mies                                                         #
8
# Website: https://github.com/travela                                          #
9
################################################################################
10

11
"""
1✔
12

13
All plugins related to Generative Replay.
14

15
"""
16

17
from copy import deepcopy
1✔
18
from typing import Optional, Any
1✔
19
from avalanche.core import SupervisedPlugin, Template
1✔
20
import torch
1✔
21

22

23
class GenerativeReplayPlugin(SupervisedPlugin):
1✔
24
    """
1✔
25
    Experience generative replay plugin.
26

27
    Updates the current mbatch of a strategy before training an experience
28
    by sampling a generator model and concatenating the replay data to the
29
    current batch.
30

31
    In this version of the plugin the number of replay samples is
32
    increased with each new experience. Another way to implempent
33
    the algorithm is by weighting the loss function and give more
34
    importance to the replayed data as the number of experiences
35
    increases. This will be implemented as an option for the user soon.
36

37
    :param generator_strategy: In case the plugin is applied to a non-generative
38
     model (e.g. a simple classifier), this should contain an Avalanche strategy
39
     for a model that implements a 'generate' method
40
     (see avalanche.models.generator.Generator). Defaults to None.
41
    :param untrained_solver: if True we assume this is the beginning of
42
        a continual learning task and add replay data only from the second
43
        experience onwards, otherwise we sample and add generative replay data
44
        before training the first experience. Default to True.
45
    :param replay_size: The user can specify the batch size of replays that
46
        should be added to each data batch. By default each data batch will be
47
        matched with replays of the same number.
48
    :param increasing_replay_size: If set to True, each experience this will
49
        double the amount of replay data added to each data batch. The effect
50
        will be that the older experiences will gradually increase in importance
51
        to the final loss.
52
    :param is_weighted_replay: If set to True, the loss function will be weighted
53
        and more importance will be given to the replay data as the number of
54
        experiences increases.
55
    :param weight_replay_loss_factor: If is_weighted_replay is set to True, the user
56
        can specify a factor the weight will be multiplied by in each iteration,
57
        the default is 1.0
58
    :param weight_replay_loss: The user can specify the initial weight of the loss for
59
        the replay data. The default is 0.0001
60
    """
61

62
    def __init__(
1✔
63
        self,
64
        generator_strategy=None,
65
        untrained_solver: bool = True,
66
        replay_size: Optional[int] = None,
67
        increasing_replay_size: bool = False,
68
        is_weighted_replay: bool = False,
69
        weight_replay_loss_factor: float = 1.0,
70
        weight_replay_loss: float = 0.0001,
71
    ):
72
        """
73
        Init.
74
        """
75
        super().__init__()
×
76
        self.generator_strategy = generator_strategy
×
77
        if self.generator_strategy:
×
78
            self.generator = generator_strategy.model
×
79
        else:
80
            self.generator = None
×
81
        self.untrained_solver = untrained_solver
×
82
        self.model_is_generator = False
×
83
        self.replay_size = replay_size
×
84
        self.increasing_replay_size = increasing_replay_size
×
NEW
85
        self.is_weighted_replay = is_weighted_replay
×
NEW
86
        self.weight_replay_loss_factor = weight_replay_loss_factor
×
NEW
87
        self.weight_replay_loss = weight_replay_loss
×
88

89
    def before_training(self, strategy, *args, **kwargs):
1✔
90
        """Checks whether we are using a user defined external generator
91
        or we use the strategy's model as the generator.
92
        If the generator is None after initialization
93
        we assume that strategy.model is the generator.
94
        (e.g. this would be the case when training a VAE with
95
        generative replay)"""
96
        if not self.generator_strategy:
×
97
            self.generator_strategy = strategy
×
98
            self.generator = strategy.model
×
99
            self.model_is_generator = True
×
100

101
    def before_training_exp(
1✔
102
        self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs
103
    ):
104
        """
105
        Make deep copies of generator and solver before training new experience.
106
        """
107
        if self.untrained_solver:
×
108
            # The solver needs to be trained before labelling generated data and
109
            # the generator needs to be trained before we can sample.
110
            return
×
111
        self.old_generator = deepcopy(self.generator)
×
112
        self.old_generator.eval()
×
113
        if not self.model_is_generator:
×
114
            self.old_model = deepcopy(strategy.model)
×
115
            self.old_model.eval()
×
116

117
    def after_training_exp(
1✔
118
        self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs
119
    ):
120
        """
121
        Set untrained_solver boolean to False after (the first) experience,
122
        in order to start training with replay data from the second experience.
123
        """
124
        self.untrained_solver = False
×
125

126
    def before_backward(self, strategy: Template, *args, **kwargs) -> Any:
1✔
127
        """
128
        Generate replay data and calculate the loss on the replay data.
129
        Add weighted loss to the total loss if the user has set the weight_replay_loss
130
        """
NEW
131
        super().before_backward(strategy, *args, **kwargs)
×
NEW
132
        if not self.is_weighted_replay:
×
133
            # If we are not using weighted loss, ignore this method
NEW
134
            return
×
135

NEW
136
        if self.untrained_solver:
×
137
            # do not generate on the first experience
NEW
138
            return
×
139

140
        # determine how many replay data points to generate
NEW
141
        if self.replay_size:
×
NEW
142
            number_replays_to_generate = self.replay_size
×
143
        else:
NEW
144
            if self.increasing_replay_size:
×
NEW
145
                number_replays_to_generate = len(strategy.mbatch[0]) * (
×
146
                    strategy.experience.current_experience
147
                )
148
            else:
NEW
149
                number_replays_to_generate = len(strategy.mbatch[0])
×
NEW
150
        replay_data = self.old_generator.generate(number_replays_to_generate).to(
×
151
            strategy.device
152
        )
153
        # get labels for replay data
NEW
154
        if not self.model_is_generator:
×
NEW
155
            with torch.no_grad():
×
NEW
156
                replay_output = self.old_model(replay_data).argmax(dim=-1)
×
157
        else:
158
            # Mock labels:
NEW
159
            replay_output = torch.zeros(replay_data.shape[0])
×
160

161
        # make copy of mbatch
NEW
162
        mbatch = deepcopy(strategy.mbatch)
×
163
        # replace mbatch with replay data, calculate loss and add to strategy.loss
NEW
164
        strategy.mbatch = [replay_data, replay_output, strategy.mbatch[-1]]
×
NEW
165
        strategy.forward()
×
NEW
166
        strategy.loss += self.weight_replay_loss * strategy.criterion()
×
NEW
167
        self.weight_replay_loss *= self.weight_replay_loss_factor
×
168
        # restore mbatch
NEW
169
        strategy.mbatch = mbatch
×
170

171
    def before_training_iteration(self, strategy, **kwargs):
1✔
172
        """
173
        Generating and appending replay data to current minibatch before
174
        each training iteration.
175
        """
NEW
176
        if self.is_weighted_replay:
×
177
            # When using weighted loss, do not add replay data to the current minibatch
NEW
178
            return
×
UNCOV
179
        if self.untrained_solver:
×
180
            # The solver needs to be trained before labelling generated data and
181
            # the generator needs to be trained before we can sample.
182
            return
×
183
        # determine how many replay data points to generate
184
        if self.replay_size:
×
185
            number_replays_to_generate = self.replay_size
×
186
        else:
187
            if self.increasing_replay_size:
×
188
                number_replays_to_generate = len(strategy.mbatch[0]) * (
×
189
                    strategy.experience.current_experience
190
                )
191
            else:
192
                number_replays_to_generate = len(strategy.mbatch[0])
×
193
        # extend X with replay data
194
        replay = self.old_generator.generate(number_replays_to_generate).to(
×
195
            strategy.device
196
        )
197
        strategy.mbatch[0] = torch.cat([strategy.mbatch[0], replay], dim=0)
×
198
        # extend y with predicted labels (or mock labels if model==generator)
199
        if not self.model_is_generator:
×
200
            with torch.no_grad():
×
201
                replay_output = self.old_model(replay).argmax(dim=-1)
×
202
        else:
203
            # Mock labels:
204
            replay_output = torch.zeros(replay.shape[0])
×
205
        strategy.mbatch[1] = torch.cat(
×
206
            [strategy.mbatch[1], replay_output.to(strategy.device)], dim=0
207
        )
208
        # extend task id batch (we implicitley assume a task-free case)
209
        strategy.mbatch[-1] = torch.cat(
×
210
            [
211
                strategy.mbatch[-1],
212
                torch.ones(replay.shape[0]).to(strategy.device)
213
                * strategy.mbatch[-1][0],
214
            ],
215
            dim=0,
216
        )
217

218

219
class TrainGeneratorAfterExpPlugin(SupervisedPlugin):
1✔
220
    """
1✔
221
    TrainGeneratorAfterExpPlugin makes sure that after each experience of
222
    training the solver of a scholar model, we also train the generator on the
223
    data of the current experience.
224
    """
225

226
    def after_training_exp(self, strategy, **kwargs):
1✔
227
        """
228
        The training method expects an Experience object
229
        with a 'dataset' parameter.
230
        """
231
        for plugin in strategy.plugins:
×
232
            if type(plugin) is GenerativeReplayPlugin:
×
233
                plugin.generator_strategy.train(strategy.experience)
×
234

235

236
__all__ = ["GenerativeReplayPlugin", "TrainGeneratorAfterExpPlugin"]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc