import torch.nn as nn import timm class TimmFRWrapperV2(nn.Module): """ Wraps timm model """ def __init__(self, model_name="edgenext_x_small", featdim=512, batchnorm=False): super().__init__() self.featdim = featdim self.model_name = model_name self.model = timm.create_model(self.model_name) self.model.reset_classifier(self.featdim) def forward(self, x): x = self.model(x) return x class LoRaLin(nn.Module): def __init__(self, in_features, out_features, rank, bias=True): super(LoRaLin, self).__init__() self.in_features = in_features self.out_features = out_features self.rank = rank self.linear1 = nn.Linear(in_features, rank, bias=False) self.linear2 = nn.Linear(rank, out_features, bias=bias) def forward(self, input): x = self.linear1(input) x = self.linear2(x) return x def replace_linear_with_lowrank_recursive_2(model, rank_ratio=0.2): for name, module in model.named_children(): if isinstance(module, nn.Linear) and "head" not in name: in_features = module.in_features out_features = module.out_features rank = max(2, int(min(in_features, out_features) * rank_ratio)) bias = False if module.bias is not None: bias = True lowrank_module = LoRaLin(in_features, out_features, rank, bias) setattr(model, name, lowrank_module) else: replace_linear_with_lowrank_recursive_2(module, rank_ratio) def replace_linear_with_lowrank_2(model, rank_ratio=0.2): replace_linear_with_lowrank_recursive_2(model, rank_ratio) return model model_configs = { "edgeface_base": { "repo": "idiap/EdgeFace-Base", "filename": "edgeface_base.pt", "timm_model": "edgenext_base", "post_setup": lambda x: x, }, "edgeface_s_gamma_05": { "repo": "idiap/EdgeFace-S-GAMMA", "filename": "edgeface_s_gamma_05.pt", "timm_model": "edgenext_small", "post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.5), }, "edgeface_xs_gamma_06": { "repo": "idiap/EdgeFace-XS-GAMMA", "filename": "edgeface_xs_gamma_06.pt", "timm_model": "edgenext_x_small", "post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.6), }, "edgeface_xxs": { "repo": "idiap/EdgeFace-XXS", "filename": "edgeface_xxs.pt", "timm_model": "edgenext_xx_small", "post_setup": lambda x: x, }, }