• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.97
/avalanche/training/supervised/l2p.py
1
from typing import Callable, List, Optional, Union
4✔
2

3
import numpy as np
4✔
4
import torch
4✔
5
import torch.nn as nn
4✔
6
from avalanche.training.plugins import EvaluationPlugin
4✔
7
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
4✔
8
from avalanche.training.plugins.evaluation import default_evaluator
4✔
9
from avalanche.training.templates import SupervisedTemplate
4✔
10
from avalanche.models.vit import create_model
4✔
11

12

13
class LearningToPrompt(SupervisedTemplate):
4✔
14
    """
4✔
15
    Learning to Prompt (L2P) strategy.
16

17
    Technique introduced in:
18
    "Wang, Zifeng, et al. "Learning to prompt for continual learning." 
19
    Proceedings of the IEEE/CVF Conference on Computer Vision and 
20
    Pattern Recognition. 2022."
21

22
    Implementation based on:
23
    - https://github.com/JH-LEE-KR/l2p-pytorch
24
    - And implementations by Dario Salvati
25

26
    As a model_name, we expect to receive one of the model list in 
27
    avalanche.models.vit
28

29
    Those models are based on the library timm.
30
    """
31

32
    def __init__(
4✔
33
        self,
34
        model_name: str,
35
        criterion: nn.Module = nn.CrossEntropyLoss(),
36
        train_mb_size: int = 1,
37
        train_epochs: int = 1,
38
        eval_mb_size: Optional[int] = 1,
39
        device: Union[str, torch.device] = "cpu",
40
        plugins: Optional[List["SupervisedPlugin"]] = None,
41
        evaluator: Union[
42
            EvaluationPlugin,
43
            Callable[[], EvaluationPlugin]
44
        ] = default_evaluator,
45
        eval_every: int = -1,
46
        peval_mode: str = "epoch",
47
        prompt_pool: bool = True,
48
        pool_size: int = 20,
49
        prompt_length: int = 5,
50
        top_k: int = 5,
51
        lr: float = 0.03,
52
        sim_coefficient: float = 0.1,
53
        prompt_key: bool = True,
54
        pretrained: bool = True,
55
        num_classes: int = 10,
56
        drop_rate: float = 0.0,
57
        drop_path_rate: float = 0.0,
58
        embedding_key: str = "cls",
59
        prompt_init: str = "uniform",
60
        batchwise_prompt: bool = False,
61
        head_type: str = "prompt",
62
        use_prompt_mask: bool = False,
63
        train_prompt_mask: bool = False,
64
        use_cls_features: bool = True,
65
        use_mask: bool = True,
66
        use_vit: bool = True,
67
        **kwargs,
68
    ):
69
        """Init.
70

71
        :param model_name: Name of the model to use. For a complete list check \
72
            models.vit.py
73
        :param criterion: Loss functions used during training. \
74
            Default CrossEntropyLoss.
75
        :param train_mb_size: The train minibatch size. Defaults to 1.
76
        :param train_epochs: The number of training epochs. Defaults to 1.
77
        :param eval_mb_size: The eval minibatch size. Defaults to 1.
78
        :param device: The device to use. Defaults to None (cpu).
79
        :param plugins: Plugins to be added. Defaults to None.
80
        :param evaluator: (optional) instance of EvaluationPlugin for logging
81
            and metric computations.
82
        :param eval_every: the frequency of the calls to `eval` inside the
83
            training loop. -1 disables the evaluation. 0 means `eval` is called
84
            only at the end of the learning experience. Values >0 mean that
85
            `eval` is called every `eval_every` epochs and at the end of the
86
            learning experience.
87
        :param use_cls_features: Use an external pre-trained model to obtained\
88
             features to obtained the prompts.
89
        :param use_mask: Use mask to train only classification rows of the \
90
            classes of the current task. Default True.
91
        :param use_vit: Boolean to confirm the usage of a visual Transformer.\
92
            Default True
93
        """
94

95
        if device is None:
4✔
NEW
96
            device = torch.device("cpu")
×
97
        
98
        self.num_classes = num_classes
4✔
99
        self.lr = lr
4✔
100
        self.sim_coefficient = sim_coefficient
4✔
101
        model = create_model(
4✔
102
            model_name=model_name,
103
            prompt_pool=prompt_pool,
104
            pool_size=pool_size,
105
            prompt_length=prompt_length,
106
            top_k=top_k,
107
            prompt_key=prompt_key,
108
            pretrained=pretrained,
109
            num_classes=num_classes,
110
            drop_rate=drop_rate,
111
            drop_path_rate=drop_path_rate,
112
            embedding_key=embedding_key,
113
            prompt_init=prompt_init,
114
            batchwise_prompt=batchwise_prompt,
115
            head_type=head_type,
116
            use_prompt_mask=use_prompt_mask,
117
        )
118

119
        for n, p in model.named_parameters():
4✔
120
            if n.startswith(tuple(["blocks", "patch_embed", 
4✔
121
                                   "cls_token", "norm", "pos_embed"])):
122
                p.requires_grad = False
×
123
        
124
        model.head = torch.nn.Linear(768, num_classes).to(device)
4✔
125

126
        optimizer = torch.optim.Adam(
4✔
127
            model.parameters(),
128
            betas=(0.9, 0.999),
129
            lr=self.lr,
130
        )
131

132
        super().__init__(
4✔
133
            model,
134
            optimizer,
135
            criterion,
136
            train_mb_size,
137
            train_epochs,
138
            eval_mb_size,
139
            device,
140
            plugins,
141
            evaluator,
142
            eval_every,
143
            peval_mode,
144
        )
145

146
        self._criterion = criterion
4✔
147
        self.use_cls_features = use_cls_features
4✔
148
        self.train_prompt_mask = train_prompt_mask
4✔
149
        self.use_mask = use_mask
4✔
150
        self.use_vit = use_vit
4✔
151

152
        if use_cls_features:
4✔
153
            self.original_vit = create_model(
×
154
                model_name=model_name,
155
                pretrained=pretrained,
156
                num_classes=num_classes,
157
                drop_rate=drop_rate,
158
                drop_path_rate=drop_path_rate,
159
            ).to(device)
160

161
            self.original_vit.reset_classifier(0)
×
162

163
            for p in self.original_vit.parameters():
×
164
                p.requires_grad = False
×
165

166
    def _before_training_exp(self, **kwargs):
4✔
167
        super()._before_training_exp(**kwargs)
4✔
168
        self.optimizer = torch.optim.Adam(
4✔
169
            self.model.parameters(),
170
            betas=(0.9, 0.999),
171
            lr=self.lr,
172
        )
173

174
    def forward(self):
4✔
175
        assert self.experience is not None
4✔
176
        if self.use_cls_features:
4✔
177
            with torch.no_grad():
×
178
                cls_features = self.original_vit(self.mb_x)["pre_logits"]
×
179
        else:
180
            cls_features = None
4✔
181

182
        if self.use_vit:
4✔
183
            self.res = self.model(
×
184
                x=self.mb_x,
185
                task_id=self.mb_task_id,
186
                cls_features=cls_features,
187
                train=self.train_prompt_mask,
188
            )
189
        else:
190
            self.res = {}
4✔
191
            self.res["logits"] = self.model(x=self.mb_x)
4✔
192
            self.res["reduce_sim"] = 0
4✔
193

194
        logits = self.res["logits"]
4✔
195

196
        if self.use_mask and self.is_training:
4✔
197
            mask = self.experience.classes_in_this_experience
×
198
            not_mask = np.setdiff1d(np.arange(self.num_classes), mask)
×
199
            not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device)
×
200
            logits = logits.index_fill(dim=1, 
×
201
                                       index=not_mask, 
202
                                       value=float("-inf"))
203

204
        return logits
4✔
205

206
    def criterion(self):
4✔
207
        loss = self._criterion(self.mb_output, self.mb_y)
4✔
208
        loss = loss - self.sim_coefficient * self.res["reduce_sim"]
4✔
209
        return loss
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc