• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

cosanlab / py-feat / 15090929758

19 Oct 2024 05:10AM UTC coverage: 54.553%. First build
15090929758

push

github

web-flow
Merge pull request #228 from cosanlab/huggingface

WIP: Huggingface Integration

702 of 1620 new or added lines in 46 files covered. (43.33%)

3409 of 6249 relevant lines covered (54.55%)

3.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

17.99
/feat/pretrained.py
1
"""
2
Helper functions specifically for working with included pre-trained models
3
"""
4

5
from feat.face_detectors.FaceBoxes.FaceBoxes_test import FaceBoxes
6✔
6
from feat.face_detectors.Retinaface.Retinaface_test import Retinaface
6✔
7
from feat.face_detectors.MTCNN.MTCNN_test import MTCNN
6✔
8
from feat.landmark_detectors.basenet_test import MobileNet_GDConv
6✔
9
from feat.landmark_detectors.pfld_compressed_test import PFLDInference
6✔
10
from feat.landmark_detectors.mobilefacenet_test import MobileFaceNet
6✔
11
from feat.au_detectors.StatLearning.SL_test import SVMClassifier, XGBClassifier
6✔
12
from feat.emo_detectors.ResMaskNet.resmasknet_test import ResMaskNet
6✔
13
from feat.emo_detectors.StatLearning.EmoSL_test import (
6✔
14
    EmoSVMClassifier,
15
)
16
from feat.facepose_detectors.img2pose.deps.models import FasterDoFRCNN
6✔
17
from feat.identity_detectors.facenet.facenet_test import Facenet
6✔
18
from feat.utils.io import get_resource_path, download_url
6✔
19
import os
6✔
20
import json
6✔
21
import pickle
6✔
22
from skops.io import load, get_untrusted_types
6✔
23
from huggingface_hub import hf_hub_download
6✔
24
import xgboost as xgb
6✔
25

26
__all__ = ["get_pretrained_models", "fetch_model", "load_model_weights"]
6✔
27
# Currently supported pre-trained detectors
28
PRETRAINED_MODELS = {
6✔
29
    "face_model": [
30
        {"retinaface": Retinaface},
31
        {"faceboxes": FaceBoxes},
32
        {"mtcnn": MTCNN},
33
        {"img2pose": FasterDoFRCNN},
34
        {"img2pose-c": FasterDoFRCNN},
35
    ],
36
    "landmark_model": [
37
        {"mobilenet": MobileNet_GDConv},
38
        {"mobilefacenet": MobileFaceNet},
39
        {"pfld": PFLDInference},
40
    ],
41
    "au_model": [{"svm": SVMClassifier}, {"xgb": XGBClassifier}],
42
    "emotion_model": [
43
        {"resmasknet": ResMaskNet},
44
        {"svm": EmoSVMClassifier},
45
    ],
46
    "facepose_model": [
47
        {"img2pose": FasterDoFRCNN},
48
        {"img2pose-c": FasterDoFRCNN},
49
    ],
50
    "identity_model": [{"facenet": Facenet}],
51
}
52

53
# Compatibility support for OpenFace which has diff AU names than feat
54
AU_LANDMARK_MAP = {
6✔
55
    "OpenFace": [
56
        "AU01_r",
57
        "AU02_r",
58
        "AU04_r",
59
        "AU05_r",
60
        "AU06_r",
61
        "AU07_r",
62
        "AU09_r",
63
        "AU10_r",
64
        "AU12_r",
65
        "AU14_r",
66
        "AU15_r",
67
        "AU17_r",
68
        "AU20_r",
69
        "AU23_r",
70
        "AU25_r",
71
        "AU26_r",
72
        "AU45_r",
73
    ],
74
    "Feat": [
75
        "AU01",
76
        "AU02",
77
        "AU04",
78
        "AU05",
79
        "AU06",
80
        "AU07",
81
        "AU09",
82
        "AU10",
83
        "AU11",
84
        "AU12",
85
        "AU14",
86
        "AU15",
87
        "AU17",
88
        "AU20",
89
        "AU23",
90
        "AU24",
91
        "AU25",
92
        "AU26",
93
        "AU28",
94
        "AU43",
95
    ],
96
}
97

98

99
def get_pretrained_models(
6✔
100
    face_model,
101
    landmark_model,
102
    au_model,
103
    emotion_model,
104
    facepose_model,
105
    identity_model,
106
    verbose,
107
):
108
    """Helper function that validates the request model names and downloads them if
109
    necessary using the URLs in the included JSON file. User by detector init"""
110

111
    # Get supported model URLs
112
    with open(os.path.join(get_resource_path(), "model_list.json"), "r") as f:
×
113
        model_urls = json.load(f)
×
114

115
    get_names = lambda s: list(
×
116
        map(
117
            lambda e: list(e.keys())[0],
118
            PRETRAINED_MODELS[s],
119
        )
120
    )
121

122
    # Face model
123
    if face_model is None:
×
124
        raise ValueError(
×
125
            f"face_model must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['face_model']]}"
126
        )
127
    else:
128
        face_model = face_model.lower()
×
129
        if face_model not in get_names("face_model"):
×
130
            raise ValueError(
×
131
                f"Requested face_model was {face_model}. Must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['face_model']]}"
132
            )
133
        for url in model_urls["face_detectors"][face_model]["urls"]:
×
134
            download_url(url, get_resource_path(), verbose=verbose)
×
135

136
    # Landmark model
137
    if landmark_model is None:
×
138
        raise ValueError(
×
139
            f"landmark_model must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['landmark_model']]}"
140
        )
141
    else:
142
        landmark_model = landmark_model.lower()
×
143
        if landmark_model not in get_names("landmark_model"):
×
144
            raise ValueError(
×
145
                f"Requested landmark_model was {landmark_model}. Must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['landmark_model']]}"
146
            )
147
        for url in model_urls["landmark_detectors"][landmark_model]["urls"]:
×
148
            download_url(url, get_resource_path(), verbose=verbose)
×
149

150
    # AU model
151
    if au_model is None:
×
152
        raise ValueError(
×
153
            f"au_model must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['au_model']]}"
154
        )
155
    else:
156
        au_model = au_model.lower()
×
157
        if au_model not in get_names("au_model"):
×
158
            raise ValueError(
×
159
                f"Requested au_model was {au_model}. Must be one of {[list(e.keys())[0]for e in PRETRAINED_MODELS['au_model']]}"
160
            )
161

162
        for url in model_urls["au_detectors"][au_model]["urls"]:
×
163
            download_url(url, get_resource_path(), verbose=verbose)
×
164
            if au_model in ["xgb", "svm"]:
×
165
                download_url(
×
166
                    model_urls["au_detectors"]["hog-pca"]["urls"][0],
167
                    get_resource_path(),
168
                    verbose=verbose,
169
                )
170
                download_url(
×
171
                    model_urls["au_detectors"]["hog-pca"]["urls"][1],
172
                    get_resource_path(),
173
                    verbose=verbose,
174
                )
175
                download_url(
×
176
                    model_urls["au_detectors"]["hog-pca"]["urls"][2],
177
                    get_resource_path(),
178
                    verbose=verbose,
179
                )
180
                download_url(
×
181
                    model_urls["au_detectors"]["hog-pca"]["urls"][3],
182
                    get_resource_path(),
183
                    verbose=verbose,
184
                )
185
                download_url(
×
186
                    model_urls["au_detectors"]["hog-pca"]["urls"][4],
187
                    get_resource_path(),
188
                    verbose=verbose,
189
                )
190
                download_url(
×
191
                    model_urls["au_detectors"]["hog-pca"]["urls"][5],
192
                    get_resource_path(),
193
                    verbose=verbose,
194
                )
195
    # Emotion model
196
    if emotion_model is None:
×
197
        raise ValueError(
×
198
            f"emotion_model must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['emotion_model']]}"
199
        )
200
    else:
201
        emotion_model = emotion_model.lower()
×
202
        if emotion_model not in get_names("emotion_model"):
×
203
            raise ValueError(
×
204
                f"Requested emotion_model was {emotion_model}. Must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['emotion_model']]}"
205
            )
206
        for url in model_urls["emotion_detectors"][emotion_model]["urls"]:
×
207
            download_url(url, get_resource_path(), verbose=verbose)
×
208
            if emotion_model in ["svm"]:
×
209
                download_url(
×
210
                    model_urls["emotion_detectors"]["emo_pca"]["urls"][0],
211
                    get_resource_path(),
212
                    verbose=verbose,
213
                )
214
                download_url(
×
215
                    model_urls["emotion_detectors"]["emo_scalar"]["urls"][0],
216
                    get_resource_path(),
217
                    verbose=verbose,
218
                )
219

220
    # Facepose model
221
    if facepose_model is None:
×
222
        raise ValueError(
×
223
            f"facepose_model must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['facepose_model']]}"
224
        )
225
    else:
226
        facepose_model = facepose_model.lower()
×
227
        if facepose_model not in get_names("facepose_model"):
×
228
            raise ValueError(
×
229
                f"Requested facepose_model was {facepose_model}. Must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['facepose_model']]}"
230
            )
231
        for url in model_urls["facepose_detectors"][facepose_model]["urls"]:
×
232
            download_url(url, get_resource_path(), verbose=verbose)
×
233

234
    # Face Identity model
235
    if identity_model is None:
×
236
        raise ValueError(
×
237
            f"representation_model must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['representation_model']]}"
238
        )
239
    else:
240
        identity_model = identity_model.lower()
×
241
        if identity_model not in get_names("identity_model"):
×
242
            raise ValueError(
×
243
                f"Requested representation_model was {identity_model}. Must be one of {[list(e.keys())[0] for e in PRETRAINED_MODELS['identity_model']]}"
244
            )
245
        for url in model_urls["identity_detectors"][identity_model]["urls"]:
×
246
            download_url(url, get_resource_path(), verbose=verbose)
×
247

248
    return (
×
249
        face_model,
250
        landmark_model,
251
        au_model,
252
        emotion_model,
253
        facepose_model,
254
        identity_model,
255
    )
256

257

258
def fetch_model(model_type, model_name):
6✔
259
    """Fetch a pre-trained model class constructor. Used by detector init"""
260
    if model_name is None:
×
261
        raise ValueError(f"{model_type} must be a valid string model name, not None")
×
262
    model_type = PRETRAINED_MODELS[model_type]
×
263
    matches = list(filter(lambda e: model_name in e.keys(), model_type))[0]
×
264
    return list(matches.values())[0]
×
265

266

267
def load_classifier_pkl(cf_path):
6✔
NEW
268
    clf = pickle.load(open(cf_path, "rb"))
×
NEW
269
    return clf
×
270

271

272
def load_model_weights(model_type="au", model="xgb", location="huggingface"):
6✔
273
    """Load weights for the AU models"""
NEW
274
    if model_type == "au":
×
NEW
275
        if model == "xgb":
×
NEW
276
            if location == "huggingface":
×
277
                # Load the entire model from skops serialized file
NEW
278
                model_path = hf_hub_download(
×
279
                    repo_id="py-feat/xgb_au",
280
                    filename="xgb_au_classifier.skops",
281
                    cache_dir=get_resource_path(),
282
                )
NEW
283
                unknown_types = get_untrusted_types(file=model_path)
×
NEW
284
                loaded_model = load(model_path, trusted=unknown_types)
×
NEW
285
                return {
×
286
                    "scaler_upper": loaded_model.scaler_upper,
287
                    "pca_model_upper": loaded_model.pca_model_upper,
288
                    "scaler_lower": loaded_model.scaler_lower,
289
                    "pca_model_lower": loaded_model.pca_model_lower,
290
                    "scaler_full": loaded_model.scaler_full,
291
                    "pca_model_full": loaded_model.pca_model_full,
292
                    "au_classifiers": loaded_model.classifiers,
293
                }
NEW
294
            elif location == "local":
×
295
                # Load weights from local Resources folder
NEW
296
                scaler_upper = load_classifier_pkl(
×
297
                    os.path.join(get_resource_path(), "all_data_Upperscalar_June30.pkl")
298
                )
NEW
299
                pca_model_upper = load_classifier_pkl(
×
300
                    os.path.join(get_resource_path(), "all_data_Upperpca_June30.pkl")
301
                )
NEW
302
                scaler_lower = load_classifier_pkl(
×
303
                    os.path.join(get_resource_path(), "all_data_Lowerscalar_June30.pkl")
304
                )
NEW
305
                pca_model_lower = load_classifier_pkl(
×
306
                    os.path.join(get_resource_path(), "all_data_Lowerpca_June30.pkl")
307
                )
NEW
308
                scaler_full = load_classifier_pkl(
×
309
                    os.path.join(get_resource_path(), "all_data_Fullscalar_June30.pkl")
310
                )
NEW
311
                pca_model_full = load_classifier_pkl(
×
312
                    os.path.join(get_resource_path(), "all_data_Fullpca_June30.pkl")
313
                )
314

NEW
315
                au_keys = [
×
316
                    "AU1",
317
                    "AU2",
318
                    "AU4",
319
                    "AU5",
320
                    "AU6",
321
                    "AU7",
322
                    "AU9",
323
                    "AU10",
324
                    "AU11",
325
                    "AU12",
326
                    "AU14",
327
                    "AU15",
328
                    "AU17",
329
                    "AU20",
330
                    "AU23",
331
                    "AU24",
332
                    "AU25",
333
                    "AU26",
334
                    "AU28",
335
                    "AU43",
336
                ]
NEW
337
                classifiers = {}
×
NEW
338
                for key in au_keys:
×
NEW
339
                    classifier = xgb.XGBClassifier()
×
NEW
340
                    classifier.load_model(
×
341
                        os.path.join(get_resource_path(), f"July4_{key}_XGB.ubj")
342
                    )
NEW
343
                    classifiers[key] = classifier
×
NEW
344
                return {
×
345
                    "scaler_upper": scaler_upper,
346
                    "pca_model_upper": pca_model_upper,
347
                    "scaler_lower": scaler_lower,
348
                    "pca_model_lower": pca_model_lower,
349
                    "scaler_full": scaler_full,
350
                    "pca_model_full": pca_model_full,
351
                    "au_classifiers": classifiers,
352
                }
353

NEW
354
        elif model == "svm":
×
NEW
355
            if location == "huggingface":
×
356
                # Load the entire model from skops serialized file
NEW
357
                model_path = hf_hub_download(
×
358
                    repo_id="py-feat/svm_au",
359
                    filename="svm_au_classifier.skops",
360
                    cache_dir=get_resource_path(),
361
                )
NEW
362
                unknown_types = get_untrusted_types(file=model_path)
×
NEW
363
                loaded_model = load(model_path, trusted=unknown_types)
×
NEW
364
                return {
×
365
                    "scaler_upper": loaded_model.scaler_upper,
366
                    "pca_model_upper": loaded_model.pca_model_upper,
367
                    "scaler_lower": loaded_model.scaler_lower,
368
                    "pca_model_lower": loaded_model.pca_model_lower,
369
                    "scaler_full": loaded_model.scaler_full,
370
                    "pca_model_full": loaded_model.pca_model_full,
371
                    "au_classifiers": loaded_model.classifiers,
372
                }
NEW
373
            elif location == "local":
×
374
                # Load weights from local Resources folder
NEW
375
                scaler_upper = load_classifier_pkl(
×
376
                    os.path.join(get_resource_path(), "all_data_Upperscalar_June30.pkl")
377
                )
NEW
378
                pca_model_upper = load_classifier_pkl(
×
379
                    os.path.join(get_resource_path(), "all_data_Upperpca_June30.pkl")
380
                )
NEW
381
                scaler_lower = load_classifier_pkl(
×
382
                    os.path.join(get_resource_path(), "all_data_Lowerscalar_June30.pkl")
383
                )
NEW
384
                pca_model_lower = load_classifier_pkl(
×
385
                    os.path.join(get_resource_path(), "all_data_Lowerpca_June30.pkl")
386
                )
NEW
387
                scaler_full = load_classifier_pkl(
×
388
                    os.path.join(get_resource_path(), "all_data_Fullscalar_June30.pkl")
389
                )
NEW
390
                pca_model_full = load_classifier_pkl(
×
391
                    os.path.join(get_resource_path(), "all_data_Fullpca_June30.pkl")
392
                )
NEW
393
                classifiers = load_classifier_pkl(
×
394
                    os.path.join(get_resource_path(), "svm_60_July2023.pkl")
395
                )
NEW
396
                return {
×
397
                    "scaler_upper": scaler_upper,
398
                    "pca_model_upper": pca_model_upper,
399
                    "scaler_lower": scaler_lower,
400
                    "pca_model_lower": pca_model_lower,
401
                    "scaler_full": scaler_full,
402
                    "pca_model_full": pca_model_full,
403
                    "au_classifiers": classifiers,
404
                }
405
            else:
NEW
406
                raise ValueError(f"This function does not support {model_type} {model}")
×
NEW
407
    elif model_type == "emotion":
×
NEW
408
        if model == "svm":
×
NEW
409
            if location == "huggingface":
×
410
                # Load the entire model from skops serialized file
NEW
411
                model_path = hf_hub_download(
×
412
                    repo_id="py-feat/svm_emo",
413
                    filename="svm_emo_classifier.skops",
414
                    cache_dir=get_resource_path(),
415
                )
NEW
416
                unknown_types = get_untrusted_types(file=model_path)
×
NEW
417
                loaded_model = load(model_path, trusted=unknown_types)
×
NEW
418
                return {
×
419
                    "scaler_full": loaded_model.scaler_full,
420
                    "pca_model_full": loaded_model.pca_model_full,
421
                    "emo_classifiers": loaded_model.classifiers,
422
                }
NEW
423
            elif location == "local":
×
424
                # Load weights from local Resources folder
NEW
425
                scaler_full = load_classifier_pkl(
×
426
                    os.path.join(get_resource_path(), "emo_data_Fullscalar_Jun30.pkl")
427
                )
NEW
428
                pca_model_full = load_classifier_pkl(
×
429
                    os.path.join(get_resource_path(), "emo_data_Fullpca_Jun30.pkl")
430
                )
NEW
431
                classifiers = load_classifier_pkl(
×
432
                    os.path.join(get_resource_path(), "July4_emo_SVM.pkl")
433
                )
NEW
434
                return {
×
435
                    "scaler_full": scaler_full,
436
                    "pca_model_full": pca_model_full,
437
                    "emo_classifiers": classifiers,
438
                }
439
        else:
NEW
440
            raise ValueError(f"This function does not support {model_type} {model}")
×
441
    else:
NEW
442
        raise ValueError(f"This function does not support {model_type}")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc