• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.74
/avalanche/training/templates/base.py
1
import sys
4✔
2
import warnings
4✔
3
from collections import defaultdict
4✔
4
from typing import Generic, Iterable, Sequence, Optional, TypeVar, Union, List
4✔
5

6
import torch
4✔
7
from torch.nn import Module
4✔
8

9
from avalanche.benchmarks import CLExperience, CLStream
4✔
10
from avalanche.core import BasePlugin
4✔
11
from avalanche.distributed.distributed_helper import DistributedHelper
4✔
12
from avalanche.training.templates.strategy_mixin_protocol import \
4✔
13
    BaseStrategyProtocol
14
from avalanche.training.utils import trigger_plugins
4✔
15

16

17
TExperienceType = TypeVar('TExperienceType', bound=CLExperience)
4✔
18
TPluginType = TypeVar('TPluginType', bound=BasePlugin, contravariant=True)
4✔
19

20

21
class BaseTemplate(BaseStrategyProtocol[TExperienceType]):
4✔
22
    """Base class for continual learning skeletons.
4✔
23

24
    **Training loop**
25
    The training loop is organized as follows::
26

27
        train
28
            train_exp  # for each experience
29

30
    **Evaluation loop**
31
    The evaluation loop is organized as follows::
32

33
        eval
34
            eval_exp  # for each experience
35

36
    """
37

38
    # we need this only for type checking
39
    PLUGIN_CLASS = BasePlugin
4✔
40

41
    def __init__(
4✔
42
        self,
43
        model: Module,
44
        device: Union[str, torch.device] = "cpu",
45
        plugins: Optional[Sequence[BasePlugin]] = None,
46
    ):
47
        super().__init__()
4✔
48
        """Init."""
1✔
49

50
        self.model: Module = model
4✔
51
        """ PyTorch model. """
1✔
52

53
        if device is None:
4✔
NEW
54
            warnings.warn(
×
55
                'When instantiating a strategy, please pass a non-None device.'
56
            )
UNCOV
57
            device = 'cpu'
×
58

59
        self.device = torch.device(device)
4✔
60
        """ PyTorch device where the model will be allocated. """
1✔
61

62
        self.plugins: List[BasePlugin] = [] \
4✔
63
            if plugins is None else list(plugins)
64
        """ List of `SupervisedPlugin`s. """
1✔
65

66
        # check plugin compatibility
67
        self._check_plugin_compatibility()
4✔
68

69
        ###################################################################
70
        # State variables. These are updated during the train/eval loops. #
71
        ###################################################################
72
        self.experience: Optional[TExperienceType] = None
4✔
73
        """ Current experience. """
1✔
74

75
        self.is_training: bool = False
4✔
76
        """ True if the strategy is in training mode. """
1✔
77

78
        self.current_eval_stream: Iterable[TExperienceType] = []
4✔
79
        """ Current evaluation stream. """
1✔
80

81
        self._distributed_check: bool = False
4✔
82
        """
1✔
83
        Internal flag used to verify the support for distributed
84
        training only once.
85
        """
86

87
        ###################################################################
88
        # Other variables #
89
        ###################################################################
90
        self._eval_streams: Optional[List[List[CLExperience]]] = None
4✔
91

92
    @property
4✔
93
    def is_eval(self):
3✔
94
        """True if the strategy is in evaluation mode."""
95
        return not self.is_training
4✔
96

97
    def train(
4✔
98
        self,
99
        experiences: Union[TExperienceType, Iterable[TExperienceType]],
100
        eval_streams: Optional[
101
            Sequence[Union[TExperienceType, Iterable[TExperienceType]]]
102
        ] = None,
103
        **kwargs,
104
    ):
105
        """Training loop.
106

107
        If experiences is a single element trains on it.
108
        If it is a sequence, trains the model on each experience in order.
109
        This is different from joint training on the entire stream.
110
        It returns a dictionary with last recorded value for each metric.
111

112
        :param experiences: single Experience or sequence.
113
        :param eval_streams: sequence of streams for evaluation.
114
            If None: use training experiences for evaluation.
115
            Use [] if you do not want to evaluate during training.
116
            Experiences in `eval_streams` are grouped by stream name
117
            when calling `eval`. If you use multiple streams, they must
118
            have different names.
119
        """
120
        if not self._distributed_check:
4✔
121
            # Checks if the strategy elements are compatible with 
122
            # distributed training
123
            self._check_distributed_training_compatibility()
4✔
124
            self._distributed_check = True
4✔
125
        
126
        self.is_training = True
4✔
127
        self._stop_training = False
4✔
128

129
        self.model.train()
4✔
130
        self.model.to(self.device)
4✔
131

132
        # Normalize training and eval data.
133
        experiences_list: Iterable[TExperienceType] = \
4✔
134
            _experiences_parameter_as_iterable(experiences)
135

136
        if eval_streams is None:
4✔
137
            eval_streams = [experiences_list]
4✔
138

139
        self._eval_streams = _group_experiences_by_stream(eval_streams)
4✔
140

141
        self._before_training(**kwargs)
4✔
142

143
        for self.experience in experiences_list:
4✔
144
            self._before_training_exp(**kwargs)
4✔
145
            self._train_exp(self.experience, eval_streams, **kwargs)
4✔
146
            self._after_training_exp(**kwargs)
4✔
147
        self._after_training(**kwargs)
4✔
148
        self._train_cleanup()
4✔
149

150
    def _train_cleanup(self):
4✔
151
        # reset _eval_streams for faster serialization
152
        self._eval_streams = None
4✔
153
        self.experience = None
4✔
154

155
    def _train_exp(self, experience: CLExperience, eval_streams, **kwargs):
4✔
156
        raise NotImplementedError()
×
157

158
    @torch.no_grad()
4✔
159
    def eval(
4✔
160
        self,
161
        experiences: Union[TExperienceType, CLStream[TExperienceType]],
162
        **kwargs,
163
    ):
164
        """
165
        Evaluate the current model on a series of experiences and
166
        returns the last recorded value for each metric.
167

168
        :param exp_list: CL experience information.
169
        :param kwargs: custom arguments.
170

171
        :return: dictionary containing last recorded value for
172
            each metric name
173
        """
174
        if not self._distributed_check:
4✔
175
            # Checks if the strategy elements are compatible with 
176
            # distributed training
NEW
177
            self._check_distributed_training_compatibility()
×
NEW
178
            self._distributed_check = True
×
179
        
180
        # eval can be called inside the train method.
181
        # Save the shared state here to restore before returning.
182
        prev_train_state = self._save_train_state()
4✔
183
        self.is_training = False
4✔
184
        self.model.eval()
4✔
185

186
        experiences_list: Iterable[TExperienceType] = \
4✔
187
            _experiences_parameter_as_iterable(experiences)
188
        self.current_eval_stream = experiences_list
4✔
189

190
        self._before_eval(**kwargs)
4✔
191
        for self.experience in experiences_list:
4✔
192
            self._before_eval_exp(**kwargs)
4✔
193
            self._eval_exp(**kwargs)
4✔
194
            self._after_eval_exp(**kwargs)
4✔
195

196
        self._after_eval(**kwargs)
4✔
197
        self._eval_cleanup()
4✔
198

199
        # restore previous shared state.
200
        self._load_train_state(prev_train_state)
4✔
201

202
    def _eval_cleanup(self):
4✔
203
        # reset for faster serialization
204
        self.current_eval_stream = []
4✔
205
        self.experience = None
4✔
206

207
    def _eval_exp(self, **kwargs):
4✔
208
        raise NotImplementedError()
×
209

210
    def _save_train_state(self):
4✔
211
        """Save the training state, which may be modified by the eval loop.
212

213
        TODO: we probably need a better way to do this.
214
        """
215
        # save each layer's training mode, to restore it later
216
        _prev_model_training_modes = {}
4✔
217
        for name, layer in self.model.named_modules():
4✔
218
            _prev_model_training_modes[name] = layer.training
4✔
219

220
        _prev_state = {
4✔
221
            "experience": self.experience,
222
            "is_training": self.is_training,
223
            "model_training_mode": _prev_model_training_modes,
224
        }
225
        return _prev_state
4✔
226

227
    def _load_train_state(self, prev_state):
4✔
228
        # restore train-state variables and training mode.
229
        self.experience = prev_state["experience"]
4✔
230
        self.is_training = prev_state["is_training"]
4✔
231

232
        # restore each layer's training mode to original
233
        prev_training_modes = prev_state["model_training_mode"]
4✔
234
        for name, layer in self.model.named_modules():
4✔
235
            try:
4✔
236
                prev_mode = prev_training_modes[name]
4✔
237
                layer.train(mode=prev_mode)
4✔
238
            except KeyError:
4✔
239
                # Unknown parameter, probably added during the eval
240
                # model's adaptation. We set it to train mode.
241
                layer.train()
4✔
242

243
    def _check_plugin_compatibility(self):
4✔
244
        """Check that the list of plugins is compatible with the template.
245

246
        This means checking that each plugin impements a subset of the
247
        supported callbacks.
248
        """
249
        # TODO: ideally we would like to check the argument's type to check
250
        #  that it's a supertype of the template.
251
        # I don't know if it's possible to do it in Python.
252
        ps = self.plugins
4✔
253

254
        def get_plugins_from_object(obj):
4✔
255
            def is_callback(x):
4✔
256
                return x.startswith("before") or x.startswith("after")
4✔
257

258
            return filter(is_callback, dir(obj))
4✔
259

260
        cb_supported = set(get_plugins_from_object(self.PLUGIN_CLASS))
4✔
261
        for p in ps:
4✔
262
            cb_p = set(get_plugins_from_object(p))
4✔
263

264
            if not cb_p.issubset(cb_supported):
4✔
265
                warnings.warn(
×
266
                    f"Plugin {p} implements incompatible callbacks for template"
267
                    f" {self}. This may result in errors. Incompatible "
268
                    f"callbacks: {cb_p - cb_supported}",
269
                )
270
                return
×
271
            
272
    def _check_distributed_training_compatibility(self):
4✔
273
        """
274
        Check if strategy elements (plugins, ...) are compatible with
275
        distributed training.
276
        This check does nothing if not training in distributed mode.
277
        """
278
        if not DistributedHelper.is_distributed:
4✔
279
            return True
4✔
280

NEW
281
        unsupported_plugins = []
×
NEW
282
        for plugin in self.plugins:
×
NEW
283
            if not getattr(plugin, "supports_distributed", False):
×
NEW
284
                unsupported_plugins.append(plugin)
×
285

NEW
286
        if len(unsupported_plugins) > 0:
×
NEW
287
            warnings.warn('You are using plugins that are not compatible'
×
288
                          'with distributed training:')
NEW
289
            for plugin in unsupported_plugins:
×
NEW
290
                print(type(plugin), file=sys.stderr)
×
291

NEW
292
        return len(unsupported_plugins) == 0
×
293

294
    #########################################################
295
    # Plugin Triggers                                       #
296
    #########################################################
297

298
    def _before_training_exp(self, **kwargs):
4✔
299
        trigger_plugins(self, "before_training_exp", **kwargs)
4✔
300

301
    def _after_training_exp(self, **kwargs):
4✔
302
        trigger_plugins(self, "after_training_exp", **kwargs)
4✔
303

304
    def _before_training(self, **kwargs):
4✔
305
        trigger_plugins(self, "before_training", **kwargs)
4✔
306

307
    def _after_training(self, **kwargs):
4✔
308
        trigger_plugins(self, "after_training", **kwargs)
4✔
309

310
    def _before_eval(self, **kwargs):
4✔
311
        trigger_plugins(self, "before_eval", **kwargs)
4✔
312

313
    def _after_eval(self, **kwargs):
4✔
314
        trigger_plugins(self, "after_eval", **kwargs)
4✔
315

316
    def _before_eval_exp(self, **kwargs):
4✔
317
        trigger_plugins(self, "before_eval_exp", **kwargs)
4✔
318

319
    def _after_eval_exp(self, **kwargs):
4✔
320
        trigger_plugins(self, "after_eval_exp", **kwargs)
4✔
321

322

323
def _group_experiences_by_stream(
4✔
324
    eval_streams: Iterable[Union[Iterable[CLExperience], CLExperience]]
325
) -> List[List[CLExperience]]:
326

327
    exps: List[CLExperience] = []
4✔
328
    # First, we unpack the list of experiences.
329
    for exp in eval_streams:
4✔
330
        if isinstance(exp, Iterable):
4✔
331
            exps.extend(exp)
4✔
332
        else:
333
            exps.append(exp)
4✔
334
    # Then, we group them by stream.
335
    exps_by_stream = defaultdict(list)
4✔
336
    for exp in exps:
4✔
337
        sname = exp.origin_stream.name
4✔
338
        exps_by_stream[sname].append(exp)
4✔
339
    # Finally, we return a list of lists.
340
    return list(list(exps_by_stream.values()))
4✔
341

342

343
def _experiences_parameter_as_iterable(
4✔
344
    experiences: Union[Iterable[TExperienceType], TExperienceType]
345
) -> Iterable[TExperienceType]:
346
    if isinstance(experiences, Iterable):
4✔
347
        return experiences
4✔
348
    else:
349
        return [experiences]
4✔
350

351

352
__all__ = [
4✔
353
    'BaseTemplate'
354
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc