• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ControlNet / MARLIN / 9011167809

09 May 2024 01:56AM CUT coverage: 65.763%. First build
9011167809

Pull #25

github

web-flow
Merge 7c160e431 into 23491494b
Pull Request #25: hotfix: add lightning dependancy for _cosine_scheduler_fn(..)

461 of 701 relevant lines covered (65.76%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.48
/src/marlin_pytorch/config.py
1
from abc import ABC, abstractmethod
1✔
2
from dataclasses import dataclass
1✔
3
from typing import Optional, Type, TypeVar
1✔
4

5
from marlin_pytorch.util import read_yaml, Singleton, NoArgInit
1✔
6

7

8
@dataclass
1✔
9
class MarlinConfig:
1✔
10
    img_size: int
1✔
11
    patch_size: int
1✔
12
    n_frames: int
1✔
13
    encoder_embed_dim: int
1✔
14
    encoder_depth: int
1✔
15
    encoder_num_heads: int
1✔
16
    decoder_embed_dim: int
1✔
17
    decoder_depth: int
1✔
18
    decoder_num_heads: int
1✔
19
    mlp_ratio: float
1✔
20
    qkv_bias: bool
1✔
21
    qk_scale: Optional[float]
1✔
22
    drop_rate: float
1✔
23
    attn_drop_rate: float
1✔
24
    norm_layer: str
1✔
25
    init_values: float
1✔
26
    tubelet_size: int
1✔
27

28

29
class Downloadable(ABC):
1✔
30

31
    @property
1✔
32
    @abstractmethod
1✔
33
    def full_model_url(self) -> str:
1✔
34
        pass
×
35

36
    @property
1✔
37
    @abstractmethod
1✔
38
    def encoder_model_url(self) -> str:
1✔
39
        pass
×
40

41

42
T = TypeVar("T", bound=MarlinConfig)
1✔
43

44
_configs = {}
1✔
45

46

47
def register_model(name: str):
1✔
48
    def wrapper(cls: Type[T]):
1✔
49
        _configs[name] = cls
1✔
50
        return cls
1✔
51

52
    return wrapper
1✔
53

54

55
class SharedConfig(MarlinConfig):
1✔
56
    img_size = 224
1✔
57
    patch_size = 16
1✔
58
    n_frames = 16
1✔
59
    mlp_ratio = 4.
1✔
60
    qkv_bias = True
1✔
61
    qk_scale = None
1✔
62
    drop_rate = 0.
1✔
63
    attn_drop_rate = 0.
1✔
64
    norm_layer = "LayerNorm"
1✔
65
    init_values = 0.
1✔
66
    tubelet_size = 2
1✔
67

68

69
@register_model("marlin_vit_base_ytf")
1✔
70
@Singleton
1✔
71
class MarlinVitBaseConfig(NoArgInit, SharedConfig, Downloadable):
1✔
72
    encoder_embed_dim = 768
1✔
73
    encoder_depth = 12
1✔
74
    encoder_num_heads = 12
1✔
75
    decoder_embed_dim = 384
1✔
76
    decoder_depth = 4
1✔
77
    decoder_num_heads = 6
1✔
78
    full_model_url = "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.full.pt"
1✔
79
    encoder_model_url = "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.encoder.pt"
1✔
80

81

82
@register_model("marlin_vit_small_ytf")
1✔
83
@Singleton
1✔
84
class MarlinVitSmallConfig(NoArgInit, SharedConfig, Downloadable):
1✔
85
    encoder_embed_dim = 384
1✔
86
    encoder_depth = 12
1✔
87
    encoder_num_heads = 6
1✔
88
    decoder_embed_dim = 192
1✔
89
    decoder_depth = 4
1✔
90
    decoder_num_heads = 3
1✔
91
    full_model_url = \
1✔
92
        "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.full.pt"
93
    encoder_model_url = \
1✔
94
        "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.encoder.pt"
95

96

97
@register_model("marlin_vit_large_ytf")
1✔
98
@Singleton
1✔
99
class MarlinVitLargeConfig(NoArgInit, SharedConfig, Downloadable):
1✔
100
    encoder_embed_dim = 1024
1✔
101
    encoder_depth = 24
1✔
102
    encoder_num_heads = 16
1✔
103
    decoder_embed_dim = 512
1✔
104
    decoder_depth = 12
1✔
105
    decoder_num_heads = 8
1✔
106
    full_model_url = \
1✔
107
        "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.full.pt"
108
    encoder_model_url = \
1✔
109
        "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.encoder.pt"
110

111

112
def register_model_from_yaml(name: str, path: str) -> None:
1✔
113
    config = read_yaml(path)
×
114
    marlin_config = MarlinConfig(
×
115
        img_size=config["img_size"],
116
        patch_size=config["patch_size"],
117
        n_frames=config["clip_frames"],
118
        encoder_embed_dim=config["encoder"]["embed_dim"],
119
        encoder_depth=config["encoder"]["depth"],
120
        encoder_num_heads=config["encoder"]["num_heads"],
121
        decoder_embed_dim=config["decoder"]["embed_dim"],
122
        decoder_depth=config["decoder"]["depth"],
123
        decoder_num_heads=config["decoder"]["num_heads"],
124
        mlp_ratio=config["mlp_ratio"],
125
        qkv_bias=config["qkv_bias"],
126
        qk_scale=config["qk_scale"],
127
        drop_rate=config["drop_rate"],
128
        attn_drop_rate=config["attn_drop_rate"],
129
        norm_layer=config["norm_layer"],
130
        init_values=config["init_values"],
131
        tubelet_size=config["tubelet_size"]
132
    )
133
    _configs[name] = marlin_config
×
134

135

136
def resolve_config(name: str) -> MarlinConfig:
1✔
137
    if name in _configs:
1✔
138
        return _configs[name]
1✔
139
    else:
140
        raise ValueError(f"Model {name} not found. Please register it first. The current registered models are: "
×
141
                         f"{_configs.keys()}")
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc