File size: 3,192 Bytes
20cf96a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""

===============================================================================

Author: Anjith George

Institution: Idiap Research Institute, Martigny, Switzerland.



Copyright (C) 2024 Anjith George



This software is distributed under the terms described in the LICENSE file 

located in the parent directory of this source code repository. 



For inquiries, please contact the author at [email protected]

===============================================================================

"""

dependencies = ['torch', 'torchvision', 'timm']

from backbones import get_model
import torch

def edgeface_base(pretrained=True, **kwargs):
    model = get_model('edgeface_base', **kwargs)
    if pretrained:
        checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_base.pt'
        state_dict = torch.hub.load_state_dict_from_url(
            checkpoint_url, map_location='cpu'
        )
        model.load_state_dict(state_dict)
    return model

def edgeface_xs_gamma_06(pretrained=True, **kwargs):
    model = get_model('edgeface_xs_gamma_06', **kwargs)
    if pretrained:
        checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xs_gamma_06.pt'
        state_dict = torch.hub.load_state_dict_from_url(
            checkpoint_url, map_location='cpu'
        )
        model.load_state_dict(state_dict)
    return model

def edgeface_xs_q(pretrained=True, **kwargs):
    model = get_model('edgeface_xs_q', **kwargs)
    if pretrained:
        checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xs_q.pt'
        state_dict = torch.hub.load_state_dict_from_url(
            checkpoint_url, map_location='cpu'
        )
        model.load_state_dict(state_dict)
    return model

def edgeface_xxs(pretrained=True, **kwargs):
    model = get_model('edgeface_xxs', **kwargs)
    if pretrained:
        checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xxs.pt'
        state_dict = torch.hub.load_state_dict_from_url(
            checkpoint_url, map_location='cpu'
        )
        model.load_state_dict(state_dict)
    return model

def edgeface_xxs_q(pretrained=True, **kwargs):
    model = get_model('edgeface_xxs_q', **kwargs)
    if pretrained:
        checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_xxs_q.pt'
        state_dict = torch.hub.load_state_dict_from_url(
            checkpoint_url, map_location='cpu'
        )
        model.load_state_dict(state_dict)
    return model

def edgeface_s_gamma_05(pretrained=True, **kwargs):
    model = get_model('edgeface_s_gamma_05', **kwargs)
    if pretrained:
        checkpoint_url = 'https://gitlab.idiap.ch/bob/bob.paper.tbiom2023_edgeface/-/raw/master/checkpoints/edgeface_s_gamma_05.pt'
        state_dict = torch.hub.load_state_dict_from_url(
            checkpoint_url, map_location='cpu'
        )
        model.load_state_dict(state_dict)
    return model