• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

cosanlab / py-feat / 15090929758

19 Oct 2024 05:10AM UTC coverage: 54.553%. First build
15090929758

push

github

web-flow
Merge pull request #228 from cosanlab/huggingface

WIP: Huggingface Integration

702 of 1620 new or added lines in 46 files covered. (43.33%)

3409 of 6249 relevant lines covered (54.55%)

3.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.81
/feat/emo_detectors/ResMaskNet/resmasknet_test.py
1
"""
2
All code & models from https://github.com/phamquiluan/ResidualMaskingNetwork
3
"""
4

5
# from lib2to3.pytree import convert
6
import os
6✔
7
import json
6✔
8
import numpy as np
6✔
9
import torch
6✔
10
from torchvision.transforms import (
6✔
11
    Resize,
12
    Grayscale,
13
    Compose,
14
)
15
import traceback
6✔
16
import torch.nn as nn
6✔
17
from feat.utils import set_torch_device
6✔
18
from feat.utils.io import get_resource_path
6✔
19
from feat.utils.image_operations import BBox
6✔
20
from huggingface_hub import PyTorchModelHubMixin
6✔
21

22
model_urls = {
6✔
23
    "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
24
    "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
25
    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
26
    "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
27
    "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
28
    "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
29
    "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
30
    "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
31
    "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
32
}
33

34

35
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
6✔
36
    """3x3 convolution with padding"""
37
    return nn.Conv2d(
6✔
38
        in_planes,
39
        out_planes,
40
        kernel_size=3,
41
        stride=stride,
42
        padding=dilation,
43
        groups=groups,
44
        bias=False,
45
        dilation=dilation,
46
    )
47

48

49
def conv1x1(in_planes, out_planes, stride=1):
6✔
50
    """1x1 convolution"""
51
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
6✔
52

53

54
class BasicBlock(nn.Module):
6✔
55
    expansion = 1
6✔
56
    __constants__ = ["downsample"]
6✔
57

58
    def __init__(
6✔
59
        self,
60
        inplanes,
61
        planes,
62
        stride=1,
63
        downsample=None,
64
        groups=1,
65
        base_width=64,
66
        dilation=1,
67
        norm_layer=None,
68
    ):
69
        super(BasicBlock, self).__init__()
6✔
70
        if norm_layer is None:
6✔
71
            norm_layer = nn.BatchNorm2d
6✔
72
        if groups != 1 or base_width != 64:
6✔
73
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
×
74
        if dilation > 1:
6✔
75
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
×
76
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
77
        self.conv1 = conv3x3(inplanes, planes, stride)
6✔
78
        self.bn1 = norm_layer(planes)
6✔
79
        self.relu = nn.ReLU(inplace=True)
6✔
80
        self.conv2 = conv3x3(planes, planes)
6✔
81
        self.bn2 = norm_layer(planes)
6✔
82
        self.downsample = downsample
6✔
83
        self.stride = stride
6✔
84

85
    def forward(self, x):
6✔
86
        identity = x
6✔
87

88
        out = self.conv1(x)
6✔
89
        out = self.bn1(out)
6✔
90
        out = self.relu(out)
6✔
91

92
        out = self.conv2(out)
6✔
93
        out = self.bn2(out)
6✔
94

95
        if self.downsample is not None:
6✔
96
            identity = self.downsample(x)
6✔
97

98
        out += identity
6✔
99
        out = self.relu(out)
6✔
100

101
        return out
6✔
102

103

104
class ResNet(nn.Module):
6✔
105
    def __init__(
6✔
106
        self,
107
        block,
108
        layers,
109
        num_classes=1000,
110
        zero_init_residual=False,
111
        groups=1,
112
        width_per_group=64,
113
        replace_stride_with_dilation=None,
114
        norm_layer=None,
115
        in_channels=3,
116
    ):
117
        super(ResNet, self).__init__()
6✔
118
        if norm_layer is None:
6✔
119
            norm_layer = nn.BatchNorm2d
6✔
120
        self._norm_layer = norm_layer
6✔
121

122
        self.inplanes = 64
6✔
123
        self.dilation = 1
6✔
124
        if replace_stride_with_dilation is None:
6✔
125
            # each element in the tuple indicates if we should replace
126
            # the 2x2 stride with a dilated convolution instead
127
            replace_stride_with_dilation = [False, False, False]
6✔
128
        if len(replace_stride_with_dilation) != 3:
6✔
129
            raise ValueError(
×
130
                "replace_stride_with_dilation should be None "
131
                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
132
            )
133
        self.groups = groups
6✔
134
        self.base_width = width_per_group
6✔
135

136
        # NOTE: strictly set the in_channels = 3 to load the pretrained model
137
        self.conv1 = nn.Conv2d(
6✔
138
            in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
139
        )
140
        self.bn1 = norm_layer(self.inplanes)
6✔
141
        self.relu = nn.ReLU(inplace=True)
6✔
142
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
6✔
143
        self.layer1 = self._make_layer(block, 64, layers[0])
6✔
144
        self.layer2 = self._make_layer(
6✔
145
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
146
        )
147
        self.layer3 = self._make_layer(
6✔
148
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
149
        )
150
        self.layer4 = self._make_layer(
6✔
151
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
152
        )
153
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
6✔
154

155
        # NOTE: strictly set the num_classes = 1000 to load the pretrained model
156
        self.fc = nn.Linear(512 * block.expansion, num_classes)
6✔
157

158
        for m in self.modules():
6✔
159
            if isinstance(m, nn.Conv2d):
6✔
160
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
161
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
6✔
162
                nn.init.constant_(m.weight, 1)
6✔
163
                nn.init.constant_(m.bias, 0)
6✔
164

165
        # Zero-initialize the last BN in each residual branch,
166
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
167
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
168
        if zero_init_residual:
6✔
169
            for m in self.modules():
×
170
                if isinstance(m, BasicBlock):
×
171
                    nn.init.constant_(m.bn2.weight, 0)
×
172

173
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
6✔
174
        norm_layer = self._norm_layer
6✔
175
        downsample = None
6✔
176
        previous_dilation = self.dilation
6✔
177
        if dilate:
6✔
178
            self.dilation *= stride
×
179
            stride = 1
×
180
        if stride != 1 or self.inplanes != planes * block.expansion:
6✔
181
            downsample = nn.Sequential(
6✔
182
                conv1x1(self.inplanes, planes * block.expansion, stride),
183
                norm_layer(planes * block.expansion),
184
            )
185

186
        layers = []
6✔
187
        layers.append(
6✔
188
            block(
189
                self.inplanes,
190
                planes,
191
                stride,
192
                downsample,
193
                self.groups,
194
                self.base_width,
195
                previous_dilation,
196
                norm_layer,
197
            )
198
        )
199
        self.inplanes = planes * block.expansion
6✔
200
        for _ in range(1, blocks):
6✔
201
            layers.append(
6✔
202
                block(
203
                    self.inplanes,
204
                    planes,
205
                    groups=self.groups,
206
                    base_width=self.base_width,
207
                    dilation=self.dilation,
208
                    norm_layer=norm_layer,
209
                )
210
            )
211

212
        return nn.Sequential(*layers)
6✔
213

214
    def forward(self, x):
6✔
215
        x = self.conv1(x)
×
216
        x = self.bn1(x)
×
217
        x = self.relu(x)
×
218
        x = self.maxpool(x)
×
219

220
        x = self.layer1(x)
×
221
        x = self.layer2(x)
×
222
        x = self.layer3(x)
×
223
        x = self.layer4(x)
×
224

225
        x = self.avgpool(x)
×
226
        x = torch.flatten(x, 1)
×
227
        x = self.fc(x)
×
228

229
        return x
×
230

231

232
###################### masking
233

234

235
# from .resnet import conv1x1, conv3x3, BasicBlock, Bottleneck
236

237

238
def up_pooling(in_channels, out_channels, kernel_size=2, stride=2):
6✔
239
    return nn.Sequential(
6✔
240
        nn.ConvTranspose2d(
241
            in_channels, out_channels, kernel_size=kernel_size, stride=stride
242
        ),
243
        nn.BatchNorm2d(out_channels),
244
        nn.ReLU(inplace=True),
245
    )
246

247

248
class Masking4(nn.Module):
6✔
249
    def __init__(self, in_channels, out_channels, block=BasicBlock):
6✔
250
        assert in_channels == out_channels
6✔
251
        super(Masking4, self).__init__()
6✔
252
        filters = [
6✔
253
            in_channels,
254
            in_channels * 2,
255
            in_channels * 4,
256
            in_channels * 8,
257
            in_channels * 16,
258
        ]
259

260
        self.downsample1 = nn.Sequential(
6✔
261
            conv1x1(filters[0], filters[1], 1),
262
            nn.BatchNorm2d(filters[1]),
263
        )
264

265
        self.downsample2 = nn.Sequential(
6✔
266
            conv1x1(filters[1], filters[2], 1),
267
            nn.BatchNorm2d(filters[2]),
268
        )
269

270
        self.downsample3 = nn.Sequential(
6✔
271
            conv1x1(filters[2], filters[3], 1),
272
            nn.BatchNorm2d(filters[3]),
273
        )
274

275
        self.downsample4 = nn.Sequential(
6✔
276
            conv1x1(filters[3], filters[4], 1),
277
            nn.BatchNorm2d(filters[4]),
278
        )
279

280
        """
6✔
281
        self.conv1 = block(filters[0], filters[1], downsample=conv1x1(filters[0], filters[1], 1))
282
        self.conv2 = block(filters[1], filters[2], downsample=conv1x1(filters[1], filters[2], 1))
283
        self.conv3 = block(filters[2], filters[3], downsample=conv1x1(filters[2], filters[3], 1))
284
        """
285

286
        self.conv1 = block(filters[0], filters[1], downsample=self.downsample1)
6✔
287
        self.conv2 = block(filters[1], filters[2], downsample=self.downsample2)
6✔
288
        self.conv3 = block(filters[2], filters[3], downsample=self.downsample3)
6✔
289
        self.conv4 = block(filters[3], filters[4], downsample=self.downsample4)
6✔
290

291
        self.down_pooling = nn.MaxPool2d(kernel_size=2)
6✔
292

293
        self.downsample5 = nn.Sequential(
6✔
294
            conv1x1(filters[4], filters[3], 1),
295
            nn.BatchNorm2d(filters[3]),
296
        )
297

298
        self.downsample6 = nn.Sequential(
6✔
299
            conv1x1(filters[3], filters[2], 1),
300
            nn.BatchNorm2d(filters[2]),
301
        )
302

303
        self.downsample7 = nn.Sequential(
6✔
304
            conv1x1(filters[2], filters[1], 1),
305
            nn.BatchNorm2d(filters[1]),
306
        )
307

308
        self.downsample8 = nn.Sequential(
6✔
309
            conv1x1(filters[1], filters[0], 1),
310
            nn.BatchNorm2d(filters[0]),
311
        )
312

313
        """
6✔
314
        self.up_pool4 = up_pooling(filters[3], filters[2])
315
        self.conv4 = block(filters[3], filters[2], downsample=conv1x1(filters[3], filters[2], 1))
316
        self.up_pool5 = up_pooling(filters[2], filters[1])
317
        self.conv5 = block(filters[2], filters[1], downsample=conv1x1(filters[2], filters[1], 1))
318

319
        self.conv6 = block(filters[1], filters[0], downsample=conv1x1(filters[1], filters[0], 1))
320
        """
321

322
        self.up_pool5 = up_pooling(filters[4], filters[3])
6✔
323
        self.conv5 = block(filters[4], filters[3], downsample=self.downsample5)
6✔
324
        self.up_pool6 = up_pooling(filters[3], filters[2])
6✔
325
        self.conv6 = block(filters[3], filters[2], downsample=self.downsample6)
6✔
326
        self.up_pool7 = up_pooling(filters[2], filters[1])
6✔
327
        self.conv7 = block(filters[2], filters[1], downsample=self.downsample7)
6✔
328
        self.conv8 = block(filters[1], filters[0], downsample=self.downsample8)
6✔
329

330
        # init weight
331
        for m in self.modules():
6✔
332
            if isinstance(m, nn.Conv2d):
6✔
333
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
334
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
6✔
335
                nn.init.constant_(m.weight, 1)
6✔
336
                nn.init.constant_(m.bias, 0)
6✔
337

338
        # Zero-initialize the last BN in each residual branch,
339
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
340
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
341
        for m in self.modules():
6✔
342
            if isinstance(m, BasicBlock):
6✔
343
                nn.init.constant_(m.bn2.weight, 0)
6✔
344

345
    def forward(self, x):
6✔
346
        x1 = self.conv1(x)
6✔
347
        p1 = self.down_pooling(x1)
6✔
348
        x2 = self.conv2(p1)
6✔
349
        p2 = self.down_pooling(x2)
6✔
350
        x3 = self.conv3(p2)
6✔
351
        p3 = self.down_pooling(x3)
6✔
352
        x4 = self.conv4(p3)
6✔
353

354
        x5 = self.up_pool5(x4)
6✔
355
        x5 = torch.cat([x5, x3], dim=1)
6✔
356
        x5 = self.conv5(x5)
6✔
357

358
        x6 = self.up_pool6(x5)
6✔
359
        x6 = torch.cat([x6, x2], dim=1)
6✔
360
        x6 = self.conv6(x6)
6✔
361

362
        x7 = self.up_pool7(x6)
6✔
363
        x7 = torch.cat([x7, x1], dim=1)
6✔
364
        x7 = self.conv7(x7)
6✔
365

366
        x8 = self.conv8(x7)
6✔
367

368
        output = torch.softmax(x8, dim=1)
6✔
369
        # output = torch.sigmoid(x8)
370
        return output
6✔
371

372

373
class Masking3(nn.Module):
6✔
374
    def __init__(self, in_channels, out_channels, block=BasicBlock):
6✔
375
        assert in_channels == out_channels
6✔
376
        super(Masking3, self).__init__()
6✔
377
        filters = [in_channels, in_channels * 2, in_channels * 4, in_channels * 8]
6✔
378

379
        self.downsample1 = nn.Sequential(
6✔
380
            conv1x1(filters[0], filters[1], 1),
381
            nn.BatchNorm2d(filters[1]),
382
        )
383

384
        self.downsample2 = nn.Sequential(
6✔
385
            conv1x1(filters[1], filters[2], 1),
386
            nn.BatchNorm2d(filters[2]),
387
        )
388

389
        self.downsample3 = nn.Sequential(
6✔
390
            conv1x1(filters[2], filters[3], 1),
391
            nn.BatchNorm2d(filters[3]),
392
        )
393

394
        """
6✔
395
        self.conv1 = block(filters[0], filters[1], downsample=conv1x1(filters[0], filters[1], 1))
396
        self.conv2 = block(filters[1], filters[2], downsample=conv1x1(filters[1], filters[2], 1))
397
        self.conv3 = block(filters[2], filters[3], downsample=conv1x1(filters[2], filters[3], 1))
398
        """
399

400
        self.conv1 = block(filters[0], filters[1], downsample=self.downsample1)
6✔
401
        self.conv2 = block(filters[1], filters[2], downsample=self.downsample2)
6✔
402
        self.conv3 = block(filters[2], filters[3], downsample=self.downsample3)
6✔
403

404
        self.down_pooling = nn.MaxPool2d(kernel_size=2)
6✔
405

406
        self.downsample4 = nn.Sequential(
6✔
407
            conv1x1(filters[3], filters[2], 1),
408
            nn.BatchNorm2d(filters[2]),
409
        )
410

411
        self.downsample5 = nn.Sequential(
6✔
412
            conv1x1(filters[2], filters[1], 1),
413
            nn.BatchNorm2d(filters[1]),
414
        )
415

416
        self.downsample6 = nn.Sequential(
6✔
417
            conv1x1(filters[1], filters[0], 1),
418
            nn.BatchNorm2d(filters[0]),
419
        )
420

421
        """
6✔
422
        self.up_pool4 = up_pooling(filters[3], filters[2])
423
        self.conv4 = block(filters[3], filters[2], downsample=conv1x1(filters[3], filters[2], 1))
424
        self.up_pool5 = up_pooling(filters[2], filters[1])
425
        self.conv5 = block(filters[2], filters[1], downsample=conv1x1(filters[2], filters[1], 1))
426

427
        self.conv6 = block(filters[1], filters[0], downsample=conv1x1(filters[1], filters[0], 1))
428
        """
429

430
        self.up_pool4 = up_pooling(filters[3], filters[2])
6✔
431
        self.conv4 = block(filters[3], filters[2], downsample=self.downsample4)
6✔
432
        self.up_pool5 = up_pooling(filters[2], filters[1])
6✔
433
        self.conv5 = block(filters[2], filters[1], downsample=self.downsample5)
6✔
434

435
        self.conv6 = block(filters[1], filters[0], downsample=self.downsample6)
6✔
436

437
        # init weight
438
        for m in self.modules():
6✔
439
            if isinstance(m, nn.Conv2d):
6✔
440
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
441
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
6✔
442
                nn.init.constant_(m.weight, 1)
6✔
443
                nn.init.constant_(m.bias, 0)
6✔
444

445
        # Zero-initialize the last BN in each residual branch,
446
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
447
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
448
        for m in self.modules():
6✔
449
            if isinstance(m, BasicBlock):
6✔
450
                nn.init.constant_(m.bn2.weight, 0)
6✔
451

452
    def forward(self, x):
6✔
453
        x1 = self.conv1(x)
6✔
454
        p1 = self.down_pooling(x1)
6✔
455
        x2 = self.conv2(p1)
6✔
456
        p2 = self.down_pooling(x2)
6✔
457
        x3 = self.conv3(p2)
6✔
458

459
        x4 = self.up_pool4(x3)
6✔
460
        x4 = torch.cat([x4, x2], dim=1)
6✔
461

462
        x4 = self.conv4(x4)
6✔
463

464
        x5 = self.up_pool5(x4)
6✔
465
        x5 = torch.cat([x5, x1], dim=1)
6✔
466
        x5 = self.conv5(x5)
6✔
467

468
        x6 = self.conv6(x5)
6✔
469

470
        output = torch.softmax(x6, dim=1)
6✔
471
        # output = torch.sigmoid(x6)
472
        return output
6✔
473

474

475
class Masking2(nn.Module):
6✔
476
    def __init__(self, in_channels, out_channels, block=BasicBlock):
6✔
477
        assert in_channels == out_channels
6✔
478
        super(Masking2, self).__init__()
6✔
479
        filters = [in_channels, in_channels * 2, in_channels * 4, in_channels * 8]
6✔
480

481
        self.downsample1 = nn.Sequential(
6✔
482
            conv1x1(filters[0], filters[1], 1),
483
            nn.BatchNorm2d(filters[1]),
484
        )
485

486
        self.downsample2 = nn.Sequential(
6✔
487
            conv1x1(filters[1], filters[2], 1),
488
            nn.BatchNorm2d(filters[2]),
489
        )
490

491
        """
6✔
492
        self.conv1 = block(filters[0], filters[1], downsample=conv1x1(filters[0], filters[1], 1))
493
        self.conv2 = block(filters[1], filters[2], downsample=conv1x1(filters[1], filters[2], 1))
494
        """
495
        self.conv1 = block(filters[0], filters[1], downsample=self.downsample1)
6✔
496
        self.conv2 = block(filters[1], filters[2], downsample=self.downsample2)
6✔
497

498
        self.down_pooling = nn.MaxPool2d(kernel_size=2)
6✔
499

500
        self.downsample3 = nn.Sequential(
6✔
501
            conv1x1(filters[2], filters[1], 1),
502
            nn.BatchNorm2d(filters[1]),
503
        )
504

505
        self.downsample4 = nn.Sequential(
6✔
506
            conv1x1(filters[1], filters[0], 1),
507
            nn.BatchNorm2d(filters[0]),
508
        )
509

510
        """
6✔
511
        self.up_pool3 = up_pooling(filters[2], filters[1])
512
        self.conv3 = block(filters[2], filters[1], downsample=conv1x1(filters[2], filters[1], 1))
513
        self.conv4 = block(filters[1], filters[0], downsample=conv1x1(filters[1], filters[0], 1))
514
        """
515
        self.up_pool3 = up_pooling(filters[2], filters[1])
6✔
516
        self.conv3 = block(filters[2], filters[1], downsample=self.downsample3)
6✔
517
        self.conv4 = block(filters[1], filters[0], downsample=self.downsample4)
6✔
518

519
        # init weight
520
        for m in self.modules():
6✔
521
            if isinstance(m, nn.Conv2d):
6✔
522
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
523
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
6✔
524
                nn.init.constant_(m.weight, 1)
6✔
525
                nn.init.constant_(m.bias, 0)
6✔
526

527
        # Zero-initialize the last BN in each residual branch,
528
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
529
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
530
        for m in self.modules():
6✔
531
            if isinstance(m, BasicBlock):
6✔
532
                nn.init.constant_(m.bn2.weight, 0)
6✔
533

534
    def forward(self, x):
6✔
535
        x1 = self.conv1(x)
6✔
536
        p1 = self.down_pooling(x1)
6✔
537
        x2 = self.conv2(p1)
6✔
538

539
        x3 = self.up_pool3(x2)
6✔
540
        x3 = torch.cat([x3, x1], dim=1)
6✔
541
        x3 = self.conv3(x3)
6✔
542

543
        x4 = self.conv4(x3)
6✔
544

545
        output = torch.softmax(x4, dim=1)
6✔
546
        # output = torch.sigmoid(x4)
547
        return output
6✔
548

549

550
class Masking1(nn.Module):
6✔
551
    def __init__(self, in_channels, out_channels, block=BasicBlock):
6✔
552
        assert in_channels == out_channels
6✔
553
        super(Masking1, self).__init__()
6✔
554
        filters = [in_channels, in_channels * 2, in_channels * 4, in_channels * 8]
6✔
555

556
        self.downsample1 = nn.Sequential(
6✔
557
            conv1x1(filters[0], filters[1], 1),
558
            nn.BatchNorm2d(filters[1]),
559
        )
560

561
        self.conv1 = block(filters[0], filters[1], downsample=self.downsample1)
6✔
562

563
        self.downsample2 = nn.Sequential(
6✔
564
            conv1x1(filters[1], filters[0], 1),
565
            nn.BatchNorm2d(filters[0]),
566
        )
567

568
        self.conv2 = block(filters[1], filters[0], downsample=self.downsample2)
6✔
569

570
        # init weight
571
        for m in self.modules():
6✔
572
            if isinstance(m, nn.Conv2d):
6✔
573
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
6✔
574
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
6✔
575
                nn.init.constant_(m.weight, 1)
6✔
576
                nn.init.constant_(m.bias, 0)
6✔
577

578
        # Zero-initialize the last BN in each residual branch,
579
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
580
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
581
        for m in self.modules():
6✔
582
            if isinstance(m, BasicBlock):
6✔
583
                nn.init.constant_(m.bn2.weight, 0)
6✔
584

585
    def forward(self, x):
6✔
586
        x1 = self.conv1(x)
6✔
587
        x2 = self.conv2(x1)
6✔
588
        output = torch.softmax(x2, dim=1)
6✔
589
        # output = torch.sigmoid(x2)
590
        return output
6✔
591

592

593
def masking(in_channels, out_channels, depth, block=BasicBlock):
6✔
594
    if depth == 1:
6✔
595
        return Masking1(in_channels, out_channels, block)
6✔
596
    elif depth == 2:
6✔
597
        return Masking2(in_channels, out_channels, block)
6✔
598
    elif depth == 3:
6✔
599
        return Masking3(in_channels, out_channels, block)
6✔
600
    elif depth == 4:
6✔
601
        return Masking4(in_channels, out_channels, block)
6✔
602
    else:
603
        traceback.print_exc()
×
604
        raise Exception("depth need to be from 0-3")
×
605

606

607
#######################
608
# from .resnet import conv1x1, conv3x3, BasicBlock, Bottleneck
609

610

611
class ResMasking(ResNet, PyTorchModelHubMixin):
6✔
612
    def __init__(self, weight_path, in_channels=3):
6✔
613
        super(ResMasking, self).__init__(
6✔
614
            block=BasicBlock,
615
            layers=[3, 4, 6, 3],
616
            in_channels=in_channels,
617
            num_classes=1000,
618
        )
619
        # state_dict = torch.load('saved/checkpoints/resnet18_rot30_2019Nov05_17.44')['net']
620
        # state_dict = load_state_dict_from_url(model_urls['resnet34'], progress=True)
621
        # self.load_state_dict(state_dict)
622

623
        self.fc = nn.Linear(512, 7)
6✔
624

625
        """
6✔
626
        # freeze all net
627
        for m in self.parameters():
628
            m.requires_grad = False
629
        """
630

631
        self.mask1 = masking(64, 64, depth=4)
6✔
632
        self.mask2 = masking(128, 128, depth=3)
6✔
633
        self.mask3 = masking(256, 256, depth=2)
6✔
634
        self.mask4 = masking(512, 512, depth=1)
6✔
635

636
    def forward(self, x):  # 224
6✔
637
        x = self.conv1(x)  # 112
6✔
638
        x = self.bn1(x)
6✔
639
        x = self.relu(x)
6✔
640
        x = self.maxpool(x)  # 56
6✔
641

642
        x = self.layer1(x)  # 56
6✔
643
        m = self.mask1(x)
6✔
644
        x = x * (1 + m)
6✔
645
        # x = x * m
646

647
        x = self.layer2(x)  # 28
6✔
648
        m = self.mask2(x)
6✔
649
        x = x * (1 + m)
6✔
650
        # x = x * m
651

652
        x = self.layer3(x)  # 14
6✔
653
        m = self.mask3(x)
6✔
654
        x = x * (1 + m)
6✔
655
        # x = x * m
656

657
        x = self.layer4(x)  # 7
6✔
658
        m = self.mask4(x)
6✔
659
        x = x * (1 + m)
6✔
660
        # x = x * m
661

662
        x = self.avgpool(x)
6✔
663
        x = torch.flatten(x, 1)
6✔
664

665
        x = self.fc(x)
6✔
666
        return x
6✔
667

668

669
# def resmasking(in_channels, num_classes, weight_path=""):
670
#     return ResMasking(weight_path)
671

672

673
def resmasking_dropout1(in_channels=3, num_classes=7, weight_path=""):
6✔
674
    model = ResMasking(weight_path, in_channels=in_channels)
×
675
    model.fc = nn.Sequential(nn.Dropout(0.4), nn.Linear(512, num_classes))
×
676
    return model
×
677

678

679
###########################
680

681

682
class ResMaskNet:
6✔
683
    def __init__(self, device="auto", pretrained="huggingface"):
6✔
684
        """Initialize ResMaskNet
685

686
        @misc{luanresmaskingnet2020,
687
        Author = {Luan Pham & Tuan Anh Tran},
688
        Title = {Facial Expression Recognition using Residual Masking Network},
689
        url = {https://github.com/phamquiluan/ResidualMaskingNetwork},
690
        Year = {2020}
691
        }
692

693
        """
694

695
        self.device = set_torch_device(device)
×
696

697
        self.FER_2013_EMO_DICT = {
×
698
            0: "angry",
699
            1: "disgust",
700
            2: "fear",
701
            3: "happy",
702
            4: "sad",
703
            5: "surprise",
704
            6: "neutral",
705
        }
706

707
        # load configs and set random seed
708
        configs = json.load(
×
709
            open(os.path.join(get_resource_path(), "ResMaskNet_fer2013_config.json"))
710
        )
711
        self.image_size = (configs["image_size"], configs["image_size"])
×
712

713
        self.model = resmasking_dropout1(in_channels=3, num_classes=7)
×
714

NEW
715
        if pretrained == "huggingface":
×
NEW
716
            self.model.from_pretrained("py-feat/resmasknet")
×
NEW
717
        elif pretrained == "local":
×
NEW
718
            self.model.load_state_dict(
×
719
                torch.load(
720
                    os.path.join(
721
                        get_resource_path(),
722
                        "ResMaskNet_Z_resmasking_dropout1_rot30.pth",
723
                    ),
724
                    map_location=self.device,
725
                )["net"]
726
            )
727
        self.model.eval()
×
728

729
    def detect_emo(self, frame, detected_face, *args, **kwargs):
6✔
730
        """Detect emotions.
731
        Args:
732
            frame ([type]): [description]
733
        Returns:
734
            List of predicted emotions in probability: [angry, disgust, fear, happy, sad, surprise, neutral]
735
        """
736

737
        face = self._batch_make(frame=frame, detected_face=detected_face)
×
738
        with torch.no_grad():
×
739
            output = self.model(face)
×
740
            proba = torch.softmax(output, 1)
×
741
            proba_np = proba.cpu().numpy()
×
742
            return proba_np
×
743

744
    def _batch_make(self, frame, detected_face, *args, **kwargs):
6✔
745
        transform = Compose([Grayscale(3)])
×
746
        gray = transform(frame)
×
747

748
        len_index = [len(aa) for aa in detected_face]
×
749
        length_cumu = np.cumsum(len_index)
×
750
        flat_faces = [item for sublist in detected_face for item in sublist]
×
751

752
        concat_batch = None
×
753
        for i, face in enumerate(flat_faces):
×
754
            frame_choice = np.where(i < length_cumu)[0][0]
×
755
            #     frame0 = np.fliplr(frame[frame_choice]).astype(np.uint8) # not sure why we need to flip the face
756
            bbox = BBox(face[:-1])
×
757
            face = (
×
758
                bbox.expand_by_factor(1.1)
759
                .extract_from_image(gray[frame_choice])
760
                .unsqueeze(0)
761
            )
762
            transform = Resize(self.image_size)
×
763
            face = transform(face) / 255
×
764
            if concat_batch is None:
×
765
                concat_batch = face
×
766
            else:
767
                concat_batch = torch.cat((concat_batch, face), 0)
×
768

769
        return concat_batch
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc