EdgeFace / timmfrv2.py
bornet's picture
Refactor to use hf_hub_download instead of torch.hub.load
528e972 verified
import torch.nn as nn
import timm
class TimmFRWrapperV2(nn.Module):
"""
Wraps timm model
"""
def __init__(self, model_name="edgenext_x_small", featdim=512, batchnorm=False):
super().__init__()
self.featdim = featdim
self.model_name = model_name
self.model = timm.create_model(self.model_name)
self.model.reset_classifier(self.featdim)
def forward(self, x):
x = self.model(x)
return x
class LoRaLin(nn.Module):
def __init__(self, in_features, out_features, rank, bias=True):
super(LoRaLin, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.linear1 = nn.Linear(in_features, rank, bias=False)
self.linear2 = nn.Linear(rank, out_features, bias=bias)
def forward(self, input):
x = self.linear1(input)
x = self.linear2(x)
return x
def replace_linear_with_lowrank_recursive_2(model, rank_ratio=0.2):
for name, module in model.named_children():
if isinstance(module, nn.Linear) and "head" not in name:
in_features = module.in_features
out_features = module.out_features
rank = max(2, int(min(in_features, out_features) * rank_ratio))
bias = False
if module.bias is not None:
bias = True
lowrank_module = LoRaLin(in_features, out_features, rank, bias)
setattr(model, name, lowrank_module)
else:
replace_linear_with_lowrank_recursive_2(module, rank_ratio)
def replace_linear_with_lowrank_2(model, rank_ratio=0.2):
replace_linear_with_lowrank_recursive_2(model, rank_ratio)
return model
model_configs = {
"edgeface_base": {
"repo": "idiap/EdgeFace-Base",
"filename": "edgeface_base.pt",
"timm_model": "edgenext_base",
"post_setup": lambda x: x,
},
"edgeface_s_gamma_05": {
"repo": "idiap/EdgeFace-S-GAMMA",
"filename": "edgeface_s_gamma_05.pt",
"timm_model": "edgenext_small",
"post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.5),
},
"edgeface_xs_gamma_06": {
"repo": "idiap/EdgeFace-XS-GAMMA",
"filename": "edgeface_xs_gamma_06.pt",
"timm_model": "edgenext_x_small",
"post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.6),
},
"edgeface_xxs": {
"repo": "idiap/EdgeFace-XXS",
"filename": "edgeface_xxs.pt",
"timm_model": "edgenext_xx_small",
"post_setup": lambda x: x,
},
}