• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ControlNet / MARLIN / 9011167809

09 May 2024 01:56AM CUT coverage: 65.763%. First build
9011167809

Pull #25

github

web-flow
Merge 7c160e431 into 23491494b
Pull Request #25: hotfix: add lightning dependancy for _cosine_scheduler_fn(..)

461 of 701 relevant lines covered (65.76%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.73
/src/marlin_pytorch/model/decoder.py
1
import torch
1✔
2
from einops import rearrange
1✔
3
from torch import nn, Tensor
1✔
4
from torch.nn import LayerNorm, Linear, ModuleList
1✔
5

6
from .modules import Block, no_grad_trunc_normal_
1✔
7
from .positional_embedding import SinCosPositionalEmbedding
1✔
8

9

10
class MarlinDecoder(nn.Module):
1✔
11

12
    def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=384, depth=8,
1✔
13
        num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
14
        norm_layer="LayerNorm", init_values=1., tubelet_size=2
15
    ):
16
        super().__init__()
1✔
17
        output_dim = 3 * tubelet_size * patch_size * patch_size
1✔
18
        self.patch_size = patch_size
1✔
19
        self.tubelet_size = tubelet_size
1✔
20
        self.n_patch_h = img_size // patch_size
1✔
21
        self.n_patch_w = img_size // patch_size
1✔
22
        self.embed_dim = embed_dim
1✔
23
        if norm_layer == "LayerNorm":
1✔
24
            self.norm_layer = LayerNorm
1✔
25
            self.norm = self.norm_layer(embed_dim)
1✔
26
        else:
27
            raise NotImplementedError("Only LayerNorm is supported")
×
28

29
        # sine-cosine positional embeddings
30
        self.pos_embedding = SinCosPositionalEmbedding(
1✔
31
            (self.n_patch_h * self.n_patch_w * (n_frames // tubelet_size), embed_dim), dropout_rate=0.)
32
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
1✔
33

34
        self.blocks = ModuleList([
1✔
35
            Block(
36
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
37
                drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer,
38
                init_values=init_values
39
            ) for _ in range(depth)])
40

41
        self.head = Linear(embed_dim, output_dim)
1✔
42
        self.apply(self._init_weights)
1✔
43
        no_grad_trunc_normal_(self.mask_token, mean=0., std=0.02, a=-0.02, b=0.02)
1✔
44

45
    @staticmethod
1✔
46
    def _init_weights(m):
1✔
47
        if isinstance(m, nn.Linear):
1✔
48
            nn.init.xavier_uniform_(m.weight)
1✔
49
            if isinstance(m, nn.Linear) and m.bias is not None:
1✔
50
                nn.init.constant_(m.bias, 0)
1✔
51
        elif isinstance(m, nn.LayerNorm):
1✔
52
            nn.init.constant_(m.bias, 0)
1✔
53
            nn.init.constant_(m.weight, 1.0)
1✔
54

55
    def unpatch_to_img(self, x: Tensor) -> Tensor:
1✔
56
        # x: (Batch, No. batches, Prod of cube size * C)
57
        x = rearrange(x, "b n (c p) -> b n p c", c=3)
×
58
        # x: (Batch, No. batches, Prod of cube size, C)
59
        x = rearrange(x, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", p0=self.tubelet_size,
×
60
            p1=self.patch_size, p2=self.patch_size, h=self.n_patch_h, w=self.n_patch_w)
61
        # x: (B, C, T, H, W)
62
        return x
×
63

64
    def forward_features(self, x, return_token_num=0):
1✔
65
        for block in self.blocks:
1✔
66
            x = block(x)
1✔
67

68
        if return_token_num > 0:
1✔
69
            x = x[:, -return_token_num:]
1✔
70

71
        x = self.norm(x)
1✔
72
        x = self.head(x)
1✔
73
        # x: (B, N_mask, C)
74
        return x
1✔
75

76
    def forward(self, x, mask):
1✔
77
        # mask: 0 -> masked, 1 -> visible
78
        b, n, c = x.shape
1✔
79
        expand_pos_embed = self.pos_embedding.emb.data.expand(b, -1, -1)
1✔
80
        pos_emb_vis = expand_pos_embed[mask].view(b, -1, c)
1✔
81
        pos_emb_mask = expand_pos_embed[~mask].view(b, -1, c)
1✔
82
        x = torch.cat([x + pos_emb_vis, self.mask_token + pos_emb_mask], dim=1)
1✔
83

84
        mask_num = pos_emb_mask.shape[1]
1✔
85

86
        x = self.forward_features(x, return_token_num=mask_num)
1✔
87
        return x
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc