• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ControlNet / MARLIN / 9011167809

09 May 2024 01:56AM CUT coverage: 65.763%. First build
9011167809

Pull #25

github

web-flow
Merge 7c160e431 into 23491494b
Pull Request #25: hotfix: add lightning dependancy for _cosine_scheduler_fn(..)

461 of 701 relevant lines covered (65.76%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.94
/src/marlin_pytorch/model/modules.py
1
import math
1✔
2
import warnings
1✔
3
from typing import Union, Optional, Callable, Tuple, List, Sequence
1✔
4

5
import torch
1✔
6
from einops.layers.torch import Rearrange
1✔
7
from torch import Tensor, nn, Size
1✔
8
from torch.nn import Conv3d, ModuleList
1✔
9
from torch.nn import functional as F
1✔
10

11
Shape = Union[Size, List[int], Tuple[int, ...]]
1✔
12
ModuleFactory = Union[Callable[[], nn.Module], Callable[[int], nn.Module]]
1✔
13

14

15
class PatchEmbedding3d(nn.Module):
1✔
16

17
    def __init__(self, input_size: Shape, patch_size: Union[int, Shape], embedding: int,
1✔
18
        strides: Optional[Union[int, Shape]] = None,
19
        build_normalization: Optional[ModuleFactory] = None
20
    ):
21
        super().__init__()
1✔
22
        # channel, time, height, width
23
        c, t, h, w = input_size
1✔
24
        # patch_time, patch_height, patch_width
25
        pt, ph, pw = (patch_size, patch_size, patch_size) if type(patch_size) is int else patch_size
1✔
26

27
        # configure the strides for conv3d
28
        if strides is None:
1✔
29
            # no specified means no overlap and gap between patches
30
            strides = (pt, ph, pw)
1✔
31
        elif type(strides) is int:
×
32
            # transform the side length of strides to 3D
33
            strides = (strides, strides, strides)
×
34

35
        self.projection = Conv3d(c, embedding, kernel_size=(pt, ph, pw), stride=strides)
1✔
36
        self.has_norm = build_normalization is not None
1✔
37
        if self.has_norm:
1✔
38
            self.normalization = build_normalization()
×
39
        self.rearrange = Rearrange("b d nt nh nw -> b (nt nh nw) d")
1✔
40

41
    def forward(self, x: Tensor) -> Tensor:
1✔
42
        x = self.projection(x)
1✔
43
        x = self.rearrange(x)
1✔
44
        if self.has_norm:
1✔
45
            x = self.normalization(x)
×
46
        return x
1✔
47

48

49
class Linear(nn.Module):
1✔
50

51
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
1✔
52
        build_activation: Optional[ModuleFactory] = None,
53
        build_normalization: Optional[ModuleFactory] = None,
54
        normalization_after_activation: bool = False,
55
        dropout_rate: float = 0.
56
    ):
57
        super().__init__()
1✔
58
        self.linear = nn.Linear(in_features, out_features, bias)
1✔
59

60
        self.has_act = build_activation is not None
1✔
61
        if self.has_act:
1✔
62
            self.activation = build_activation()
1✔
63
        else:
64
            self.activation = None
1✔
65

66
        self.has_norm = build_normalization is not None
1✔
67
        if self.has_norm:
1✔
68
            self.normalization = build_normalization()
×
69
            self.norm_after_act = normalization_after_activation
×
70
        else:
71
            self.normalization = None
1✔
72

73
        self.has_dropout = dropout_rate > 0
1✔
74
        if self.has_dropout:
1✔
75
            self.dropout = nn.Dropout(dropout_rate)
×
76

77
    def forward(self, x: Tensor) -> Tensor:
1✔
78
        x = self.linear(x)
1✔
79
        if self.has_act and self.has_norm:
1✔
80
            if self.norm_after_act:
×
81
                x = self.activation(x)
×
82
                x = self.normalization(x)
×
83
            else:
84
                x = self.normalization(x)
×
85
                x = self.activation(x)
×
86
        elif self.has_act and not self.has_norm:
1✔
87
            x = self.activation(x)
1✔
88
        elif not self.has_act and self.has_norm:
1✔
89
            x = self.normalization(x)
×
90

91
        if self.has_dropout:
1✔
92
            x = self.dropout(x)
×
93
        return x
1✔
94

95

96
class MLP(nn.Module):
1✔
97

98
    def __init__(self, neurons: Sequence[int],
1✔
99
        build_activation: Optional[ModuleFactory] = None, dropout_rate: float = 0.
100
    ):
101
        super().__init__()
1✔
102
        n_features = neurons[1:]
1✔
103
        self.layers: ModuleList[Linear] = ModuleList(
1✔
104
            [Linear(neurons[i], neurons[i + 1], True, build_activation, None,
105
                False, dropout_rate
106
            ) for i in range(len(n_features) - 1)
107
            ] + [
108
                Linear(neurons[-2], neurons[-1], True)
109
            ]
110
        )
111

112
    def forward(self, x: Tensor) -> Tensor:
1✔
113
        for layer in self.layers:
1✔
114
            x = layer(x)
1✔
115
        return x
1✔
116

117

118
class Attention(nn.Module):
1✔
119

120
    def __init__(
1✔
121
        self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
122
        proj_drop=0., attn_head_dim=None
123
    ):
124
        super().__init__()
1✔
125
        self.num_heads = num_heads
1✔
126
        head_dim = dim // num_heads
1✔
127
        if attn_head_dim is not None:
1✔
128
            head_dim = attn_head_dim
×
129
        all_head_dim = head_dim * self.num_heads
1✔
130
        self.scale = qk_scale or head_dim ** -0.5
1✔
131

132
        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
1✔
133
        if qkv_bias:
1✔
134
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
1✔
135
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
1✔
136
        else:
137
            self.q_bias = None
×
138
            self.v_bias = None
×
139

140
        self.attn_drop = nn.Dropout(attn_drop)
1✔
141
        self.proj = nn.Linear(all_head_dim, dim)
1✔
142
        self.proj_drop = nn.Dropout(proj_drop)
1✔
143

144
    def forward(self, x):
1✔
145
        B, N, C = x.shape
1✔
146
        qkv_bias = None
1✔
147
        if self.q_bias is not None:
1✔
148
            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
1✔
149
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
150
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
1✔
151
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
1✔
152
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
1✔
153

154
        q = q * self.scale
1✔
155
        attn = (q @ k.transpose(-2, -1))
1✔
156

157
        attn = attn.softmax(dim=-1)
1✔
158
        attn = self.attn_drop(attn)
1✔
159

160
        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
1✔
161
        x = self.proj(x)
1✔
162
        x = self.proj_drop(x)
1✔
163
        return x
1✔
164

165

166
class Block(nn.Module):
1✔
167

168
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
1✔
169
        init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
170
        attn_head_dim=None
171
    ):
172
        super().__init__()
1✔
173
        self.norm1 = norm_layer(dim)
1✔
174
        self.attn = Attention(
1✔
175
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
176
            attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
177
        self.norm2 = norm_layer(dim)
1✔
178
        mlp_hidden_dim = int(dim * mlp_ratio)
1✔
179
        self.mlp = MLP(
1✔
180
            neurons=[dim, mlp_hidden_dim, dim],
181
            build_activation=act_layer,
182
            dropout_rate=drop
183
        )
184

185
        if init_values > 0:
1✔
186
            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
×
187
            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
×
188
        else:
189
            self.gamma_1, self.gamma_2 = None, None
1✔
190

191
    def forward(self, x):
1✔
192
        if self.gamma_1 is None:
1✔
193
            x = x + self.attn(self.norm1(x))
1✔
194
            x = x + self.mlp(self.norm2(x))
1✔
195
        else:
196
            x = x + (self.gamma_1 * self.attn(self.norm1(x)))
×
197
            x = x + (self.gamma_2 * self.mlp(self.norm2(x)))
×
198
        return x
1✔
199

200

201
def no_grad_trunc_normal_(tensor, mean, std, a, b):
1✔
202
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
203
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
204
    def norm_cdf(x):
1✔
205
        # Computes standard normal cumulative distribution function
206
        return (1. + math.erf(x / math.sqrt(2.))) / 2.
1✔
207

208
    if (mean < a - 2 * std) or (mean > b + 2 * std):
1✔
209
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
×
210
                      "The distribution of values may be incorrect.",
211
            stacklevel=2)
212

213
    with torch.no_grad():
1✔
214
        # Values are generated by using a truncated uniform distribution and
215
        # then using the inverse CDF for the normal distribution.
216
        # Get upper and lower cdf values
217
        l = norm_cdf((a - mean) / std)
1✔
218
        u = norm_cdf((b - mean) / std)
1✔
219

220
        # Uniformly fill tensor with values from [l, u], then translate to
221
        # [2l-1, 2u-1].
222
        tensor.uniform_(2 * l - 1, 2 * u - 1)
1✔
223

224
        # Use inverse cdf transform for normal distribution to get truncated
225
        # standard normal
226
        tensor.erfinv_()
1✔
227

228
        # Transform to proper mean, std
229
        tensor.mul_(std * math.sqrt(2.))
1✔
230
        tensor.add_(mean)
1✔
231

232
        # Clamp to ensure it's in the proper range
233
        tensor.clamp_(min=a, max=b)
1✔
234
        return tensor
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc