• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

f-dangel / backpack / 8116256563

01 Mar 2024 07:29PM UTC coverage: 98.375%. Remained the same
8116256563

push

github

f-dangel
[DOC] Fix RTD build

4420 of 4493 relevant lines covered (98.38%)

13.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/backpack/core/derivatives/conv_transposend.py
1
"""Partial derivatives for ``torch.nn.ConvTranspose{1,2,3}d``."""
2
from typing import List, Tuple, Union
14✔
3

4
from einops import rearrange
14✔
5
from numpy import prod
14✔
6
from torch import Tensor, einsum
14✔
7
from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
14✔
8

9
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
14✔
10
from backpack.utils import TORCH_VERSION_AT_LEAST_1_13
14✔
11
from backpack.utils.conv import get_conv_function
14✔
12
from backpack.utils.conv_transpose import (
14✔
13
    get_conv_transpose_function,
14
    unfold_by_conv_transpose,
15
)
16
from backpack.utils.subsampling import subsample
14✔
17

18
if TORCH_VERSION_AT_LEAST_1_13:
14✔
19
    from backpack.utils.conv import _grad_input_padding
6✔
20
else:
21
    from torch.nn.grad import _grad_input_padding
8✔
22

23

24
class ConvTransposeNDDerivatives(BaseParameterDerivatives):
14✔
25
    """Base class for partial derivatives of transpose convolution."""
26

27
    def __init__(self, N: int):
14✔
28
        """Store transpose convolution dimension and operations.
29

30
        Args:
31
            N: Transpose convolution dimension.
32
        """
33
        self.conv_func = get_conv_function(N)
14✔
34
        self.conv_transpose_func = get_conv_transpose_function(N)
14✔
35
        self.conv_dims = N
14✔
36

37
    def hessian_is_zero(self, module):
14✔
38
        return True
14✔
39

40
    def _bias_jac_t_mat_prod(
14✔
41
        self, module, g_inp, g_out, mat, sum_batch=True, subsampling=None
42
    ):
43
        equation = f"vnc...->v{'' if sum_batch else 'n'}c"
14✔
44
        return einsum(equation, mat)
14✔
45

46
    def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
14✔
47
        # Expand batch dimension
48
        jac_mat = mat.unsqueeze(1)
14✔
49
        # Expand data dimensions
50
        for i in range(3, len(module.output.shape) + 1):
14✔
51
            jac_mat = jac_mat.unsqueeze(i)
14✔
52

53
        expand_shape = [-1, module.output.shape[0], -1, *module.output.shape[2:]]
14✔
54

55
        return jac_mat.expand(*expand_shape)
14✔
56

57
    def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
14✔
58
        V = mat.shape[0]
14✔
59
        G = module.groups
14✔
60
        C_in = module.input0.shape[1]
14✔
61
        N = module.output.shape[0]
14✔
62
        C_out = module.output.shape[1]
14✔
63

64
        mat_reshape = mat.reshape(V, G, C_in // G, C_out // G, *module.weight.shape[2:])
14✔
65
        u = unfold_by_conv_transpose(module.input0, module).reshape(
14✔
66
            N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:]
67
        )
68

69
        dims_kern = "xyz"[: self.conv_dims]
14✔
70
        dims_data = "abc"[: self.conv_dims]
14✔
71
        einstr = "ngi{0}{1},vgio{0}->vngo{1}".format(dims_kern, dims_data)
14✔
72
        jac_mat = einsum(einstr, u, mat_reshape)
14✔
73

74
        return self.reshape_like_output(jac_mat, module)
14✔
75

76
    def _weight_jac_t_mat_prod(
14✔
77
        self,
78
        module: Union[ConvTranspose1d, ConvTranspose2d, ConvTranspose3d],
79
        g_inp: Tuple[Tensor],
80
        g_out: Tuple[Tensor],
81
        mat: Tensor,
82
        sum_batch: bool = True,
83
        subsampling: List[int] = None,
84
    ) -> Tensor:
85
        V = mat.shape[0]
14✔
86
        G = module.groups
14✔
87
        C_in = module.input0.shape[1]
14✔
88
        N = module.output.shape[0] if subsampling is None else len(subsampling)
14✔
89
        C_out = module.output.shape[1]
14✔
90

91
        mat_reshape = mat.reshape(V, N, G, C_out // G, *module.output.shape[2:])
14✔
92

93
        u = unfold_by_conv_transpose(
14✔
94
            subsample(module.input0, subsampling=subsampling), module
95
        ).reshape(N, G, C_in // G, *module.weight.shape[2:], *module.output.shape[2:])
96

97
        dims_kern = "xyz"[: self.conv_dims]
14✔
98
        dims_data = "abc"[: self.conv_dims]
14✔
99
        result_str = ("vgio" if sum_batch else "vngio") + dims_kern
14✔
100
        equation = f"ngi{dims_kern}{dims_data},vngo{dims_data}->{result_str}"
14✔
101

102
        final_shape = (
14✔
103
            (V, *module.weight.shape) if sum_batch else (V, N, *module.weight.shape)
104
        )
105

106
        return einsum(equation, u, mat_reshape).reshape(final_shape)
14✔
107

108
    def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
14✔
109
        in_features = int(prod(module.input0.size()[1:]))
14✔
110
        out_features = int(prod(module.output.size()[1:]))
14✔
111

112
        mat = mat.reshape(out_features, *module.output.size()[1:])
14✔
113
        jac_t_mat = self.__jac_t(module, mat).reshape(out_features, in_features)
14✔
114

115
        mat_t_jac = jac_t_mat.t().reshape(in_features, *module.output.size()[1:])
14✔
116
        jac_t_mat_t_jac = self.__jac_t(module, mat_t_jac)
14✔
117
        jac_t_mat_t_jac = jac_t_mat_t_jac.reshape(in_features, in_features)
14✔
118

119
        return jac_t_mat_t_jac.t()
14✔
120

121
    def _jac_mat_prod(self, module, g_inp, g_out, mat):
14✔
122
        mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
14✔
123
        jmp_as_conv = self.__jac(module, mat_as_conv)
14✔
124
        return self.reshape_like_output(jmp_as_conv, module)
14✔
125

126
    def __jac(self, module, mat):
14✔
127
        input_size = list(module.output.size())
14✔
128
        input_size[0] = mat.size(0)
14✔
129

130
        grad_padding = _grad_input_padding(
14✔
131
            grad_output=mat,
132
            input_size=input_size,
133
            stride=module.stride,
134
            padding=module.padding,
135
            kernel_size=module.kernel_size,
136
            dilation=module.dilation,
137
        )
138

139
        jac_t_mat = self.conv_transpose_func(
14✔
140
            input=mat,
141
            weight=module.weight,
142
            bias=None,
143
            stride=module.stride,
144
            padding=module.padding,
145
            output_padding=grad_padding,
146
            groups=module.groups,
147
            dilation=module.dilation,
148
        )
149

150
        return jac_t_mat
14✔
151

152
    def _jac_t_mat_prod(
14✔
153
        self,
154
        module: Module,
155
        g_inp: Tuple[Tensor],
156
        g_out: Tuple[Tensor],
157
        mat: Tensor,
158
        subsampling: List[int] = None,
159
    ) -> Tensor:
160
        mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
14✔
161
        jmp_as_conv = self.__jac_t(module, mat_as_conv)
14✔
162
        return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling)
14✔
163

164
    def __jac_t(self, module: Module, mat: Tensor) -> Tensor:
14✔
165
        jac_t = self.conv_func(
14✔
166
            mat,
167
            module.weight,
168
            bias=None,
169
            stride=module.stride,
170
            padding=module.padding,
171
            dilation=module.dilation,
172
            groups=module.groups,
173
        )
174

175
        for dim in range(self.conv_dims):
14✔
176
            axis = dim + 1
14✔
177
            size = module.input0.shape[axis]
14✔
178
            jac_t = jac_t.narrow(axis, 0, size)
14✔
179

180
        return jac_t
14✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc