• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5077487805

pending completion
5077487805

Pull #1387

github

Unknown Committer
Unknown Commit Message
Pull Request #1387: added topk patch

1 of 2 new or added lines in 1 file covered. (50.0%)

2 existing lines in 1 file now uncovered.

15749 of 21808 relevant lines covered (72.22%)

0.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

53.9
/avalanche/evaluation/metrics/topk_acc.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 29-03-2022                                                             #
7
# Author(s): Rudy Semola                                                       #
8
# E-mail: contact@continualai.org                                              #
9
# Website: www.continualai.org                                                 #
10
################################################################################
11

12
from typing import TYPE_CHECKING, List, Union, Dict
1✔
13

14
import torch
1✔
15
from torch import Tensor
1✔
16
import torchmetrics
1✔
17
from torchmetrics.functional import accuracy
1✔
18

19
from avalanche.evaluation import Metric, GenericPluginMetric
1✔
20
from avalanche.evaluation.metrics.mean import Mean
1✔
21
from avalanche.evaluation.metric_utils import phase_and_task
1✔
22

23
from collections import defaultdict
1✔
24
from packaging import version
1✔
25

26
if TYPE_CHECKING:
1✔
27
    from avalanche.training.templates.common_templates import SupervisedTemplate
×
28

29

30
class TopkAccuracy(Metric[Dict[int, float]]):
1✔
31
    """
1✔
32
    The Top-k Accuracy metric. This is a standalone metric.
33
    It is defined using torchmetrics.functional accuracy with top_k
34
    """
35

36
    def __init__(self, top_k: int):
1✔
37
        """
38
        Creates an instance of the standalone Top-k Accuracy metric.
39

40
        By default this metric in its initial state will return a value of 0.
41
        The metric can be updated by using the `update` method while
42
        the running top-k accuracy can be retrieved using the `result` method.
43

44
        :param top_k: integer number to define the value of k.
45
        """
46
        self._topk_acc_dict: Dict[int, Mean] = defaultdict(Mean)
1✔
47
        self.top_k: int = top_k
1✔
48

49
        self.__torchmetrics_requires_task = \
1✔
50
            version.parse(torchmetrics.__version__) >= version.parse('0.11.0')
51

52
    @torch.no_grad()
1✔
53
    def update(
1✔
54
        self,
55
        predicted_y: Tensor,
56
        true_y: Tensor,
57
        task_labels: Union[float, Tensor],
58
    ) -> None:
59
        """
60
        Update the running top-k accuracy given the true and predicted labels.
61
        Parameter `task_labels` is used to decide how to update the inner
62
        dictionary: if Float, only the dictionary value related to that task
63
        is updated. If Tensor, all the dictionary elements belonging to the
64
        task labels will be updated.
65

66
        :param predicted_y: The model prediction. Both labels and logit vectors
67
            are supported.
68
        :param true_y: The ground truth. Both labels and one-hot vectors
69
            are supported.
70
        :param task_labels: the int task label associated to the current
71
            experience or the task labels vector showing the task label
72
            for each pattern.
73

74
        :return: None.
75
        """
76
        if len(true_y) != len(predicted_y):
1✔
77
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
×
78

79
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
1✔
80
            raise ValueError("Size mismatch for true_y and task_labels tensors")
×
81

82
        true_y = torch.as_tensor(true_y)
1✔
83
        predicted_y = torch.as_tensor(predicted_y)
1✔
84

85
        if isinstance(task_labels, int):
1✔
86
            total_patterns = len(true_y)
1✔
87
            self._topk_acc_dict[task_labels].update(
1✔
88
                self._compute_topk_acc(predicted_y, true_y, top_k=self.top_k), 
89
                total_patterns
90
            )
91
        elif isinstance(task_labels, Tensor):
×
92
            for pred, true, t in zip(predicted_y, true_y, task_labels):
×
93
                self._topk_acc_dict[int(t)].update(
×
94
                    self._compute_topk_acc(pred, true, top_k=self.top_k),
95
                    1
96
                )
97
        else:
98
            raise ValueError(
×
99
                f"Task label type: {type(task_labels)}, "
100
                f"expected int/float or Tensor"
101
            )
102
        
103
    def _compute_topk_acc(self, pred, gt, top_k):
1✔
104
        if self.__torchmetrics_requires_task:
1✔
105
            num_classes = int(torch.max(torch.as_tensor(gt))) + 1
1✔
106
            
107
            if num_classes < top_k:
1✔
NEW
108
                return 0
×
109
            
110
            pred_t = torch.as_tensor(pred)
1✔
111
            if len(pred_t.shape) > 1:
1✔
112
                num_classes = max(num_classes, pred_t.shape[1])
1✔
113
            
114
            return accuracy(
1✔
115
                pred,
116
                gt,
117
                task="multiclass",
118
                num_classes=num_classes,
119
                top_k=self.top_k
120
            )
121
        else:
122
            return accuracy(
×
123
                pred,
124
                gt,
125
                top_k=self.top_k)
126
        
127
    def result_task_label(self, task_label: int) -> Dict[int, float]:
1✔
128
        """
129
        Retrieves the running top-k accuracy.
130

131
        Calling this method will not change the internal state of the metric.
132

133
        :param task_label: if None, return the entire dictionary of accuracies
134
            for each task. Otherwise return the dictionary
135
            
136
        :return: A dictionary `{task_label: topk_accuracy}`, where the accuracy
137
            is a float value between 0 and 1.
138
        """
139
        assert task_label is not None
×
140
        return {task_label: self._topk_acc_dict[task_label].result()}
×
141

142
    def result(self) -> Dict[int, float]:
1✔
143
        """
144
        Retrieves the running top-k accuracy for all tasks.
145

146
        Calling this method will not change the internal state of the metric.
147

148
        :return: A dict of running top-k accuracies for each task label,
149
            where each value is a float value between 0 and 1.
150
        """
151
        return {k: v.result() for k, v in self._topk_acc_dict.items()}
1✔
152
    
153
    def reset(self, task_label=None) -> None:
1✔
154
        """
155
        Resets the metric.
156
        :param task_label: if None, reset the entire dictionary.
157
            Otherwise, reset the value associated to `task_label`.
158

159
        :return: None.
160
        """
161
        assert task_label is None or isinstance(task_label, int)
1✔
162
        if task_label is None:
1✔
163
            self._topk_acc_dict = defaultdict(Mean)
1✔
164
        else:
165
            self._topk_acc_dict[task_label].reset()
×
166

167

168
class TopkAccuracyPluginMetric(
1✔
169
        GenericPluginMetric[
170
            Dict[int, float],
171
            TopkAccuracy]):
172
    """
1✔
173
    Base class for all top-k accuracies plugin metrics
174
    """
175

176
    def __init__(self, reset_at, emit_at, mode, top_k):
1✔
177
        super(TopkAccuracyPluginMetric, self).__init__(
×
178
            TopkAccuracy(top_k=top_k),
179
            reset_at=reset_at,
180
            emit_at=emit_at,
181
            mode=mode
182
        )
183

184
    def reset(self, strategy=None) -> None:
1✔
185
        if self._reset_at == "stream" or strategy is None:
×
186
            self._metric.reset()
×
187
        else:
188
            self._metric.reset(phase_and_task(strategy)[1])
×
189

190
    def result(self, strategy=None) -> Dict[int, float]:
1✔
191
        if self._emit_at == "stream" or strategy is None:
×
192
            return self._metric.result()
×
193
        else:
194
            return self._metric.result_task_label(phase_and_task(strategy)[1])
×
195

196
    def update(self, strategy: "SupervisedTemplate"):
1✔
197
        assert strategy.experience is not None
×
198
        # task labels defined for each experience
199
        task_labels = strategy.experience.task_labels
×
200
        if len(task_labels) > 1:
×
201
            # task labels defined for each pattern
202
            task_labels = strategy.mb_task_id
×
203
        else:
204
            task_labels = task_labels[0]
×
205
        self._metric.update(strategy.mb_output, strategy.mb_y, task_labels)
×
206

207

208
class MinibatchTopkAccuracy(TopkAccuracyPluginMetric):
1✔
209
    """
1✔
210
    The minibatch plugin top-k accuracy metric.
211
    This metric only works at training time.
212

213
    This metric computes the average top-k accuracy over patterns
214
    from a single minibatch.
215
    It reports the result after each iteration.
216
    """
217

218
    def __init__(self, top_k):
1✔
219
        """
220
        Creates an instance of the MinibatchTopkAccuracy metric.
221
        """
222
        super(MinibatchTopkAccuracy, self).__init__(
×
223
            reset_at="iteration", emit_at="iteration", mode="train", top_k=top_k
224
        )
225
        self.top_k = top_k
×
226

227
    def __str__(self):
1✔
228
        return "Topk_" + str(self.top_k) + "_Acc_MB"
×
229

230

231
class EpochTopkAccuracy(TopkAccuracyPluginMetric):
1✔
232
    """
1✔
233
    The average top-k accuracy over a single training epoch.
234
    This plugin metric only works at training time.
235

236
    The top-k accuracy will be logged after each training epoch by computing
237
    the number of correctly predicted patterns during the epoch divided by
238
    the overall number of patterns encountered in that epoch.
239
    """
240

241
    def __init__(self, top_k):
1✔
242
        """
243
        Creates an instance of the EpochTopkAccuracy metric.
244
        """
245

246
        super(EpochTopkAccuracy, self).__init__(
×
247
            reset_at="epoch", emit_at="epoch", mode="train", top_k=top_k
248
        )
249
        self.top_k = top_k
×
250

251
    def __str__(self):
1✔
252
        return "Topk_" + str(self.top_k) + "_Acc_Epoch"
×
253

254

255
class RunningEpochTopkAccuracy(TopkAccuracyPluginMetric):
1✔
256
    """
1✔
257
    The average top-k accuracy across all minibatches up to the current
258
    epoch iteration.
259
    This plugin metric only works at training time.
260

261
    At each iteration, this metric logs the top-k accuracy averaged over all
262
    patterns seen so far in the current epoch.
263
    The metric resets its state after each training epoch.
264
    """
265

266
    def __init__(self, top_k):
1✔
267
        """
268
        Creates an instance of the RunningEpochTopkAccuracy metric.
269
        """
270

271
        super(RunningEpochTopkAccuracy, self).__init__(
×
272
            reset_at="epoch", emit_at="iteration", mode="train", top_k=top_k
273
        )
274
        self.top_k = top_k
×
275

276
    def __str__(self):
1✔
277
        return "Topk_" + str(self.top_k) + "_Acc_Epoch"
×
278

279

280
class ExperienceTopkAccuracy(TopkAccuracyPluginMetric):
1✔
281
    """
1✔
282
    At the end of each experience, this plugin metric reports
283
    the average top-k accuracy over all patterns seen in that experience.
284
    This metric only works at eval time.
285
    """
286

287
    def __init__(self, top_k):
1✔
288
        """
289
        Creates an instance of the ExperienceTopkAccuracy metric.
290
        """
291
        super(ExperienceTopkAccuracy, self).__init__(
×
292
            reset_at="experience",
293
            emit_at="experience",
294
            mode="eval",
295
            top_k=top_k,
296
        )
297
        self.top_k = top_k
×
298

299
    def __str__(self):
1✔
300
        return "Topk_" + str(self.top_k) + "_Acc_Exp"
×
301

302

303
class TrainedExperienceTopkAccuracy(TopkAccuracyPluginMetric):
1✔
304
    """
1✔
305
    At the end of each experience, this plugin metric reports the average
306
    top-k accuracy for only the experiences
307
    that the model has been trained on so far.
308

309
    This metric only works at eval time.
310
    """
311

312
    def __init__(self, top_k):
1✔
313
        """
314
        Creates an instance of the TrainedExperienceTopkAccuracy metric.
315
        """
316
        super(TrainedExperienceTopkAccuracy, self).__init__(
×
317
            reset_at="stream", emit_at="stream", mode="eval", top_k=top_k
318
        )
319
        self._current_experience = 0
×
320
        self.top_k = top_k
×
321

322
    def after_training_exp(self, strategy): 
1✔
323
        self._current_experience = strategy.experience.current_experience
×
324
        # Reset average after learning from a new experience
325
        self.reset(strategy)
×
326
        return super().after_training_exp(strategy)
×
327

328
    def update(self, strategy):
1✔
329
        """
330
        Only update the top-k accuracy with results from experiences
331
        that have been trained on
332
        """
333
        if strategy.experience.current_experience <= self._current_experience:
×
334
            TopkAccuracyPluginMetric.update(self, strategy)
×
335

336
    def __str__(self):
1✔
337
        return "Topk_" + str(self.top_k) + "_Acc_On_Trained_Experiences"
×
338

339

340
class StreamTopkAccuracy(TopkAccuracyPluginMetric):
1✔
341
    """
1✔
342
    At the end of the entire stream of experiences, this plugin metric
343
    reports the average top-k accuracy over all patterns
344
    seen in all experiences. This metric only works at eval time.
345
    """
346

347
    def __init__(self, top_k):
1✔
348
        """
349
        Creates an instance of StreamTopkAccuracy metric
350
        """
351
        super(StreamTopkAccuracy, self).__init__(
×
352
            reset_at="stream", emit_at="stream", mode="eval", top_k=top_k
353
        )
354
        self.top_k = top_k
×
355

356
    def __str__(self):
1✔
357
        return "Topk_" + str(self.top_k) + "_Acc_Stream"
×
358

359

360
def topk_acc_metrics(
1✔
361
    *,
362
    top_k=3,
363
    minibatch=False,
364
    epoch=False,
365
    epoch_running=False,
366
    experience=False,
367
    trained_experience=False,
368
    stream=False,
369
) -> List[TopkAccuracyPluginMetric]:
370
    """
371
    Helper method that can be used to obtain the desired set of
372
    plugin metrics.
373

374
    :param minibatch: If True, will return a metric able to log
375
        the minibatch top-k accuracy at training time.
376
    :param epoch: If True, will return a metric able to log
377
        the epoch top-k accuracy at training time.
378
    :param epoch_running: If True, will return a metric able to log
379
        the running epoch top-k accuracy at training time.
380
    :param experience: If True, will return a metric able to log
381
        the top-k accuracy on each evaluation experience.
382
    :param trained_experience: If True, will return a metric able to log
383
        the average evaluation top-k accuracy only for experiences that the
384
        model has been trained on
385
    :param stream: If True, will return a metric able to log the top-k accuracy
386
        averaged over the entire evaluation stream of experiences.
387

388
    :return: A list of plugin metrics.
389
    """
390

391
    metrics: List[TopkAccuracyPluginMetric] = []
×
392
    if minibatch:
×
393
        metrics.append(MinibatchTopkAccuracy(top_k=top_k))
×
394
    if epoch:
×
395
        metrics.append(EpochTopkAccuracy(top_k=top_k))
×
396
    if epoch_running:
×
397
        metrics.append(RunningEpochTopkAccuracy(top_k=top_k))
×
398
    if experience:
×
399
        metrics.append(ExperienceTopkAccuracy(top_k=top_k))
×
400
    if trained_experience:
×
401
        metrics.append(TrainedExperienceTopkAccuracy(top_k=top_k))
×
402
    if stream:
×
403
        metrics.append(StreamTopkAccuracy(top_k=top_k))
×
404

405
    return metrics
×
406

407

408
__all__ = [
1✔
409
    "TopkAccuracy",
410
    "MinibatchTopkAccuracy",
411
    "EpochTopkAccuracy",
412
    "RunningEpochTopkAccuracy",
413
    "ExperienceTopkAccuracy",
414
    "StreamTopkAccuracy",
415
    "TrainedExperienceTopkAccuracy",
416
    "topk_acc_metrics",
417
]
418

419

420
"""
421
UNIT TEST
422
"""
423
if __name__ == "__main__":
1✔
424
    metric = topk_acc_metrics(trained_experience=True, top_k=5)
×
425
    print(metric)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc