Spaces:
Running
Running
import torch.nn as nn | |
import timm | |
class TimmFRWrapperV2(nn.Module): | |
""" | |
Wraps timm model | |
""" | |
def __init__(self, model_name="edgenext_x_small", featdim=512, batchnorm=False): | |
super().__init__() | |
self.featdim = featdim | |
self.model_name = model_name | |
self.model = timm.create_model(self.model_name) | |
self.model.reset_classifier(self.featdim) | |
def forward(self, x): | |
x = self.model(x) | |
return x | |
class LoRaLin(nn.Module): | |
def __init__(self, in_features, out_features, rank, bias=True): | |
super(LoRaLin, self).__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.rank = rank | |
self.linear1 = nn.Linear(in_features, rank, bias=False) | |
self.linear2 = nn.Linear(rank, out_features, bias=bias) | |
def forward(self, input): | |
x = self.linear1(input) | |
x = self.linear2(x) | |
return x | |
def replace_linear_with_lowrank_recursive_2(model, rank_ratio=0.2): | |
for name, module in model.named_children(): | |
if isinstance(module, nn.Linear) and "head" not in name: | |
in_features = module.in_features | |
out_features = module.out_features | |
rank = max(2, int(min(in_features, out_features) * rank_ratio)) | |
bias = False | |
if module.bias is not None: | |
bias = True | |
lowrank_module = LoRaLin(in_features, out_features, rank, bias) | |
setattr(model, name, lowrank_module) | |
else: | |
replace_linear_with_lowrank_recursive_2(module, rank_ratio) | |
def replace_linear_with_lowrank_2(model, rank_ratio=0.2): | |
replace_linear_with_lowrank_recursive_2(model, rank_ratio) | |
return model | |
model_configs = { | |
"edgeface_base": { | |
"repo": "idiap/EdgeFace-Base", | |
"filename": "edgeface_base.pt", | |
"timm_model": "edgenext_base", | |
"post_setup": lambda x: x, | |
}, | |
"edgeface_s_gamma_05": { | |
"repo": "idiap/EdgeFace-S-GAMMA", | |
"filename": "edgeface_s_gamma_05.pt", | |
"timm_model": "edgenext_small", | |
"post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.5), | |
}, | |
"edgeface_xs_gamma_06": { | |
"repo": "idiap/EdgeFace-XS-GAMMA", | |
"filename": "edgeface_xs_gamma_06.pt", | |
"timm_model": "edgenext_x_small", | |
"post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.6), | |
}, | |
"edgeface_xxs": { | |
"repo": "idiap/EdgeFace-XXS", | |
"filename": "edgeface_xxs.pt", | |
"timm_model": "edgenext_xx_small", | |
"post_setup": lambda x: x, | |
}, | |
} | |