• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ControlNet / MARLIN / 9011167809

09 May 2024 01:56AM CUT coverage: 65.763%. First build
9011167809

Pull #25

github

web-flow
Merge 7c160e431 into 23491494b
Pull Request #25: hotfix: add lightning dependancy for _cosine_scheduler_fn(..)

461 of 701 relevant lines covered (65.76%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

51.05
/src/marlin_pytorch/model/marlin.py
1
import os.path
1✔
2
import shutil
1✔
3
from collections import deque
1✔
4
from pathlib import Path
1✔
5
from typing import Generator, Optional
1✔
6
from urllib.request import urlretrieve
1✔
7

8
import cv2
1✔
9
import ffmpeg
1✔
10
import numpy as np
1✔
11
import torch
1✔
12
from einops import rearrange
1✔
13
from torch import Tensor
1✔
14
from torch.nn import Linear, Module
1✔
15

16
from ..config import resolve_config, Downloadable
1✔
17
from ..face_detector import FaceXZooFaceDetector
1✔
18

19
from .decoder import MarlinDecoder
1✔
20
from .encoder import MarlinEncoder
1✔
21
from ..util import read_video, padding_video, DownloadProgressBar
1✔
22

23

24
class Marlin(Module):
1✔
25

26
    def __init__(self,
1✔
27
        img_size: int,
28
        patch_size: int,
29
        n_frames: int,
30
        encoder_embed_dim: int,
31
        encoder_depth: int,
32
        encoder_num_heads: int,
33
        decoder_embed_dim: int,
34
        decoder_depth: int,
35
        decoder_num_heads: int,
36
        mlp_ratio: float,
37
        qkv_bias: bool,
38
        qk_scale: Optional[float],
39
        drop_rate: float,
40
        attn_drop_rate: float,
41
        norm_layer: str,
42
        init_values: float,
43
        tubelet_size: int,
44
        as_feature_extractor: bool = True,
45
    ):
46
        super().__init__()
1✔
47
        self.encoder = MarlinEncoder(
1✔
48
            img_size=img_size,
49
            patch_size=patch_size,
50
            n_frames=n_frames,
51
            embed_dim=encoder_embed_dim,
52
            depth=encoder_depth,
53
            num_heads=encoder_num_heads,
54
            mlp_ratio=mlp_ratio,
55
            qkv_bias=qkv_bias,
56
            qk_scale=qk_scale,
57
            drop_rate=drop_rate,
58
            attn_drop_rate=attn_drop_rate,
59
            norm_layer=norm_layer,
60
            init_values=init_values,
61
            tubelet_size=tubelet_size
62
        )
63
        self.as_feature_extractor = as_feature_extractor
1✔
64
        self.clip_frames = n_frames
1✔
65
        if as_feature_extractor:
1✔
66
            self.enc_dec_proj = None
1✔
67
            self.decoder = None
1✔
68
        else:
69
            self.decoder = MarlinDecoder(
1✔
70
                img_size=img_size,
71
                patch_size=patch_size,
72
                embed_dim=decoder_embed_dim,
73
                depth=decoder_depth,
74
                num_heads=decoder_num_heads,
75
                mlp_ratio=mlp_ratio,
76
                qkv_bias=qkv_bias,
77
                qk_scale=qk_scale,
78
                drop_rate=drop_rate,
79
                attn_drop_rate=attn_drop_rate,
80
                norm_layer=norm_layer,
81
                init_values=init_values,
82
                tubelet_size=tubelet_size
83
            )
84

85
            self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
1✔
86

87
    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
1✔
88
        if self.as_feature_extractor:
1✔
89
            raise RuntimeError("For feature extraction, please use `extract_features` or `extract_video`.")
×
90
        else:
91
            assert mask is not None
1✔
92
            x = self.encoder(x, mask)
1✔
93
            x = self.enc_dec_proj(x)
1✔
94
            x = self.decoder(x, mask)
1✔
95
        return x
1✔
96

97
    @property
1✔
98
    def device(self):
1✔
99
        return self.encoder.norm.weight.device
1✔
100

101
    def extract_features(self, x: Tensor, keep_seq: bool = True):
1✔
102
        """Extract features for one video clip (v)"""
103
        if self.training:
1✔
104
            return self.encoder.extract_features(x, seq_mean_pool=not keep_seq)
1✔
105
        else:
106
            with torch.no_grad():
×
107
                return self.encoder.extract_features(x, seq_mean_pool=not keep_seq)
×
108

109
    def _crop_face(self, v: Tensor) -> Tensor:
1✔
110
        # use face sdk to crop face
111
        # v: (1, C, T, H, W)
112
        v = (rearrange(v, "b c t h w -> (b t) h w c").cpu().numpy() * 255).astype(np.uint8)
×
113
        face_frames = []
×
114
        for i in range(v.shape[0]):
×
115
            # crop_face result: (H, W, C)
116
            face_frames.append(torch.from_numpy(FaceXZooFaceDetector.crop_face(v[i])[0]))
×
117

118
        faces = torch.stack(face_frames)  # (T, H, W, C)
×
119
        return rearrange(faces, "(b t) h w c -> b c t h w", b=1).to(self.device) / 255
×
120

121
    @torch.no_grad()
1✔
122
    def extract_video(self, video_path: str, crop_face: bool = False, sample_rate: int = 2,
1✔
123
        stride: int = 16,
124
        reduction: str = "none",
125
        keep_seq: bool = False,
126
        detector_device: Optional[str] = None
127
    ) -> Tensor:
128
        self.eval()
×
129
        features = []
×
130
        for v in self._load_video(video_path, sample_rate, stride):
×
131
            # v: (1, C, T, H, W)
132
            if crop_face:
×
133
                if not FaceXZooFaceDetector.inited:
×
134
                    Path(".marlin").mkdir(exist_ok=True)
×
135
                    FaceXZooFaceDetector.init(
×
136
                        face_sdk_path=FaceXZooFaceDetector.install(os.path.join(".marlin", "FaceXZoo")),
137
                        device=detector_device or self.device
138
                    )
139
                v = self._crop_face(v)
×
140
            assert v.shape[3:] == (224, 224)
×
141
            features.append(self.extract_features(v, keep_seq=keep_seq))
×
142

143
        features = torch.cat(features)  # (N, 768)
×
144

145
        if reduction == "mean":
×
146
            return features.mean(dim=0)
×
147
        elif reduction == "max":
×
148
            return features.max(dim=0)[0]
×
149

150
        return features
×
151

152
    def _load_video(self, video_path: str, sample_rate: int, stride: int) -> Generator[Tensor, None, None]:
1✔
153
        probe = ffmpeg.probe(video_path)
×
154
        total_frames = int(probe["streams"][0]["nb_frames"])
×
155
        if total_frames <= self.clip_frames:
×
156
            video = read_video(video_path, channel_first=True) / 255  # (T, C, H, W)
×
157
            # pad frames to 16
158
            v = padding_video(video, self.clip_frames, "same")  # (T, C, H, W)
×
159
            assert v.shape[0] == self.clip_frames
×
160
            yield v.permute(1, 0, 2, 3).unsqueeze(0).to(self.device)
×
161
        elif total_frames <= self.clip_frames * sample_rate:
×
162
            video = read_video(video_path, channel_first=True) / 255  # (T, C, H, W)
×
163
            # use first 16 frames
164
            if video.shape[0] < self.clip_frames:
×
165
                # double-check the number of frames, see https://github.com/pytorch/vision/issues/2490
166
                v = padding_video(video, self.clip_frames, "same")  # (T, C, H, W)
×
167
            v = video[:self.clip_frames]
×
168
            yield v.permute(1, 0, 2, 3).unsqueeze(0).to(self.device)
×
169
        else:
170
            # extract features based on sliding window
171
            cap = cv2.VideoCapture(video_path)
×
172
            deq = deque(maxlen=self.clip_frames)
×
173

174
            clip_start_indexes = list(range(0, total_frames - self.clip_frames * sample_rate, stride * sample_rate))
×
175
            clip_end_indexes = [i + self.clip_frames * sample_rate - 1 for i in clip_start_indexes]
×
176

177
            current_index = -1
×
178
            while True:
×
179
                ret, frame = cap.read()
×
180
                if not ret:
×
181
                    break
×
182
                current_index += 1
×
183
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
×
184
                frame = torch.from_numpy(frame).permute(2, 0, 1) / 255  # (C, H, W)
×
185

186
                for _ in range(sample_rate - 1):
×
187
                    cap.read()
×
188
                    current_index += 1
×
189

190
                deq.append(frame)
×
191
                if current_index in clip_end_indexes:
×
192
                    v = torch.stack(list(deq))  # (T, C, H, W)
×
193
                    yield v.permute(1, 0, 2, 3).unsqueeze(0).to(self.device)
×
194

195
            cap.release()
×
196

197
    @classmethod
1✔
198
    def from_file(cls, model_name: str, path: str) -> "Marlin":
1✔
199
        if path.endswith(".pt"):
1✔
200
            state_dict = torch.load(path, map_location="cpu")
1✔
201
        elif path.endswith(".ckpt"):
×
202
            state_dict = torch.load(path, map_location="cpu")["state_dict"]
×
203

204
            discriminator_keys = [k for k in state_dict.keys() if k.startswith("discriminator")]
×
205
            for key in discriminator_keys:
×
206
                del state_dict[key]
×
207
        else:
208
            raise ValueError(f"Unsupported file type: {path.split('.')[-1]}")
×
209
        # determine if the checkpoint is full model or encoder only.
210
        for key in state_dict.keys():
1✔
211
            if key.startswith("decoder."):
1✔
212
                as_feature_extractor = False
1✔
213
                break
1✔
214
        else:
215
            as_feature_extractor = True
1✔
216

217
        config = resolve_config(model_name)
1✔
218
        model = cls(
1✔
219
            img_size=config.img_size,
220
            patch_size=config.patch_size,
221
            n_frames=config.n_frames,
222
            encoder_embed_dim=config.encoder_embed_dim,
223
            encoder_depth=config.encoder_depth,
224
            encoder_num_heads=config.encoder_num_heads,
225
            decoder_embed_dim=config.decoder_embed_dim,
226
            decoder_depth=config.decoder_depth,
227
            decoder_num_heads=config.decoder_num_heads,
228
            mlp_ratio=config.mlp_ratio,
229
            qkv_bias=config.qkv_bias,
230
            qk_scale=config.qk_scale,
231
            drop_rate=config.drop_rate,
232
            attn_drop_rate=config.attn_drop_rate,
233
            norm_layer=config.norm_layer,
234
            init_values=config.init_values,
235
            tubelet_size=config.tubelet_size,
236
            as_feature_extractor=as_feature_extractor
237
        )
238
        model.load_state_dict(state_dict)
1✔
239
        return model
1✔
240

241
    @classmethod
1✔
242
    def from_online(cls, model_name: str, full_model: bool = False) -> "Marlin":
1✔
243
        config = resolve_config(model_name)
1✔
244
        if not isinstance(config, Downloadable):
1✔
245
            raise ValueError(f"Model {model_name} is not downloadable.")
×
246

247
        url = config.full_model_url if full_model else config.encoder_model_url
1✔
248
        path = Path(".marlin")
1✔
249
        path.mkdir(exist_ok=True)
1✔
250
        file = path / f"{model_name}.{'full' if full_model else 'encoder'}.pt"
1✔
251
        if not file.exists():
1✔
252
            with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc="Downloading Marlin model") as pb:
1✔
253
                urlretrieve(url, filename=file, reporthook=pb.update_to)
1✔
254
        return cls.from_file(model_name, str(file))
1✔
255

256
    @classmethod
1✔
257
    def clean_cache(cls, verbose: bool = True) -> None:
1✔
258
        path = Path(".marlin")
×
259
        if path.exists():
×
260
            shutil.rmtree(path)
×
261
            if verbose:
×
262
                print("Marlin checkpoints cache cleaned.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc