from timm import create_model from torch import nn, Tensor from typing import Optional from functools import partial from ..utils import _get_activation, _get_norm_layer, ConvUpsample, ConvDownsample from ..utils import LightConvUpsample, LightConvDownsample, LighterConvUpsample, LighterConvDownsample from ..utils import ConvRefine, LightConvRefine, LighterConvRefine regular_models = [ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "convnext_nano", "convnext_tiny", "convnext_small", "convnext_base", "mobilenetv4_conv_large", ] heavy_models = [ "convnext_large", "convnext_xlarge", "convnext_xxlarge", ] light_models = [ "mobilenetv1_100", "mobilenetv1_125", "mobilenetv2_100", "mobilenetv2_140", "mobilenetv3_large_100", "mobilenetv4_conv_medium", ] lighter_models = [ "mobilenetv2_050", "mobilenetv3_small_050", "mobilenetv3_small_075", "mobilenetv3_small_100", "mobilenetv4_conv_small_050", "mobilenetv4_conv_small" ] supported_models = regular_models + heavy_models + light_models + lighter_models refiner_in_channels = { # ResNet "resnet18": 512, "resnet34": 512, "resnet50": 2048, "resnet101": 2048, "resnet152": 2048, # ConvNeXt "convnext_nano": 640, "convnext_tiny": 768, "convnext_small": 768, "convnext_base": 1024, "convnext_large": 1536, "convnext_xlarge": 2048, "convnext_xxlarge": 3072, # MobileNet V1 "mobilenetv1_100": 1024, "mobilenetv1_125": 1280, # MobileNet V2 "mobilenetv2_050": 160, "mobilenetv2_100": 320, "mobilenetv2_140": 448, # MobileNet V3 "mobilenetv3_small_050": 288, "mobilenetv3_small_075": 432, "mobilenetv3_small_100": 576, "mobilenetv3_large_100": 960, # MobileNet V4 "mobilenetv4_conv_small_050": 480, "mobilenetv4_conv_small": 960, "mobilenetv4_conv_medium": 960, "mobilenetv4_conv_large": 960, } refiner_out_channels = { # ResNet "resnet18": 512, "resnet34": 512, "resnet50": 2048, "resnet101": 2048, "resnet152": 2048, # ConvNeXt "convnext_nano": 640, "convnext_tiny": 768, "convnext_small": 768, "convnext_base": 1024, "convnext_large": 1536, "convnext_xlarge": 2048, "convnext_xxlarge": 3072, # MobileNet V1 "mobilenetv1_100": 512, "mobilenetv1_125": 640, # MobileNet V2 "mobilenetv2_050": 160, "mobilenetv2_100": 320, "mobilenetv2_140": 448, # MobileNet V3 "mobilenetv3_small_050": 288, "mobilenetv3_small_075": 432, "mobilenetv3_small_100": 576, "mobilenetv3_large_100": 480, # MobileNet V4 "mobilenetv4_conv_small_050": 480, "mobilenetv4_conv_small": 960, "mobilenetv4_conv_medium": 960, "mobilenetv4_conv_large": 960, } groups = { # ResNet "resnet18": 1, "resnet34": 1, "resnet50": refiner_in_channels["resnet50"] // 512, "resnet101": refiner_in_channels["resnet101"] // 512, "resnet152": refiner_in_channels["resnet152"] // 512, # ConvNeXt "convnext_nano": 8, "convnext_tiny": 8, "convnext_small": 8, "convnext_base": 8, "convnext_large": refiner_in_channels["convnext_large"] // 512, "convnext_xlarge": refiner_in_channels["convnext_xlarge"] // 512, "convnext_xxlarge": refiner_in_channels["convnext_xxlarge"] // 512, # MobileNet V1 "mobilenetv1_100": None, "mobilenetv1_125": None, # MobileNet V2 "mobilenetv2_050": None, "mobilenetv2_100": None, "mobilenetv2_140": None, # MobileNet V3 "mobilenetv3_small_050": None, "mobilenetv3_small_075": None, "mobilenetv3_small_100": None, "mobilenetv3_large_100": None, # MobileNet V4 "mobilenetv4_conv_small_050": None, "mobilenetv4_conv_small": None, "mobilenetv4_conv_medium": None, "mobilenetv4_conv_large": 1, } class TIMMModel(nn.Module): def __init__( self, model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none" ) -> None: super().__init__() assert model_name in supported_models, f"Backbone {model_name} not supported. Supported models are {supported_models}" assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." self.model_name = model_name self.encoder = create_model(model_name, pretrained=True, features_only=True, out_indices=[-1]) self.encoder_channels = self.encoder.feature_info.channels()[-1] self.encoder_reduction = self.encoder.feature_info.reduction()[-1] self.block_size = block_size if block_size is not None else self.encoder_reduction if model_name in lighter_models: upsample_block = LighterConvUpsample downsample_block = LighterConvDownsample decoder_block = LighterConvRefine elif model_name in light_models: upsample_block = LightConvUpsample downsample_block = LightConvDownsample decoder_block = LightConvRefine else: upsample_block = partial(ConvUpsample, groups=groups[model_name]) downsample_block = partial(ConvDownsample, groups=groups[model_name]) decoder_block = partial(ConvRefine, groups=groups[model_name]) if norm == "bn": norm_layer = nn.BatchNorm2d elif norm == "ln": norm_layer = nn.LayerNorm else: norm_layer = _get_norm_layer(self.encoder) if act == "relu": activation = nn.ReLU(inplace=True) elif act == "gelu": activation = nn.GELU() else: activation = _get_activation(self.encoder) if self.block_size > self.encoder_reduction: if self.block_size > self.encoder_reduction * 2: assert self.block_size == self.encoder_reduction * 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." self.refiner = nn.Sequential( downsample_block( in_channels=self.encoder_channels, out_channels=refiner_in_channels[self.model_name], norm_layer=norm_layer, activation=activation, ), downsample_block( in_channels=refiner_in_channels[self.model_name], out_channels=refiner_out_channels[self.model_name], norm_layer=norm_layer, activation=activation, ) ) else: assert self.block_size == self.encoder_reduction * 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." self.refiner = downsample_block( in_channels=self.encoder_channels, out_channels=refiner_out_channels[self.model_name], norm_layer=norm_layer, activation=activation, ) self.refiner_channels = refiner_out_channels[self.model_name] elif self.block_size < self.encoder_reduction: if self.block_size < self.encoder_reduction // 2: assert self.block_size == self.encoder_reduction // 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." self.refiner = nn.Sequential( upsample_block( in_channels=self.encoder_channels, out_channels=refiner_in_channels[self.model_name], norm_layer=norm_layer, activation=activation, ), upsample_block( in_channels=refiner_in_channels[self.model_name], out_channels=refiner_out_channels[self.model_name], norm_layer=norm_layer, activation=activation, ) ) else: assert self.block_size == self.encoder_reduction // 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." self.refiner = upsample_block( in_channels=self.encoder_channels, out_channels=refiner_out_channels[self.model_name], norm_layer=norm_layer, activation=activation, ) self.refiner_channels = refiner_out_channels[self.model_name] else: self.refiner = nn.Identity() self.refiner_channels = self.encoder_channels self.refiner_reduction = self.block_size if self.refiner_channels <= 256: self.decoder = nn.Identity() self.decoder_channels = self.refiner_channels elif self.refiner_channels <= 512: self.decoder = decoder_block( in_channels=self.refiner_channels, out_channels=self.refiner_channels // 2, norm_layer=norm_layer, activation=activation, ) self.decoder_channels = self.refiner_channels // 2 elif self.refiner_channels <= 1024: self.decoder = nn.Sequential( decoder_block( in_channels=self.refiner_channels, out_channels=self.refiner_channels // 2, norm_layer=norm_layer, activation=activation, ), decoder_block( in_channels=self.refiner_channels // 2, out_channels=self.refiner_channels // 4, norm_layer=norm_layer, activation=activation, ), ) self.decoder_channels = self.refiner_channels // 4 else: self.decoder = nn.Sequential( decoder_block( in_channels=self.refiner_channels, out_channels=self.refiner_channels // 2, norm_layer=norm_layer, activation=activation, ), decoder_block( in_channels=self.refiner_channels // 2, out_channels=self.refiner_channels // 4, norm_layer=norm_layer, activation=activation, ), decoder_block( in_channels=self.refiner_channels // 4, out_channels=self.refiner_channels // 8, norm_layer=norm_layer, activation=activation, ), ) self.decoder_channels = self.refiner_channels // 8 self.decoder_reduction = self.refiner_reduction def encode(self, x: Tensor) -> Tensor: return self.encoder(x)[0] def refine(self, x: Tensor) -> Tensor: return self.refiner(x) def decode(self, x: Tensor) -> Tensor: return self.decoder(x) def forward(self, x: Tensor) -> Tensor: x = self.encode(x) x = self.refine(x) x = self.decode(x) return x def _timm_model(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> TIMMModel: return TIMMModel(model_name, block_size=block_size, norm=norm, act=act)