• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

cosanlab / py-feat / 15090929758

19 Oct 2024 05:10AM UTC coverage: 54.553%. First build
15090929758

push

github

web-flow
Merge pull request #228 from cosanlab/huggingface

WIP: Huggingface Integration

702 of 1620 new or added lines in 46 files covered. (43.33%)

3409 of 6249 relevant lines covered (54.55%)

3.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.43
/feat/landmark_detectors/mobilefacenet_test.py
1
from torch.nn import (
6✔
2
    Linear,
3
    Conv2d,
4
    BatchNorm1d,
5
    BatchNorm2d,
6
    PReLU,
7
    Sequential,
8
    Module,
9
)
10
import torch
6✔
11
import torch.nn as nn
6✔
12
from huggingface_hub import PyTorchModelHubMixin
6✔
13

14
##################################  Original Arcface Model #############################################################
15

16

17
class Flatten(Module):
6✔
18
    def forward(self, input):
6✔
19
        return input.view(input.size(0), -1)
6✔
20

21

22
##################################  MobileFaceNet #############################################################
23

24

25
class Conv_block(Module):
6✔
26
    def __init__(
6✔
27
        self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1
28
    ):
29
        super(Conv_block, self).__init__()
6✔
30
        self.conv = Conv2d(
6✔
31
            in_c,
32
            out_channels=out_c,
33
            kernel_size=kernel,
34
            groups=groups,
35
            stride=stride,
36
            padding=padding,
37
            bias=False,
38
        )
39
        self.bn = BatchNorm2d(
6✔
40
            out_c
41
        )  # Here is another MPS issue where data is not float32
42
        self.prelu = PReLU(out_c)
6✔
43

44
        # Ensure BatchNorm parameters are float32
45
        self.bn.weight.data = self.bn.weight.data.float()
6✔
46
        self.bn.bias.data = self.bn.bias.data.float()
6✔
47
        self.bn.running_mean.data = self.bn.running_mean.data.float()
6✔
48
        self.bn.running_var.data = self.bn.running_var.data.float()
6✔
49

50
    def forward(self, x):
6✔
51
        x = self.conv(x)
6✔
52
        x = self.bn(x)
6✔
53
        x = self.prelu(x)
6✔
54
        return x
6✔
55

56
    @property
6✔
57
    def weight(self):
6✔
NEW
58
        return self.conv.weight
×
59

60
    @property
6✔
61
    def bias(self):
6✔
NEW
62
        return self.conv.bias
×
63

64

65
class Linear_block(Module):
6✔
66
    def __init__(
6✔
67
        self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1
68
    ):
69
        super(Linear_block, self).__init__()
6✔
70
        self.conv = Conv2d(
6✔
71
            in_c,
72
            out_channels=out_c,
73
            kernel_size=kernel,
74
            groups=groups,
75
            stride=stride,
76
            padding=padding,
77
            bias=False,
78
        )
79
        self.bn = BatchNorm2d(out_c)
6✔
80

81
    def forward(self, x):
6✔
82
        x = self.conv(x)
6✔
83
        x = self.bn(x)
6✔
84
        return x
6✔
85

86

87
class Depth_Wise(Module):
6✔
88
    def __init__(
6✔
89
        self,
90
        in_c,
91
        out_c,
92
        residual=False,
93
        kernel=(3, 3),
94
        stride=(2, 2),
95
        padding=(1, 1),
96
        groups=1,
97
    ):
98
        super(Depth_Wise, self).__init__()
6✔
99
        self.conv = Conv_block(
6✔
100
            in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)
101
        )
102
        self.conv_dw = Conv_block(
6✔
103
            groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride
104
        )
105
        self.project = Linear_block(
6✔
106
            groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)
107
        )
108
        self.residual = residual
6✔
109

110
    def forward(self, x):
6✔
111
        if self.residual:
6✔
112
            short_cut = x
6✔
113
        x = self.conv(x)
6✔
114
        x = self.conv_dw(x)
6✔
115
        x = self.project(x)
6✔
116
        if self.residual:
6✔
117
            output = short_cut + x
6✔
118
        else:
119
            output = x
6✔
120
        return output
6✔
121

122

123
class Residual(Module):
6✔
124
    def __init__(
6✔
125
        self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
126
    ):
127
        super(Residual, self).__init__()
6✔
128
        modules = []
6✔
129
        for _ in range(num_block):
6✔
130
            modules.append(
6✔
131
                Depth_Wise(
132
                    c,
133
                    c,
134
                    residual=True,
135
                    kernel=kernel,
136
                    padding=padding,
137
                    stride=stride,
138
                    groups=groups,
139
                )
140
            )
141
        self.model = Sequential(*modules)
6✔
142

143
    def forward(self, x):
6✔
144
        return self.model(x)
6✔
145

146

147
class GNAP(Module):
6✔
148
    def __init__(self, embedding_size):
6✔
149
        super(GNAP, self).__init__()
×
150
        assert embedding_size == 512
×
151
        self.bn1 = BatchNorm2d(512, affine=False)
×
152
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
×
153

154
        self.bn2 = BatchNorm1d(512, affine=False)
×
155

156
    def forward(self, x):
6✔
157
        x = self.bn1(x)
×
158
        x_norm = torch.norm(x, 2, 1, True)
×
159
        x_norm_mean = torch.mean(x_norm)
×
160
        weight = x_norm_mean / x_norm
×
161
        x = x * weight
×
162
        x = self.pool(x)
×
163
        x = x.view(x.shape[0], -1)
×
164
        feature = self.bn2(x)
×
165
        return feature
×
166

167

168
class GDC(Module):
6✔
169
    def __init__(self, embedding_size):
6✔
170
        super(GDC, self).__init__()
6✔
171
        self.conv_6_dw = Linear_block(
6✔
172
            512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)
173
        )
174
        self.conv_6_flatten = Flatten()
6✔
175
        self.linear = Linear(512, embedding_size, bias=False)
6✔
176
        # self.bn = BatchNorm1d(embedding_size, affine=False)
177
        self.bn = BatchNorm1d(embedding_size)
6✔
178

179
    def forward(self, x):
6✔
180
        x = self.conv_6_dw(x)
6✔
181
        x = self.conv_6_flatten(x)
6✔
182
        x = self.linear(x)
6✔
183
        x = self.bn(x)
6✔
184
        return x
6✔
185

186

187
class MobileFaceNet(Module, PyTorchModelHubMixin):
6✔
188
    def __init__(self, input_size, embedding_size=512, output_name="GDC", device="cpu"):
6✔
189
        super(MobileFaceNet, self).__init__()
6✔
190
        # Make sure this module is compatible with mps
191
        self.device = device
6✔
192
        self.to(device)
6✔
193
        # torch.set_default_dtype(torch.float32) # Ensure default dtype is float32 for MPS compatibility
194

195
        assert output_name in ["GNAP", "GDC"]
6✔
196
        assert input_size[0] in [112]
6✔
197
        self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
6✔
198
        self.conv2_dw = Conv_block(
6✔
199
            64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
200
        )
201
        self.conv_23 = Depth_Wise(
6✔
202
            64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128
203
        )
204
        self.conv_3 = Residual(
6✔
205
            64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
206
        )
207
        self.conv_34 = Depth_Wise(
6✔
208
            64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256
209
        )
210
        self.conv_4 = Residual(
6✔
211
            128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
212
        )
213
        self.conv_45 = Depth_Wise(
6✔
214
            128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512
215
        )
216
        self.conv_5 = Residual(
6✔
217
            128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
218
        )
219
        self.conv_6_sep = Conv_block(
6✔
220
            128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)
221
        )
222
        if output_name == "GNAP":
6✔
223
            self.output_layer = GNAP(512)
×
224
        else:
225
            self.output_layer = GDC(embedding_size)
6✔
226

227
        self._initialize_weights()
6✔
228

229
    def _initialize_weights(self):
6✔
230
        for m in self.modules():
6✔
231
            if isinstance(m, nn.Conv2d):
6✔
232
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
233
                if m.bias is not None:
6✔
234
                    m.bias.data.zero_()
×
235
            elif isinstance(m, nn.BatchNorm2d):
6✔
236
                m.weight.data.fill_(1)
6✔
237
                m.bias.data.zero_()
6✔
238
            elif isinstance(m, nn.Linear):
6✔
239
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
240
                if m.bias is not None:
6✔
241
                    m.bias.data.zero_()
×
242

243
    def forward(self, x):
6✔
244
        # Ensure this module is compatible with mps
245
        x = x.to(self.device)
6✔
246
        x = x.to(self.device).float()
6✔
247

248
        out = self.conv1(x)
6✔
249

250
        out = self.conv2_dw(out)
6✔
251

252
        out = self.conv_23(out)
6✔
253

254
        out = self.conv_3(out)
6✔
255

256
        out = self.conv_34(out)
6✔
257

258
        out = self.conv_4(out)
6✔
259

260
        out = self.conv_45(out)
6✔
261

262
        out = self.conv_5(out)
6✔
263

264
        conv_features = self.conv_6_sep(out)
6✔
265
        out = self.output_layer(conv_features)
6✔
266
        return out, conv_features
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc