• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5600673849

pending completion
5600673849

Pull #1463

github

web-flow
Merge abde4c21e into 435b40d2b
Pull Request #1463: Various fixes and improvements

19 of 70 new or added lines in 7 files covered. (27.14%)

2 existing lines in 2 files now uncovered.

16709 of 22963 relevant lines covered (72.76%)

2.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

22.76
/avalanche/logging/wandb_logger.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 25-11-2020                                                             #
7
# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini                   #
8
# E-mail: contact@continualai.org                                              #
9
# Website: www.continualai.org                                                 #
10
################################################################################
11
""" This module handles all the functionalities related to the logging of
4✔
12
Avalanche experiments using Weights & Biases. """
13

14
import re
4✔
15
from typing import Optional, Union, List, TYPE_CHECKING
4✔
16
from pathlib import Path
4✔
17
import os
4✔
18
import warnings
4✔
19

20
import numpy as np
4✔
21
from numpy import array
4✔
22
from torch import Tensor
4✔
23

24
from PIL.Image import Image
4✔
25
from matplotlib.pyplot import Figure
4✔
26

27
from avalanche.core import SupervisedPlugin
4✔
28
from avalanche.evaluation.metric_results import (
4✔
29
    AlternativeValues,
30
    MetricValue,
31
    TensorImage,
32
)
33
from avalanche.logging import BaseLogger
4✔
34

35
if TYPE_CHECKING:
4✔
36
    from avalanche.evaluation.metric_results import MetricValue
×
37
    from avalanche.training.templates import SupervisedTemplate
×
38

39

40
CHECKPOINT_METRIC_NAME = re.compile(
4✔
41
    r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_"
42
    r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$"
43
)
44

45

46
class WandBLogger(BaseLogger, SupervisedPlugin):
4✔
47
    """Weights and Biases logger.
4✔
48

49
    The `WandBLogger` provides an easy integration with
50
    Weights & Biases logging. Each monitored metric is automatically
51
    logged to a dedicated Weights & Biases project dashboard.
52

53
    External storage for W&B Artifacts (for instance - AWS S3 and GCS
54
    buckets) uri are supported.
55

56
    The wandb log files are placed by default in "./wandb/" unless specified.
57

58
    .. note::
59

60
        TensorBoard can be synced on to the W&B dedicated dashboard.
61
    """
62

63
    def __init__(
4✔
64
        self,
65
        project_name: str = "Avalanche",
66
        run_name: str = "Test",
67
        log_artifacts: bool = False,
68
        path: Union[str, Path] = "Checkpoints",
69
        uri: Optional[str] = None,
70
        sync_tfboard: bool = False,
71
        save_code: bool = True,
72
        config: Optional[object] = None,
73
        dir: Optional[Union[str, Path]] = None,
74
        params: Optional[dict] = None,
75
    ):
76
        """Creates an instance of the `WandBLogger`.
77

78
        :param project_name: Name of the W&B project.
79
        :param run_name: Name of the W&B run.
80
        :param log_artifacts: Option to log model weights as W&B Artifacts.
81
            Note that, in order for model weights to be logged, the
82
            :class:`WeightCheckpoint` metric must be added to the
83
            evaluation plugin.
84
        :param path: Path to locally save the model checkpoints.
85
        :param uri: URI identifier for external storage buckets (GCS, S3).
86
        :param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI.
87
        :param save_code: Saves the main training script to W&B.
88
        :param config: Syncs hyper-parameters and config values used to W&B.
89
        :param dir: Path to the local log directory for W&B logs to be saved at.
90
        :param params: All arguments for wandb.init() function call. Visit
91
            https://docs.wandb.ai/ref/python/init to learn about all
92
            wand.init() parameters.
93
        """
94
        super().__init__()
×
95
        self.import_wandb()
×
96
        self.project_name = project_name
×
97
        self.run_name = run_name
×
98
        self.log_artifacts = log_artifacts
×
99
        self.path = path
×
100
        self.uri = uri
×
101
        self.sync_tfboard = sync_tfboard
×
102
        self.save_code = save_code
×
103
        self.config = config
×
104
        self.dir = dir
×
105
        self.params = params
×
106
        self.args_parse()
×
107
        self.before_run()
×
108
        self.step = 0
×
109
        self.exp_count = 0
×
110

111
    def import_wandb(self):
4✔
112
        try:
×
113
            import wandb
×
114

NEW
115
            assert hasattr(wandb, "__version__")
×
116
        except ImportError:
×
117
            raise ImportError('Please run "pip install wandb" to install wandb')
×
118
        self.wandb = wandb
×
119

120
    def args_parse(self):
4✔
121
        self.init_kwargs = {
×
122
            "project": self.project_name,
123
            "name": self.run_name,
124
            "sync_tensorboard": self.sync_tfboard,
125
            "dir": self.dir,
126
            "save_code": self.save_code,
127
            "config": self.config,
128
        }
129
        if self.params:
×
130
            self.init_kwargs.update(self.params)
×
131

132
    def before_run(self):
4✔
133
        if self.wandb is None:
×
134
            self.import_wandb()
×
135

136
        if self.init_kwargs is None:
×
137
            self.init_kwargs = dict()
×
138

139
        run_id = self.init_kwargs.get("id", None)
×
140
        if run_id is None:
×
141
            run_id = os.environ.get("WANDB_RUN_ID", None)
×
142
        if run_id is None:
×
143
            run_id = self.wandb.util.generate_id()
×
144

145
        self.init_kwargs["id"] = run_id
×
146

147
        self.wandb.init(**self.init_kwargs)
×
148
        self.wandb.run._label(repo="Avalanche")
×
149

150
    def after_training_exp(
4✔
151
        self,
152
        strategy: "SupervisedTemplate",
153
        metric_values: List["MetricValue"],
154
        **kwargs,
155
    ):
156
        for val in metric_values:
×
157
            self.log_metrics([val])
×
158

159
        self.wandb.log({"TrainingExperience": self.exp_count}, step=self.step)
×
160
        self.exp_count += 1
×
161

162
    def log_single_metric(self, name, value, x_plot):
4✔
163
        self.step = x_plot
×
164

NEW
165
        if name.startswith("WeightCheckpoint"):
×
NEW
166
            if self.log_artifacts:
×
NEW
167
                self._log_checkpoint(name, value, x_plot)
×
NEW
168
            return
×
169

170
        if isinstance(value, AlternativeValues):
×
171
            value = value.best_supported_value(
×
172
                Image,
173
                Tensor,
174
                TensorImage,
175
                Figure,
176
                float,
177
                int,
178
                self.wandb.viz.CustomChart,
179
            )
180

181
        if not isinstance(
×
182
            value,
183
            (
184
                Image,
185
                TensorImage,
186
                Tensor,
187
                Figure,
188
                float,
189
                int,
190
                self.wandb.viz.CustomChart,
191
            ),
192
        ):
193
            # Unsupported type
194
            return
×
195

196
        if isinstance(value, Image):
×
197
            self.wandb.log({name: self.wandb.Image(value)}, step=self.step)
×
198

199
        elif isinstance(value, Tensor):
×
200
            value = np.histogram(value.view(-1).numpy())
×
201
            self.wandb.log(
×
202
                {name: self.wandb.Histogram(np_histogram=value)}, step=self.step
203
            )
204

205
        elif isinstance(value, (float, int, Figure, self.wandb.viz.CustomChart)):
×
206
            self.wandb.log({name: value}, step=self.step)
×
207

208
        elif isinstance(value, TensorImage):
×
209
            self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step)
×
210

211
    def _log_checkpoint(self, name, value, x_plot):
4✔
NEW
212
        assert self.wandb is not None
×
213

214
        # Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000'
NEW
215
        name_match = CHECKPOINT_METRIC_NAME.match(name)
×
NEW
216
        if name_match is None:
×
NEW
217
            warnings.warn(f"Checkpoint metric has unsupported name {name}.")
×
NEW
218
            return
×
219
        # phase_name: str = name_match['phase_name']
220
        # stream_name: str = name_match['stream_name']
NEW
221
        task_id: Optional[int] = (
×
222
            int(name_match["task_id"]) if name_match["task_id"] is not None else None
223
        )
NEW
224
        experience_id: int = int(name_match["experience_id"])
×
NEW
225
        assert experience_id >= 0
×
226

NEW
227
        cwd = Path.cwd()
×
NEW
228
        checkpoint_directory = cwd / self.path
×
NEW
229
        checkpoint_directory.mkdir(parents=True, exist_ok=True)
×
230

NEW
231
        checkpoint_name = "Model_{}".format(experience_id)
×
NEW
232
        checkpoint_file_name = checkpoint_name + ".pth"
×
NEW
233
        checkpoint_path = checkpoint_directory / checkpoint_file_name
×
NEW
234
        artifact_name = "Models/" + checkpoint_file_name
×
235

236
        # Write the checkpoint blob
NEW
237
        with open(checkpoint_path, "wb") as f:
×
NEW
238
            f.write(value)
×
239

NEW
240
        metadata = {
×
241
            "experience": experience_id,
242
            "x_step": x_plot,
243
            **({"task_id": task_id} if task_id is not None else {}),
244
        }
245

NEW
246
        artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata)
×
NEW
247
        artifact.add_file(str(checkpoint_path), name=artifact_name)
×
NEW
248
        self.wandb.run.log_artifact(artifact)
×
NEW
249
        if self.uri is not None:
×
NEW
250
            artifact.add_reference(self.uri, name=artifact_name)
×
251

252
    def __getstate__(self):
4✔
253
        state = self.__dict__.copy()
×
254
        if "wandb" in state:
×
255
            del state["wandb"]
×
256
        return state
×
257

258
    def __setstate__(self, state):
4✔
259
        print("[W&B logger] Resuming from checkpoint...")
×
260
        self.__dict__ = state
×
261
        if self.init_kwargs is None:
×
262
            self.init_kwargs = dict()
×
263
        self.init_kwargs["resume"] = "allow"
×
264

265
        self.wandb = None
×
266
        self.before_run()
×
267

268

269
__all__ = ["WandBLogger"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc