• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5077487805

pending completion
5077487805

Pull #1387

github

Unknown Committer
Unknown Commit Message
Pull Request #1387: added topk patch

1 of 2 new or added lines in 1 file covered. (50.0%)

2 existing lines in 1 file now uncovered.

15749 of 21808 relevant lines covered (72.22%)

0.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.0
/avalanche/evaluation/metrics/amca.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 26-05-2022                                                             #
7
# Author(s): Eli Verwimp, Lorenzo Pellegrini                                   #
8
# E-mail: contact@continualai.org                                              #
9
# Website: www.continualai.org                                                 #
10
################################################################################
11

12

13
from typing import (
1✔
14
    Callable,
15
    Dict,
16
    Iterable,
17
    List,
18
    Union,
19
    TYPE_CHECKING,
20
    Optional,
21
    Sequence,
22
    Set,
23
)
24

25
fmean: Callable[[Iterable[float]], float]
1✔
26
try:
1✔
27
    from statistics import fmean
1✔
UNCOV
28
except ImportError:
×
UNCOV
29
    from statistics import mean as fmean
×
30

31
from collections import defaultdict, OrderedDict
1✔
32

33
import torch
1✔
34
from torch import Tensor
1✔
35
from avalanche.evaluation import (
1✔
36
    Metric,
37
    PluginMetric,
38
    _ExtendedGenericPluginMetric,
39
    _ExtendedPluginMetricValue,
40
)
41
from avalanche.evaluation.metric_utils import generic_get_metric_name
1✔
42
from avalanche.evaluation.metrics.class_accuracy import (
1✔
43
    ClassAccuracy,
44
    TrackedClassesType,
45
)
46

47
if TYPE_CHECKING:
1✔
48
    from avalanche.training.templates import SupervisedTemplate
×
49

50

51
class AverageMeanClassAccuracy(Metric[Dict[int, float]]):
1✔
52
    """
1✔
53
    The Average Mean Class Accuracy (AMCA) metric. This is a standalone metric
54
    used to compute more specific ones.
55

56
    Instances of this metric keeps the running average accuracy
57
    over multiple <prediction, target> pairs of Tensors,
58
    provided incrementally.
59

60
    Beware that this class does not provide mechanisms to separate scores based
61
    on the originating data stream. For this, please refer to
62
    :class:`MultiStreamAMCA`.
63

64
    The "prediction" and "target" tensors may contain plain labels or
65
    one-hot/logit vectors.
66

67
    Each time `result` is called, this metric emits the average mean accuracy
68
    as the average accuracy of all previous experiences (also considering the
69
    accuracy in the current experience).
70
    The metric expects that the :meth:`next_experience` method will be called
71
    after each experience. This is needed to consolidate the current mean
72
    accuracy. After calling :meth:`next_experience`, a new experience with
73
    accuracy 0.0 is immediately started. If you need to obtain the AMCA up to
74
    experience `t-1`, obtain the :meth:`result` before calling
75
    :meth:`next_experience`.
76

77

78
    The set of classes to be tracked can be reduced (please refer to the
79
    constructor parameters).
80

81
    The reset method will bring the metric to its initial state
82
    (tracked classes will be kept). By default, this metric in its initial state
83
    will return a `{task_id -> amca}` dictionary in which all AMCAs are set to 0
84
    (that is, the `reset` method will hardly be useful when using this metric).
85
    """
86

87
    def __init__(self, classes: Optional[TrackedClassesType] = None):
1✔
88
        """
89
        Creates an instance of the standalone AMCA metric.
90

91
        By default, this metric in its initial state will return an empty
92
        dictionary. The metric can be updated by using the `update` method
93
        while the running AMCA can be retrieved using the `result` method.
94

95
        By using the `classes` parameter, one can restrict the list of classes
96
        to be tracked and in addition will initialize the accuracy for that
97
        class to 0.0.
98

99
        Setting the `classes` parameter is very important, as the mean class
100
        accuracy may vary based on this! If the test set is fixed and contains
101
        at least a sample for each class, then it is safe to leave `classes`
102
        to None.
103

104
        :param classes: The classes to keep track of. If None (default), all
105
            classes seen are tracked. Otherwise, it can be a dict of classes
106
            to be tracked (as "task-id" -> "list of class ids") or, if running
107
            a task-free benchmark (with only task 0), a simple list of class
108
            ids. By passing this parameter, the list of classes to be considered
109
            is created immediately. This will ensure that the mean class
110
            accuracy is correctly computed. In addition, this can be used to
111
            restrict the classes that should be considered when computing the
112
            mean class accuracy.
113
        """
114
        self._class_accuracies = ClassAccuracy(classes=classes)
1✔
115
        """
116
        A dictionary "task_id -> {class_id -> Mean}".
117
        """
118

119
        # Here a Mean metric could be used as well. However, that could make it
120
        # difficult to compute the running AMCA...
121
        self._prev_exps_accuracies: Dict[int, List[float]] = defaultdict(list)
1✔
122
        """
123
        The mean class accuracy of previous experiences as a dictionary
124
        `{task_id -> [accuracies]}`.
125
        """
126

127
        self._updated_once = False
1✔
128

129
    @torch.no_grad()
1✔
130
    def update(
1✔
131
        self,
132
        predicted_y: Tensor,
133
        true_y: Tensor,
134
        task_labels: Union[int, Tensor],
135
    ) -> None:
136
        """
137
        Update the running accuracy given the true and predicted labels for each
138
        class.
139

140
        :param predicted_y: The model prediction. Both labels and logit vectors
141
            are supported.
142
        :param true_y: The ground truth. Both labels and one-hot vectors
143
            are supported.
144
        :param task_labels: the int task label associated to the current
145
            experience or the task labels vector showing the task label
146
            for each pattern.
147
        :return: None.
148
        """
149
        self._updated_once = True
1✔
150
        self._class_accuracies.update(predicted_y, true_y, task_labels)
1✔
151

152
    def result(self) -> Dict[int, float]:
1✔
153
        """
154
        Retrieves the running AMCA for each task.
155

156
        Calling this method will not change the internal state of the metric.
157

158
        :return: A dictionary `{task_id -> amca}`. The
159
            running AMCA of each task is a float value between 0 and 1.
160
        """
161
        curr_task_acc = self._get_curr_task_acc()
1✔
162

163
        all_task_ids = set(self._prev_exps_accuracies.keys())
1✔
164
        all_task_ids = all_task_ids.union(curr_task_acc.keys())
1✔
165

166
        mean_accs = OrderedDict()
1✔
167
        for task_id in sorted(all_task_ids):
1✔
168
            prev_accs = self._prev_exps_accuracies.get(task_id, list())
1✔
169
            curr_acc = curr_task_acc.get(task_id, 0)
1✔
170
            mean_accs[task_id] = fmean(prev_accs + [curr_acc])
1✔
171

172
        return mean_accs
1✔
173

174
    def next_experience(self):
1✔
175
        """
176
        Moves to the next experience.
177

178
        This will consolidate the class accuracies for the current experience.
179

180
        This method can also be safely called before even calling the `update`
181
        method for the first time. In that case, this call will be ignored.
182
        """
183
        if not self._updated_once:
1✔
184
            return
1✔
185

186
        for task_id, mean_class_acc in self._get_curr_task_acc().items():
1✔
187
            self._prev_exps_accuracies[task_id].append(mean_class_acc)
1✔
188
        self._class_accuracies.reset()
1✔
189

190
    def reset(self) -> None:
1✔
191
        """
192
        Resets the metric.
193

194
        :return: None.
195
        """
196
        self._updated_once = False
1✔
197
        self._class_accuracies.reset()
1✔
198
        self._prev_exps_accuracies.clear()
1✔
199

200
    def _get_curr_task_acc(self):
1✔
201
        task_acc = dict()
1✔
202
        class_acc = self._class_accuracies.result()
1✔
203
        for task_id, task_classes in class_acc.items():
1✔
204
            class_accuracies = list(task_classes.values())
1✔
205
            mean_class_acc = fmean(class_accuracies)
1✔
206

207
            task_acc[task_id] = mean_class_acc
1✔
208
        return task_acc
1✔
209

210

211
class MultiStreamAMCA(Metric[Dict[str, Dict[int, float]]]):
1✔
212
    """
1✔
213
    An extension of the Average Mean Class Accuracy (AMCA) metric
214
    (class:`AverageMeanClassAccuracy`) able to separate the computation of the
215
    AMCA based on the current stream.
216
    """
217

218
    def __init__(self, classes=None, streams=None):
1✔
219
        """
220
        Creates an instance of a MultiStream AMCA.
221

222
        :param classes: The list of classes to track. This has the same semantic
223
            of the `classes` parameter of class
224
            :class:`AverageMeanClassAccuracy`.
225
        :param streams: The list of streams to track. Defaults to None, which
226
            means that all stream will be tracked. This is not recommended, as
227
            you usually will want to track the "test" stream only.
228
        """
229

230
        self._limit_streams = streams
1✔
231
        if self._limit_streams is not None:
1✔
232
            self._limit_streams = set(self._limit_streams)
×
233

234
        self._limit_classes = classes
1✔
235
        self._amcas: Dict[str, AverageMeanClassAccuracy] = dict()
1✔
236

237
        self._current_stream: Optional[str] = None
1✔
238
        self._streams_in_this_phase: Set[str] = set()
1✔
239

240
    @torch.no_grad()
1✔
241
    def update(
1✔
242
        self,
243
        predicted_y: Tensor,
244
        true_y: Tensor,
245
        task_labels: Union[int, Tensor],
246
    ) -> None:
247
        """
248
        Update the running accuracy given the true and predicted labels for each
249
        class.
250

251
        This will update the accuracies for the "current stream" (the one set
252
        through `next_experience`). If `next_experience` has not been called,
253
        then an error will be raised.
254

255
        :param predicted_y: The model prediction. Both labels and logit vectors
256
            are supported.
257
        :param true_y: The ground truth. Both labels and one-hot vectors
258
            are supported.
259
        :param task_labels: the int task label associated to the current
260
            experience or the task labels vector showing the task label
261
            for each pattern.
262
        :return: None.
263
        """
264
        if self._current_stream is None:
1✔
265
            raise RuntimeError(
1✔
266
                "No current stream set. "
267
                'Call "set_stream" to set the current stream.'
268
            )
269

270
        if self._is_stream_tracked(self._current_stream):
1✔
271
            self._amcas[self._current_stream].update(
1✔
272
                predicted_y, true_y, task_labels
273
            )
274

275
    def result(self) -> Dict[str, Dict[int, float]]:
1✔
276
        """
277
        Retrieves the running AMCA for each stream.
278

279
        Calling this method will not change the internal state of the metric.
280

281
        :return: A dictionary `{stream_name -> {task_id -> amca}}`. The
282
            running AMCA of each task is a float value between 0 and 1.
283
        """
284
        all_streams_dict = OrderedDict()
1✔
285
        for stream_name in sorted(self._amcas.keys()):
1✔
286
            stream_metric = self._amcas[stream_name]
1✔
287
            stream_result = stream_metric.result()
1✔
288
            all_streams_dict[stream_name] = stream_result
1✔
289
        return all_streams_dict
1✔
290

291
    def set_stream(self, stream_name: str):
1✔
292
        """
293
        Switches to a specific stream.
294

295
        :param stream_name: The name of the stream.
296
        """
297
        self._current_stream = stream_name
1✔
298
        if not self._is_stream_tracked(stream_name):
1✔
299
            return
×
300

301
        if self._current_stream not in self._amcas:
1✔
302
            self._amcas[stream_name] = AverageMeanClassAccuracy(
1✔
303
                classes=self._limit_classes
304
            )
305
        self._streams_in_this_phase.add(stream_name)
1✔
306

307
    def finish_phase(self):
1✔
308
        """
309
        Moves to the next phase.
310

311
        This will consolidate the class accuracies recorded so far.
312
        """
313
        for stream_name in self._streams_in_this_phase:
1✔
314
            self._amcas[stream_name].next_experience()
1✔
315

316
        self._streams_in_this_phase.clear()
1✔
317

318
    def reset(self) -> None:
1✔
319
        """
320
        Resets the metric.
321

322
        :return: None.
323
        """
324
        for metric in self._amcas.values():
1✔
325
            metric.reset()
1✔
326
        self._current_stream = None
1✔
327
        self._streams_in_this_phase.clear()
1✔
328

329
    def _is_stream_tracked(self, stream_name):
1✔
330
        return self._limit_streams is None or stream_name in self._limit_streams
1✔
331

332

333
class AMCAPluginMetric(_ExtendedGenericPluginMetric):
1✔
334
    """
1✔
335
    Plugin metric for the Average Mean Class Accuracy (AMCA).
336

337
    The AMCA is tracked for the classes and streams defined in the constructor.
338

339
    In addition, by default, the results obtained through the periodic
340
    evaluation (mid-training validation) mechanism are ignored.
341
    """
342

343
    VALUE_NAME = "{metric_name}/{stream_name}_stream/Task{task_label:03}"
1✔
344

345
    def __init__(self, classes=None, streams=None, ignore_validation=True):
1✔
346
        """
347
        Instantiates the AMCA plugin metric.
348

349
        :param classes: The classes to track. Refer to :class:`MultiStreamAMCA`
350
            for more details.
351
        :param streams: The streams to track. Defaults to None, which means that
352
            all streams will be considered. Beware that, when creating instances
353
            of this class using the :func:`amca_metrics` helper, the resulting
354
            metric will only track the "test" stream by default.
355
        :param ignore_validation: Defaults to True, which means that periodic
356
            evaluations will be ignored (recommended).
357
        """
358
        self._ms_amca = MultiStreamAMCA(classes=classes, streams=streams)
×
359
        self._ignore_validation = ignore_validation
×
360

361
        self._is_training = False
×
362
        super().__init__(
×
363
            self._ms_amca, reset_at="never", emit_at="stream", mode="eval"
364
        )
365

366
    def update(self, strategy: "SupervisedTemplate"):
1✔
367
        if self._is_training and self._ignore_validation:
×
368
            # Running a validation (eval phase inside a train phase), ignore it
369
            return
×
370

371
        self._ms_amca.update(
×
372
            strategy.mb_output, strategy.mb_y, strategy.mb_task_id
373
        )
374

375
    def before_training(self, strategy: "SupervisedTemplate"):
1✔
376
        self._is_training = True
×
377
        return super().before_training(strategy)
×
378

379
    def after_training(self, strategy: "SupervisedTemplate"):
1✔
380
        self._is_training = False
×
381
        return super().after_training(strategy)
×
382

383
    def before_eval(self, strategy: "SupervisedTemplate"):
1✔
384
        # In the first eval phase, calling finish_phase will do nothing
385
        # (as expected)
386
        if not (self._is_training and self._ignore_validation):
×
387
            # If not running a validation
388
            self._ms_amca.finish_phase()
×
389
        return super().before_eval(strategy)
×
390

391
    def before_eval_exp(self, strategy: "SupervisedTemplate"):
1✔
392
        assert strategy.experience is not None
×
393
        if not (self._is_training and self._ignore_validation):
×
394
            # If not running a validation
395
            self._ms_amca.set_stream(strategy.experience.origin_stream.name)
×
396
        return super().before_eval_exp(strategy)
×
397

398
    def result(self) -> List[_ExtendedPluginMetricValue]:
1✔
399
        if self._is_training and self._ignore_validation:
×
400
            # Running a validation, ignore it
401
            return []
×
402

403
        metric_values = []
×
404
        stream_amca = self._ms_amca.result()
×
405

406
        for stream_name, stream_accs in stream_amca.items():
×
407
            for task_id, task_amca in stream_accs.items():
×
408
                metric_values.append(
×
409
                    _ExtendedPluginMetricValue(
410
                        metric_name=str(self),
411
                        metric_value=task_amca,
412
                        phase_name="eval",
413
                        stream_name=stream_name,
414
                        task_label=task_id,
415
                        experience_id=None,
416
                    )
417
                )
418

419
        return metric_values
×
420

421
    def metric_value_name(self, m_value: _ExtendedPluginMetricValue) -> str:
1✔
422
        return generic_get_metric_name(
×
423
            AMCAPluginMetric.VALUE_NAME, vars(m_value)
424
        )
425

426
    def __str__(self):
1✔
427
        return "Top1_AMCA_Stream"
×
428

429

430
def amca_metrics(streams: Sequence[str] = ("test",)) -> PluginMetric:
1✔
431
    """
432
    Helper method that can be used to obtain the desired set of
433
    plugin metrics.
434

435
    The returned metric will not compute the AMCA when the
436
    :class:`PeriodicEval` plugin is used. To change this behavior,
437
    you can instantiate a :class:`AMCAPluginMetric` by setting
438
    `ignore_validation` to False.
439

440
    :param streams: The list of streams to track. Defaults to "test" only.
441

442
    :return: The AMCA plugin metric.
443
    """
444
    return AMCAPluginMetric(streams=streams, ignore_validation=True)
×
445

446

447
__all__ = [
1✔
448
    "AverageMeanClassAccuracy",
449
    "MultiStreamAMCA",
450
    "AMCAPluginMetric",
451
    "amca_metrics",
452
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc