• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

f-dangel / backpack / 8116261751

01 Mar 2024 07:30PM UTC coverage: 98.375%. Remained the same
8116261751

Pull #323

github

web-flow
Merge 610195223 into e9b1dd361
Pull Request #323: [FIX | FMT] RTD build, apply latest `black` and `isort`

97 of 97 new or added lines in 97 files covered. (100.0%)

43 existing lines in 18 files now uncovered.

4420 of 4493 relevant lines covered (98.38%)

11.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.33
/backpack/extensions/module_extension.py
1
"""Contains base class for BackPACK module extensions."""
2
from __future__ import annotations
12✔
3

4
from typing import TYPE_CHECKING, Any, List, Tuple
12✔
5
from warnings import warn
12✔
6

7
from torch import Tensor
12✔
8
from torch.nn import Module
12✔
9

10
from backpack.utils.module_classification import is_loss
12✔
11

12
if TYPE_CHECKING:
13
    from backpack import BackpropExtension
14

15

16
class ModuleExtension:
12✔
17
    """Base class for a Module Extension for BackPACK.
18

19
    Descendants of this class need to
20
    - define what parameters of the Module need to be treated (weight, bias)
21
      and provide functions to compute the quantities
22
    - extend the `backpropagate` function if information other than the gradient
23
      needs to be propagated through the graph.
24
    """
25

26
    def __init__(self, params: List[str] = None):
12✔
27
        """Initialization.
28

29
        Args:
30
            params: List of module parameters that need special treatment.
31
                For each param `p` in the list, instances of the extended module `m`
32
                need to have a field `m.p` and the class extending `ModuleExtension`
33
                needs to provide a method with the same signature as the `backpropagate`
34
                method.
35
                The result of this method will be saved in the savefield of `m.p`.
36

37
        Raises:
38
            NotImplementedError: if child class doesn't have a method for each parameter
39
        """
40
        self.__params: List[str] = [] if params is None else params
12✔
41

42
        for param in self.__params:
12✔
43
            if not hasattr(self, param):
12✔
44
                raise NotImplementedError(
45
                    f"The module extension {self} is missing an implementation "
46
                    f"of how to calculate the quantity for {param}. "
47
                    f"This should be realized in a function "
48
                    f"{param}(extension, module, g_inp, g_out, bpQuantities) -> Any."
49
                )
50

51
    def backpropagate(
12✔
52
        self,
53
        extension: BackpropExtension,
54
        module: Module,
55
        g_inp: Tuple[Tensor],
56
        g_out: Tuple[Tensor],
57
        bpQuantities: Any,
58
    ) -> Any:
59
        """Backpropagation of additional information through the graph.
60

61
        Args:
62
            extension: Instance of the extension currently running
63
            module: Instance of the extended module
64
            g_inp: Gradient of the loss w.r.t. the inputs
65
            g_out: Gradient of the loss w.r.t. the output
66
            bpQuantities: Quantities backpropagated w.r.t. the output
67

68
        Returns
69
            Quantities backpropagated w.r.t. the input
70
        """
UNCOV
71
        warn("Backpropagate has not been overwritten")
×
72

73
    def __call__(
12✔
74
        self,
75
        extension: BackpropExtension,
76
        module: Module,
77
        g_inp: Tuple[Tensor],
78
        g_out: Tuple[Tensor],
79
    ) -> None:
80
        """Apply all actions required by the extension.
81

82
        Fetch backpropagated quantities from module output, apply backpropagation
83
        rule, and store as backpropagated quantities for the module input(s).
84

85
        Args:
86
            extension: current backpropagation extension
87
            module: current module
88
            g_inp: input gradients
89
            g_out: output gradients
90

91
        Raises:
92
            AssertionError: if there is no saved quantity although extension expects one,
93
                or if a backpropagated quantity is expected, but there is None and the old
94
                backward hook is used and the module is not a Flatten no op.
95
        """
96
        self.check_hyperparameters_module_extension(extension, module, g_inp, g_out)
12✔
97
        delete_old_quantities = not self.__should_retain_backproped_quantities(module)
12✔
98
        bp_quantity = self.__get_backproped_quantity(
12✔
99
            extension, module.output, delete_old_quantities
100
        )
101
        if (
12✔
102
            extension.expects_backpropagation_quantities()
103
            and bp_quantity is None
104
            and not is_loss(module)
105
        ):
106
            raise AssertionError(
107
                "BackPACK extension expects a backpropagation quantity but it is None. "
108
                f"Module: {module}, Extension: {extension}."
109
            )
110

111
        for param in self.__params:
12✔
112
            if self.__param_exists_and_requires_grad(module, param):
12✔
113
                extFunc = getattr(self, param)
12✔
114
                extValue = extFunc(extension, module, g_inp, g_out, bp_quantity)
12✔
115
                self.__save_value_on_parameter(extValue, extension, module, param)
12✔
116

117
        module_inputs = self.__get_inputs_for_backpropagation(extension, module)
12✔
118
        if module_inputs:
12✔
119
            bp_quantity = self.backpropagate(
12✔
120
                extension, module, g_inp, g_out, bp_quantity
121
            )
122
            for module_inp in module_inputs:
12✔
123
                self.__save_backproped_quantity(extension, module_inp, bp_quantity)
12✔
124

125
    @staticmethod
12✔
126
    def __get_inputs_for_backpropagation(
12✔
127
        extension: BackpropExtension, module: Module
128
    ) -> Tuple[Tensor]:
129
        """Returns the inputs on which a backpropagation should be performed.
130

131
        Args:
132
            extension: current extension
133
            module: current module
134

135
        Returns:
136
            the inputs which need a backpropagation quantity
137
        """
138
        module_inputs: Tuple[Tensor, ...] = ()
12✔
139

140
        if extension.expects_backpropagation_quantities():
12✔
141
            i = 0
12✔
142
            while hasattr(module, f"input{i}"):
12✔
143
                input = getattr(module, f"input{i}")
12✔
144
                if input.requires_grad:
12✔
145
                    module_inputs += (input,)
12✔
146
                i += 1
12✔
147

148
        return module_inputs
12✔
149

150
    @staticmethod
12✔
151
    def __should_retain_backproped_quantities(module: Module) -> bool:
12✔
152
        """Whether the backpropagation quantities should be kept.
153

154
        This is old code inherited and not tested.
155

156
        Args:
157
            module: current module
158

159
        Returns:
160
            whether backpropagation quantities should be kept
161
        """
162
        is_a_leaf = module.output.grad_fn is None
12✔
163
        retain_grad_is_on = getattr(module.output, "retains_grad", False)
12✔
164
        # inp_is_out = id(module.input0) == id(module.output)
165
        should_retain_grad = is_a_leaf or retain_grad_is_on  # or inp_is_out
12✔
166
        return should_retain_grad
12✔
167

168
    @staticmethod
12✔
169
    def __get_backproped_quantity(
12✔
170
        extension: BackpropExtension,
171
        reference_tensor: Tensor,
172
        delete_old: bool,
173
    ) -> Tensor or None:
174
        """Fetch backpropagated quantities attached to the module output.
175

176
        The property reference_tensor.data_ptr() is used as a reference.
177

178
        Args:
179
            extension: current BackPACK extension
180
            reference_tensor: the output Tensor of the current module
181
            delete_old: whether to delete the old backpropagated quantity
182

183
        Returns:
184
            the backpropagation quantity
185
        """
186
        return extension.saved_quantities.retrieve_quantity(
12✔
187
            reference_tensor.data_ptr(), delete_old
188
        )
189

190
    @staticmethod
12✔
191
    def __save_backproped_quantity(
12✔
192
        extension: BackpropExtension, reference_tensor: Tensor, bpQuantities: Any
193
    ) -> None:
194
        """Save additional information backpropagated for a tensor.
195

196
        Args:
197
            extension: current BackPACK extension
198
            reference_tensor: reference tensor for which additional information
199
                is backpropagated.
200
            bpQuantities: backpropagation quantities that should be saved
201
        """
202
        extension.saved_quantities.save_quantity(
12✔
203
            reference_tensor.data_ptr(),
204
            bpQuantities,
205
            extension.accumulate_backpropagated_quantities,
206
        )
207

208
    @staticmethod
12✔
209
    def __param_exists_and_requires_grad(module: Module, param_str: str) -> bool:
12✔
210
        """Whether the module has the parameter and it requires gradient.
211

212
        Args:
213
            module: current module
214
            param_str: parameter name
215

216
        Returns:
217
            whether the module has the parameter and it requires gradient
218
        """
219
        param_exists = getattr(module, param_str) is not None
12✔
220
        return param_exists and getattr(module, param_str).requires_grad
12✔
221

222
    @staticmethod
12✔
223
    def __save_value_on_parameter(
12✔
224
        value: Any, extension: BackpropExtension, module: Module, param_str: str
225
    ) -> None:
226
        """Saves the value on the parameter of that module.
227

228
        Args:
229
            value: The value that should be saved.
230
            extension: The current BackPACK extension.
231
            module: current module
232
            param_str: parameter name
233
        """
234
        setattr(getattr(module, param_str), extension.savefield, value)
12✔
235

236
    def check_hyperparameters_module_extension(
12✔
237
        self,
238
        ext: BackpropExtension,
239
        module: Module,
240
        g_inp: Tuple[Tensor],
241
        g_out: Tuple[Tensor],
242
    ) -> None:
243
        """Check whether the current module is supported by the extension.
244

245
        Child classes can override this method.
246

247
        Args:
248
            ext: current extension
249
            module: module
250
            g_inp: input gradients
251
            g_out: output gradients
252
        """
253
        pass
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc