• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

f-dangel / backpack / 8116256563

01 Mar 2024 07:29PM UTC coverage: 98.375%. Remained the same
8116256563

push

github

f-dangel
[DOC] Fix RTD build

4420 of 4493 relevant lines covered (98.38%)

13.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.16
/backpack/utils/conv.py
1
"""Utility functions for convolution layers."""
2

3
from typing import Callable, Tuple, Type, Union
14✔
4
from warnings import warn
14✔
5

6
from einops import rearrange
14✔
7
from torch import Tensor, einsum
14✔
8
from torch.nn import (
14✔
9
    Conv1d,
10
    Conv2d,
11
    Conv3d,
12
    ConvTranspose1d,
13
    ConvTranspose2d,
14
    ConvTranspose3d,
15
    Module,
16
)
17
from torch.nn.functional import conv1d, conv2d, conv3d, unfold
14✔
18
from unfoldNd import unfoldNd
14✔
19

20

21
def get_conv_module(N: int) -> Type[Module]:
14✔
22
    """Return the PyTorch module class of N-dimensional convolution.
23

24
    Args:
25
        N: Convolution dimension.
26

27
    Returns:
28
        Convolution class.
29
    """
30
    return {
14✔
31
        1: Conv1d,
32
        2: Conv2d,
33
        3: Conv3d,
34
    }[N]
35

36

37
def get_conv_function(N: int) -> Callable:
14✔
38
    """Return the PyTorch function of N-dimensional convolution.
39

40
    Args:
41
        N: Convolution dimension.
42

43
    Returns:
44
        Convolution function.
45
    """
46
    return {
14✔
47
        1: conv1d,
48
        2: conv2d,
49
        3: conv3d,
50
    }[N]
51

52

53
def unfold_input(module: Union[Conv1d, Conv2d, Conv3d], input: Tensor) -> Tensor:
14✔
54
    """Return unfolded input to a convolution.
55

56
    Use PyTorch's ``unfold`` operation for 2d convolutions (4d input tensors),
57
    otherwise fall back to a custom implementation.
58

59
    Args:
60
        module: Convolution module whose hyperparameters are used for the unfold.
61
        input: Input to convolution that will be unfolded.
62

63
    Returns:
64
        Unfolded input.
65
    """
66
    if input.dim() == 4:
14✔
67
        return unfold(
14✔
68
            input,
69
            kernel_size=module.kernel_size,
70
            dilation=module.dilation,
71
            padding=module.padding,
72
            stride=module.stride,
73
        )
74
    else:
75
        return unfold_by_conv(input, module)
14✔
76

77

78
def get_weight_gradient_factors(
14✔
79
    input: Tensor, grad_out: Tensor, module: Union[Conv1d, Conv2d, Conv3d]
80
) -> Tuple[Tensor, Tensor]:
81
    """Return the factors for constructing the gradients w.r.t. the kernel.
82

83
    Args:
84
        input: Convolution layer input.
85
        grad_out: Gradient w.r.t. to the convolution layer output.
86
        module: Convolution layer.
87

88
    Returns:
89
        Unfolded input, output gradient with flattened spatial dimensions.
90
    """
91
    X = unfold_input(module, input)
14✔
92
    dE_dY = rearrange(grad_out, "n c ... -> n c (...)")
14✔
93
    return X, dE_dY
14✔
94

95

96
def extract_weight_diagonal(
14✔
97
    module: Union[Conv1d, Conv2d, Conv3d],
98
    unfolded_input: Tensor,
99
    S: Tensor,
100
    sum_batch: bool = True,
101
) -> Tensor:
102
    """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the weight Jacobian.
103

104
    Args:
105
        module: Convolution layer for which the diagonal is extracted w.r.t. the weight.
106
        unfolded_input: Unfolded input to the convolution. Shape must follow the
107
            conventions of ``torch.nn.Unfold``.
108
        S: Backpropagated (symmetric factorization) of the loss Hessian.
109
            Has shape ``(V, *module.output.shape)``.
110
        sum_batch: Sum out the batch dimension of the weight diagonals.
111
            Default value: ``True``.
112

113
    Returns:
114
        Per-sample weight diagonal if ``sum_batch=False`` (shape
115
        ``(N, module.weight.shape)`` with batch size ``N``) or summed weight
116
        diagonal if ``sum_batch=True`` (shape ``module.weight.shape``).
117
    """
118
    S = rearrange(S, "v n (g c) ... -> v n g c (...)", g=module.groups)
14✔
119
    unfolded_input = rearrange(unfolded_input, "n (g c) k -> n g c k", g=module.groups)
14✔
120

121
    JS = einsum("ngkl,vngml->vngmk", (unfolded_input, S))
14✔
122

123
    sum_dims = [0, 1] if sum_batch else [0]
14✔
124
    out_shape = (
14✔
125
        module.weight.shape if sum_batch else (JS.shape[1], *module.weight.shape)
126
    )
127

128
    return JS.pow_(2).sum(sum_dims).reshape(out_shape)
14✔
129

130

131
# TODO This method applies the bias Jacobian, then squares and sums the result. Intro-
132
# duce base class for {Batch}DiagHessian and DiagGGN{Exact,MC} and remove this method
133
def extract_bias_diagonal(
14✔
134
    module: Union[
135
        Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
136
    ],
137
    S: Tensor,
138
    sum_batch: bool = True,
139
) -> Tensor:
140
    """Extract diagonal of ``(Jᵀ S) (Jᵀ S)ᵀ`` where ``J`` is the bias Jacobian.
141

142
    Args:
143
        module: Convolution layer for which the diagonal is extracted w.r.t. the bias.
144
        S: Backpropagated (symmetric factorization) of the loss Hessian.
145
            Has shape ``(V, *module.output.shape)``.
146
        sum_batch: Sum out the batch dimension of the bias diagonals.
147
            Default value: ``True``.
148

149
    Returns:
150
        Per-sample bias diagonal if ``sum_batch=False`` (shape
151
        ``(N, module.bias.shape)`` with batch size ``N``) or summed bias
152
        diagonal if ``sum_batch=True`` (shape ``module.bias.shape``).
153
    """
154
    start_spatial = 3
14✔
155
    sum_before = list(range(start_spatial, S.dim()))
14✔
156
    sum_after = [0, 1] if sum_batch else [0]
14✔
157

158
    return S.sum(sum_before).pow_(2).sum(sum_after)
14✔
159

160

161
def unfold_by_conv(input: Tensor, module: Union[Conv1d, Conv2d, Conv3d]) -> Tensor:
14✔
162
    """Return the unfolded input using convolution.
163

164
    Args:
165
        input: Convolution layer input.
166
        module: Convolution layer.
167

168
    Returns:
169
        Unfolded input. For a 2d convolution with input of shape `[N, C_in, *, *]`
170
        and a kernel of shape `[_, _, K_H, K_W]`, this tensor has shape
171
        `[N, C_in * K_H * K_W, L]` where `L` is the output's number of patches.
172
    """
173
    return unfoldNd(
14✔
174
        input,
175
        module.kernel_size,
176
        dilation=module.dilation,
177
        padding=module.padding,
178
        stride=module.stride,
179
    )
180

181

182
def _grad_input_padding(
14✔
183
    grad_output: Tensor,
184
    input_size: Tuple[int, ...],
185
    stride: Tuple[int, ...],
186
    padding: Tuple[int, ...],
187
    kernel_size: Tuple[int, ...],
188
    dilation: Union[None, Tuple[int]] = None,
189
) -> Tuple[int, ...]:
190
    """Determine padding for the VJP of convolution.
191

192
    # noqa: DAR101
193
    # noqa: DAR201
194
    # noqa: DAR401
195

196
    Note:
197
        This function was copied from the PyTorch repository (version 1.9).
198
        It was removed between torch 1.12.1 and torch 1.13.
199
    """
200
    if dilation is None:
6✔
201
        # For backward compatibility
202
        warn(
×
203
            "_grad_input_padding 'dilation' argument not provided. Default of 1 is used."
204
        )
205
        dilation = [1] * len(stride)
×
206

207
    input_size = list(input_size)
6✔
208
    k = grad_output.dim() - 2
6✔
209

210
    if len(input_size) == k + 2:
6✔
211
        input_size = input_size[-k:]
6✔
212
    if len(input_size) != k:
6✔
213
        raise ValueError(f"input_size must have {k+2} elements (got {len(input_size)})")
×
214

215
    def dim_size(d):
6✔
216
        return (
6✔
217
            (grad_output.size(d + 2) - 1) * stride[d]
218
            - 2 * padding[d]
219
            + 1
220
            + dilation[d] * (kernel_size[d] - 1)
221
        )
222

223
    min_sizes = [dim_size(d) for d in range(k)]
6✔
224
    max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
6✔
225
    for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
6✔
226
        if size < min_size or size > max_size:
6✔
227
            raise ValueError(
×
228
                f"requested an input grad size of {input_size}, but valid sizes range "
229
                f"from {min_sizes} to {max_sizes} (for a grad_output of "
230
                f"{grad_output.size()[2:]})"
231
            )
232

233
    return tuple(input_size[d] - min_sizes[d] for d in range(k))
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc