• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 4993189103

pending completion
4993189103

Pull #1370

github

Unknown Committer
Unknown Commit Message
Pull Request #1370: Add base elements to support distributed comms. Add supports_distributed plugin flag.

258 of 822 new or added lines in 27 files covered. (31.39%)

80 existing lines in 5 files now uncovered.

15585 of 21651 relevant lines covered (71.98%)

2.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.39
/avalanche/training/plugins/evaluation.py
1
import warnings
4✔
2
from copy import copy
4✔
3
from collections import defaultdict
4✔
4
from typing import (
4✔
5
    Any,
6
    Callable,
7
    Dict,
8
    List,
9
    Optional,
10
    Tuple,
11
    Union,
12
    Sequence,
13
    TYPE_CHECKING,
14
)
15
from avalanche.distributed.distributed_helper import DistributedHelper
4✔
16

17
from avalanche.evaluation.metric_results import MetricValue
4✔
18
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
4✔
19
from avalanche.logging import InteractiveLogger
4✔
20

21
if TYPE_CHECKING:
4✔
22
    from avalanche.evaluation import PluginMetric
×
23
    from avalanche.logging import BaseLogger
×
24
    from avalanche.training.templates import SupervisedTemplate
×
25

26

27
def _init_metrics_list_lambda():
4✔
28
    # SERIALIZATION NOTICE: we need these because lambda serialization
29
    # does not work in some cases (yes, even with dill).
30
    return [], []
4✔
31

32

33
class EvaluationPlugin:
4✔
34
    """Manager for logging and metrics.
4✔
35

36
    An evaluation plugin that obtains relevant data from the
37
    training and eval loops of the strategy through callbacks.
38
    The plugin keeps a dictionary with the last recorded value for each metric.
39
    The dictionary will be returned by the `train` and `eval` methods of the
40
    strategies.
41
    It is also possible to keep a dictionary with all recorded metrics by
42
    specifying `collect_all=True`. The dictionary can be retrieved via
43
    the `get_all_metrics` method.
44

45
    This plugin also logs metrics using the provided loggers.
46
    """
47

48
    def __init__(
4✔
49
        self,
50
        *metrics: Union["PluginMetric", Sequence["PluginMetric"]],
51
        loggers: Optional[Union[
52
            "BaseLogger",
53
            Sequence["BaseLogger"],
54
            Callable[[], Sequence["BaseLogger"]]]] = None,
55
        collect_all=True,
56
        strict_checks=False
57
    ):
58
        """Creates an instance of the evaluation plugin.
59

60
        :param metrics: The metrics to compute.
61
        :param loggers: The loggers to be used to log the metric values.
62
        :param collect_all: if True, collect in a separate dictionary all
63
            metric curves values. This dictionary is accessible with
64
            `get_all_metrics` method.
65
        :param strict_checks: if True, checks that the full evaluation streams
66
            is used when calling `eval`. An error will be raised otherwise.
67
        """
68
        super().__init__()
4✔
69
        self.supports_distributed = True
4✔
70
        self.collect_all = collect_all
4✔
71
        self.strict_checks = strict_checks
4✔
72

73
        flat_metrics_list = []
4✔
74
        for metric in metrics:
4✔
75
            if isinstance(metric, Sequence):
4✔
76
                flat_metrics_list += list(metric)
4✔
77
            else:
78
                flat_metrics_list.append(metric)
4✔
79
        self.metrics = flat_metrics_list
4✔
80

81
        if loggers is None:
4✔
82
            loggers = []
4✔
83
        elif callable(loggers):
4✔
84
            loggers = loggers()
4✔
85
        elif not isinstance(loggers, Sequence):
4✔
86
            loggers = [loggers]
×
87

88
        self.loggers: Sequence["BaseLogger"] = loggers
4✔
89

90
        if len(self.loggers) == 0 and DistributedHelper.is_main_process:
4✔
91
            warnings.warn("No loggers specified, metrics will not be logged")
4✔
92

93
        self.all_metric_results: Dict[str, Tuple[List[int], List[Any]]]
4✔
94
        if self.collect_all:
4✔
95
            # for each curve collect all emitted values.
96
            # dictionary key is full metric name.
97
            # Dictionary value is a tuple of two lists.
98
            # first list gathers x values (indices representing
99
            # time steps at which the corresponding metric value
100
            # has been emitted)
101
            # second list gathers metric values
102
            # SERIALIZATION NOTICE: don't use a lambda here, otherwise
103
            # serialization may fail in some cases.
104
            self.all_metric_results = defaultdict(_init_metrics_list_lambda)
4✔
105
        else:
106
            self.all_metric_results = dict()
×
107
        
108
        # Dictionary of last values emitted. Dictionary key
109
        # is the full metric name, while dictionary value is
110
        # metric value.
111
        self.last_metric_results: Dict[str, Any] = {}
4✔
112

113
        self._active = True
4✔
114
        """If False, no metrics will be collected."""
1✔
115

116
        self._metric_values: List[MetricValue] = []
4✔
117
        """List of metrics that have yet to be processed by loggers."""
4✔
118

119
    @property
4✔
120
    def active(self):
3✔
121
        return self._active
×
122

123
    @active.setter
4✔
124
    def active(self, value):
3✔
125
        assert (
×
126
            value is True or value is False
127
        ), "Active must be set as either True or False"
128
        self._active = value
×
129

130
    def publish_metric_value(self, mval: MetricValue):
4✔
131
        """Publish a MetricValue to be processed by the loggers."""
132
        self._metric_values.append(mval)
4✔
133

134
        name = mval.name
4✔
135
        x = mval.x_plot
4✔
136
        val = mval.value
4✔
137
        if self.collect_all:
4✔
138
            self.all_metric_results[name][0].append(x)
4✔
139
            self.all_metric_results[name][1].append(val)
4✔
140
        self.last_metric_results[name] = val
4✔
141

142
    def _update_metrics_and_loggers(
4✔
143
        self, strategy: "SupervisedTemplate", callback: str
144
    ):
145
        """Call the metric plugins with the correct callback `callback` and
146
        update the loggers with the new metric values."""
147
        original_experience = strategy.experience
4✔
148
        if original_experience is not None:
4✔
149
            # Set experience to LOGGING so that certain fields can be accessed
150
            strategy.experience = original_experience.logging()
4✔
151
        try:
4✔
152
            if not self._active:
4✔
153
                return []
×
154

155
            for metric in self.metrics:
4✔
156
                if hasattr(metric, callback):
4✔
157
                    metric_result = getattr(metric, callback)(strategy)
4✔
158
                    if isinstance(metric_result, Sequence):
4✔
159
                        for mval in metric_result:
4✔
160
                            self.publish_metric_value(mval)
4✔
161
                    elif metric_result is not None:
4✔
162
                        self.publish_metric_value(metric_result)
×
163

164
            for logger in self.loggers:
4✔
165
                logger.log_metrics(self._metric_values)
4✔
166
                if hasattr(logger, callback):
4✔
167
                    getattr(logger, callback)(strategy, self._metric_values)
4✔
168
            self._metric_values = []
4✔
169
        finally:
170
            # Revert to previous experience (mode = EVAL or TRAIN)
171
            strategy.experience = original_experience
4✔
172

173
    def get_last_metrics(self):
4✔
174
        """
175
        Return a shallow copy of dictionary with metric names
176
        as keys and last metrics value as values.
177

178
        :return: a dictionary with full metric
179
            names as keys and last metric value as value.
180
        """
181
        return copy(self.last_metric_results)
4✔
182

183
    def get_all_metrics(self):
4✔
184
        """
185
        Return the dictionary of all collected metrics.
186
        This method should be called only when `collect_all` is set to True.
187

188
        :return: if `collect_all` is True, returns a dictionary
189
            with full metric names as keys and a tuple of two lists
190
            as value. The first list gathers x values (indices
191
            representing time steps at which the corresponding
192
            metric value has been emitted). The second list
193
            gathers metric values. a dictionary. If `collect_all`
194
            is False return an empty dictionary
195
        """
196
        if self.collect_all:
4✔
197
            return self.all_metric_results
4✔
198
        else:
199
            return {}
×
200

201
    def reset_last_metrics(self):
4✔
202
        """
203
        Set the dictionary storing last value for each metric to be
204
        empty dict.
205
        """
206
        self.last_metric_results = {}
×
207

208
    def __getattribute__(self, item):
4✔
209
        # We don't want to reimplement all the callbacks just to call the
210
        # metrics. What we don't instead is to assume that any method that
211
        # starts with `before` or `after` is a callback of the plugin system,
212
        # and we forward that call to the metrics.
213
        try:
4✔
214
            return super().__getattribute__(item)
4✔
215
        except AttributeError as e:
4✔
216
            if item.startswith("before_") or item.startswith("after_"):
4✔
217
                # method is a callback. Forward to metrics.
218
                def fun(strat, **kwargs):
4✔
219
                    return self._update_metrics_and_loggers(strat, item)
4✔
220
                return fun
4✔
221
            raise
4✔
222

223
    def before_eval(self, strategy: "SupervisedTemplate", **kwargs):
4✔
224
        self._update_metrics_and_loggers(strategy, "before_eval")
4✔
225
        msge = (
4✔
226
            "Stream provided to `eval` must be the same of the entire "
227
            "evaluation stream."
228
        )
229
        if self.strict_checks:
4✔
230

231
            curr_stream = next(iter(strategy.current_eval_stream)).origin_stream
4✔
232
            benchmark = curr_stream[0].origin_stream.benchmark
4✔
233
            full_stream = benchmark.streams[curr_stream.name]
4✔
234

235
            if len(curr_stream) != len(full_stream):
4✔
236
                raise ValueError(msge)
4✔
237

238

239
def default_loggers() -> Sequence["BaseLogger"]:
4✔
240
    if DistributedHelper.is_main_process:
4✔
241
        return [InteractiveLogger()]
4✔
242
    else:
NEW
243
        return []
×
244

245

246
def default_evaluator() -> EvaluationPlugin:
4✔
247
    return EvaluationPlugin(
4✔
248
        accuracy_metrics(
249
            minibatch=False, epoch=True, experience=True, stream=True
250
        ),
251
        loss_metrics(minibatch=False, epoch=True, experience=True, stream=True),
252
        loggers=default_loggers,
253
    )
254

255

256
__all__ = [
4✔
257
    "EvaluationPlugin",
258
    "default_evaluator"
259
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc