• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 5600673849

pending completion
5600673849

Pull #1463

github

web-flow
Merge abde4c21e into 435b40d2b
Pull Request #1463: Various fixes and improvements

19 of 70 new or added lines in 7 files covered. (27.14%)

2 existing lines in 2 files now uncovered.

16709 of 22963 relevant lines covered (72.76%)

2.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

51.35
/avalanche/evaluation/metrics/checkpoint.py
1
################################################################################
2
# Copyright (c) 2021 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Date: 30-12-2020                                                             #
7
# Author(s): Diganta Misra                                                     #
8
# E-mail: contact@continualai.org                                              #
9
# Website: www.continualai.org                                                 #
10
################################################################################
11

12
import copy
4✔
13
import io
4✔
14
from typing import TYPE_CHECKING, Optional
4✔
15

16
from torch import Tensor
4✔
17
import torch
4✔
18

19
from avalanche.evaluation import PluginMetric
4✔
20
from avalanche.evaluation.metric_results import MetricValue, MetricResult
4✔
21
from avalanche.evaluation.metric_utils import get_metric_name
4✔
22

23
if TYPE_CHECKING:
4✔
24
    from avalanche.training.templates import SupervisedTemplate
×
25

26

27
class WeightCheckpoint(PluginMetric[Tensor]):
4✔
28
    """
4✔
29
    The WeightCheckpoint Metric.
30

31
    Instances of this metric keeps the weight checkpoint tensor of the
32
    model at each experience.
33

34
    Each time `result` is called, this metric emits the latest experience's
35
    weight checkpoint tensor since the last `reset`.
36

37
    The reset method will bring the metric to its initial state. By default
38
    this metric in its initial state will return None.
39
    """
40

41
    def __init__(self):
4✔
42
        """
43
        Creates an instance of the WeightCheckpoint Metric.
44

45
        By default this metric in its initial state will return None.
46
        The metric can be updated by using the `update` method
47
        while the current experience's weight checkpoint tensor can be
48
        retrieved using the `result` method.
49
        """
50
        super().__init__()
×
NEW
51
        self.weights: Optional[bytes] = None
×
52

53
    def update(self, weights: bytes):
4✔
54
        """
55
        Update the weight checkpoint at the current experience.
56

57
        :param weights: the weight tensor at current experience
58
        :return: None.
59
        """
60
        self.weights = weights
×
61

62
    def result(self) -> Optional[bytes]:
4✔
63
        """
64
        Retrieves the weight checkpoint at the current experience.
65

66
        :return: The weight checkpoint as a tensor.
67
        """
68
        return self.weights
×
69

70
    def reset(self) -> None:
4✔
71
        """
72
        Resets the metric.
73

74
        :return: None.
75
        """
76
        self.weights = None
×
77

78
    def _package_result(self, strategy) -> "MetricResult":
4✔
79
        weights = self.result()
×
NEW
80
        if weights is None:
×
NEW
81
            return None
×
82

UNCOV
83
        metric_name = get_metric_name(
×
84
            self, strategy, add_experience=True, add_task=False
85
        )
86
        return [
×
87
            MetricValue(self, metric_name, weights, strategy.clock.train_iterations)
88
        ]
89

90
    def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult":
4✔
NEW
91
        buff = io.BytesIO()
×
NEW
92
        model_params = copy.deepcopy(strategy.model).to("cpu")
×
NEW
93
        torch.save(model_params, buff)
×
NEW
94
        buff.seek(0)
×
NEW
95
        self.update(buff.read())
×
96

NEW
97
        return self._package_result(strategy)
×
98

99
    def __str__(self):
4✔
100
        return "WeightCheckpoint"
×
101

102

103
__all__ = ["WeightCheckpoint"]
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc