Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import Callable, Optional, Sequence | |
| import torch | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule, build_norm_layer | |
| from torch import nn | |
| from mmpretrain.registry import MODELS | |
| from .base_backbone import BaseBackbone | |
| from .mobilenet_v2 import InvertedResidual | |
| from .vision_transformer import TransformerEncoderLayer | |
| class MobileVitBlock(nn.Module): | |
| """MobileViT block. | |
| According to the paper, the MobileViT block has a local representation. | |
| a transformer-as-convolution layer which consists of a global | |
| representation with unfolding and folding, and a final fusion layer. | |
| Args: | |
| in_channels (int): Number of input image channels. | |
| transformer_dim (int): Number of transformer channels. | |
| ffn_dim (int): Number of ffn channels in transformer block. | |
| out_channels (int): Number of channels in output. | |
| conv_ksize (int): Conv kernel size in local representation | |
| and fusion. Defaults to 3. | |
| conv_cfg (dict, optional): Config dict for convolution layer. | |
| Defaults to None, which means using conv2d. | |
| norm_cfg (dict, optional): Config dict for normalization layer. | |
| Defaults to dict(type='BN'). | |
| act_cfg (dict, optional): Config dict for activation layer. | |
| Defaults to dict(type='Swish'). | |
| num_transformer_blocks (int): Number of transformer blocks in | |
| a MobileViT block. Defaults to 2. | |
| patch_size (int): Patch size for unfolding and folding. | |
| Defaults to 2. | |
| num_heads (int): Number of heads in global representation. | |
| Defaults to 4. | |
| drop_rate (float): Probability of an element to be zeroed | |
| after the feed forward layer. Defaults to 0. | |
| attn_drop_rate (float): The drop out rate for attention output weights. | |
| Defaults to 0. | |
| drop_path_rate (float): Stochastic depth rate. Defaults to 0. | |
| no_fusion (bool): Whether to remove the fusion layer. | |
| Defaults to False. | |
| transformer_norm_cfg (dict, optional): Config dict for normalization | |
| layer in transformer. Defaults to dict(type='LN'). | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| transformer_dim: int, | |
| ffn_dim: int, | |
| out_channels: int, | |
| conv_ksize: int = 3, | |
| conv_cfg: Optional[dict] = None, | |
| norm_cfg: Optional[dict] = dict(type='BN'), | |
| act_cfg: Optional[dict] = dict(type='Swish'), | |
| num_transformer_blocks: int = 2, | |
| patch_size: int = 2, | |
| num_heads: int = 4, | |
| drop_rate: float = 0., | |
| attn_drop_rate: float = 0., | |
| drop_path_rate: float = 0., | |
| no_fusion: bool = False, | |
| transformer_norm_cfg: Callable = dict(type='LN'), | |
| ): | |
| super(MobileVitBlock, self).__init__() | |
| self.local_rep = nn.Sequential( | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| kernel_size=conv_ksize, | |
| padding=int((conv_ksize - 1) / 2), | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg), | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=transformer_dim, | |
| kernel_size=1, | |
| bias=False, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=None, | |
| act_cfg=None), | |
| ) | |
| global_rep = [ | |
| TransformerEncoderLayer( | |
| embed_dims=transformer_dim, | |
| num_heads=num_heads, | |
| feedforward_channels=ffn_dim, | |
| drop_rate=drop_rate, | |
| attn_drop_rate=attn_drop_rate, | |
| drop_path_rate=drop_path_rate, | |
| qkv_bias=True, | |
| act_cfg=dict(type='Swish'), | |
| norm_cfg=transformer_norm_cfg) | |
| for _ in range(num_transformer_blocks) | |
| ] | |
| global_rep.append( | |
| build_norm_layer(transformer_norm_cfg, transformer_dim)[1]) | |
| self.global_rep = nn.Sequential(*global_rep) | |
| self.conv_proj = ConvModule( | |
| in_channels=transformer_dim, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| if no_fusion: | |
| self.conv_fusion = None | |
| else: | |
| self.conv_fusion = ConvModule( | |
| in_channels=in_channels + out_channels, | |
| out_channels=out_channels, | |
| kernel_size=conv_ksize, | |
| padding=int((conv_ksize - 1) / 2), | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| self.patch_size = (patch_size, patch_size) | |
| self.patch_area = self.patch_size[0] * self.patch_size[1] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| shortcut = x | |
| # Local representation | |
| x = self.local_rep(x) | |
| # Unfold (feature map -> patches) | |
| patch_h, patch_w = self.patch_size | |
| B, C, H, W = x.shape | |
| new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil( | |
| W / patch_w) * patch_w | |
| num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w # noqa | |
| num_patches = num_patch_h * num_patch_w # N | |
| interpolate = False | |
| if new_h != H or new_w != W: | |
| # Note: Padding can be done, but then it needs to be handled in attention function. # noqa | |
| x = F.interpolate( | |
| x, size=(new_h, new_w), mode='bilinear', align_corners=False) | |
| interpolate = True | |
| # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w] | |
| x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, | |
| patch_w).transpose(1, 2) | |
| # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w # noqa | |
| x = x.reshape(B, C, num_patches, | |
| self.patch_area).transpose(1, 3).reshape( | |
| B * self.patch_area, num_patches, -1) | |
| # Global representations | |
| x = self.global_rep(x) | |
| # Fold (patch -> feature map) | |
| # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w] | |
| x = x.contiguous().view(B, self.patch_area, num_patches, -1) | |
| x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, | |
| patch_h, patch_w) | |
| # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] # noqa | |
| x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, | |
| num_patch_w * patch_w) | |
| if interpolate: | |
| x = F.interpolate( | |
| x, size=(H, W), mode='bilinear', align_corners=False) | |
| x = self.conv_proj(x) | |
| if self.conv_fusion is not None: | |
| x = self.conv_fusion(torch.cat((shortcut, x), dim=1)) | |
| return x | |
| class MobileViT(BaseBackbone): | |
| """MobileViT backbone. | |
| A PyTorch implementation of : `MobileViT: Light-weight, General-purpose, | |
| and Mobile-friendly Vision Transformer <https://arxiv.org/pdf/2110.02178.pdf>`_ | |
| Modified from the `official repo | |
| <https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py>`_ | |
| and `timm | |
| <https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mobilevit.py>`_. | |
| Args: | |
| arch (str | List[list]): Architecture of MobileViT. | |
| - If a string, choose from "small", "x_small" and "xx_small". | |
| - If a list, every item should be also a list, and the first item | |
| of the sub-list can be chosen from "moblienetv2" and "mobilevit", | |
| which indicates the type of this layer sequence. If "mobilenetv2", | |
| the other items are the arguments of :attr:`~MobileViT.make_mobilenetv2_layer` | |
| (except ``in_channels``) and if "mobilevit", the other items are | |
| the arguments of :attr:`~MobileViT.make_mobilevit_layer` | |
| (except ``in_channels``). | |
| Defaults to "small". | |
| in_channels (int): Number of input image channels. Defaults to 3. | |
| stem_channels (int): Channels of stem layer. Defaults to 16. | |
| last_exp_factor (int): Channels expand factor of last layer. | |
| Defaults to 4. | |
| out_indices (Sequence[int]): Output from which stages. | |
| Defaults to (4, ). | |
| frozen_stages (int): Stages to be frozen (all param fixed). | |
| Defaults to -1, which means not freezing any parameters. | |
| conv_cfg (dict, optional): Config dict for convolution layer. | |
| Defaults to None, which means using conv2d. | |
| norm_cfg (dict, optional): Config dict for normalization layer. | |
| Defaults to dict(type='BN'). | |
| act_cfg (dict, optional): Config dict for activation layer. | |
| Defaults to dict(type='Swish'). | |
| init_cfg (dict, optional): Initialization config dict. | |
| """ # noqa | |
| # Parameters to build layers. The first param is the type of layer. | |
| # For `mobilenetv2` layer, the rest params from left to right are: | |
| # out channels, stride, num of blocks, expand_ratio. | |
| # For `mobilevit` layer, the rest params from left to right are: | |
| # out channels, stride, transformer_channels, ffn channels, | |
| # num of transformer blocks, expand_ratio. | |
| arch_settings = { | |
| 'small': [ | |
| ['mobilenetv2', 32, 1, 1, 4], | |
| ['mobilenetv2', 64, 2, 3, 4], | |
| ['mobilevit', 96, 2, 144, 288, 2, 4], | |
| ['mobilevit', 128, 2, 192, 384, 4, 4], | |
| ['mobilevit', 160, 2, 240, 480, 3, 4], | |
| ], | |
| 'x_small': [ | |
| ['mobilenetv2', 32, 1, 1, 4], | |
| ['mobilenetv2', 48, 2, 3, 4], | |
| ['mobilevit', 64, 2, 96, 192, 2, 4], | |
| ['mobilevit', 80, 2, 120, 240, 4, 4], | |
| ['mobilevit', 96, 2, 144, 288, 3, 4], | |
| ], | |
| 'xx_small': [ | |
| ['mobilenetv2', 16, 1, 1, 2], | |
| ['mobilenetv2', 24, 2, 3, 2], | |
| ['mobilevit', 48, 2, 64, 128, 2, 2], | |
| ['mobilevit', 64, 2, 80, 160, 4, 2], | |
| ['mobilevit', 80, 2, 96, 192, 3, 2], | |
| ] | |
| } | |
| def __init__(self, | |
| arch='small', | |
| in_channels=3, | |
| stem_channels=16, | |
| last_exp_factor=4, | |
| out_indices=(4, ), | |
| frozen_stages=-1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='Swish'), | |
| init_cfg=[ | |
| dict(type='Kaiming', layer=['Conv2d']), | |
| dict( | |
| type='Constant', | |
| val=1, | |
| layer=['_BatchNorm', 'GroupNorm']) | |
| ]): | |
| super(MobileViT, self).__init__(init_cfg) | |
| if isinstance(arch, str): | |
| arch = arch.lower() | |
| assert arch in self.arch_settings, \ | |
| f'Unavailable arch, please choose from ' \ | |
| f'({set(self.arch_settings)}) or pass a list.' | |
| arch = self.arch_settings[arch] | |
| self.arch = arch | |
| self.num_stages = len(arch) | |
| # check out indices and frozen stages | |
| if isinstance(out_indices, int): | |
| out_indices = [out_indices] | |
| assert isinstance(out_indices, Sequence), \ | |
| f'"out_indices" must by a sequence or int, ' \ | |
| f'get {type(out_indices)} instead.' | |
| for i, index in enumerate(out_indices): | |
| if index < 0: | |
| out_indices[i] = self.num_stages + index | |
| assert out_indices[i] >= 0, f'Invalid out_indices {index}' | |
| self.out_indices = out_indices | |
| if frozen_stages not in range(-1, self.num_stages): | |
| raise ValueError('frozen_stages must be in range(-1, ' | |
| f'{self.num_stages}). ' | |
| f'But received {frozen_stages}') | |
| self.frozen_stages = frozen_stages | |
| _make_layer_func = { | |
| 'mobilenetv2': self.make_mobilenetv2_layer, | |
| 'mobilevit': self.make_mobilevit_layer, | |
| } | |
| self.stem = ConvModule( | |
| in_channels=in_channels, | |
| out_channels=stem_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| in_channels = stem_channels | |
| layers = [] | |
| for i, layer_settings in enumerate(arch): | |
| layer_type, settings = layer_settings[0], layer_settings[1:] | |
| layer, out_channels = _make_layer_func[layer_type](in_channels, | |
| *settings) | |
| layers.append(layer) | |
| in_channels = out_channels | |
| self.layers = nn.Sequential(*layers) | |
| self.conv_1x1_exp = ConvModule( | |
| in_channels=in_channels, | |
| out_channels=last_exp_factor * in_channels, | |
| kernel_size=1, | |
| stride=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| def make_mobilevit_layer(in_channels, | |
| out_channels, | |
| stride, | |
| transformer_dim, | |
| ffn_dim, | |
| num_transformer_blocks, | |
| expand_ratio=4): | |
| """Build mobilevit layer, which consists of one InvertedResidual and | |
| one MobileVitBlock. | |
| Args: | |
| in_channels (int): The input channels. | |
| out_channels (int): The output channels. | |
| stride (int): The stride of the first 3x3 convolution in the | |
| ``InvertedResidual`` layers. | |
| transformer_dim (int): The channels of the transformer layers. | |
| ffn_dim (int): The mid-channels of the feedforward network in | |
| transformer layers. | |
| num_transformer_blocks (int): The number of transformer blocks. | |
| expand_ratio (int): adjusts number of channels of the hidden layer | |
| in ``InvertedResidual`` by this amount. Defaults to 4. | |
| """ | |
| layer = [] | |
| layer.append( | |
| InvertedResidual( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| stride=stride, | |
| expand_ratio=expand_ratio, | |
| act_cfg=dict(type='Swish'), | |
| )) | |
| layer.append( | |
| MobileVitBlock( | |
| in_channels=out_channels, | |
| transformer_dim=transformer_dim, | |
| ffn_dim=ffn_dim, | |
| out_channels=out_channels, | |
| num_transformer_blocks=num_transformer_blocks, | |
| )) | |
| return nn.Sequential(*layer), out_channels | |
| def make_mobilenetv2_layer(in_channels, | |
| out_channels, | |
| stride, | |
| num_blocks, | |
| expand_ratio=4): | |
| """Build mobilenetv2 layer, which consists of several InvertedResidual | |
| layers. | |
| Args: | |
| in_channels (int): The input channels. | |
| out_channels (int): The output channels. | |
| stride (int): The stride of the first 3x3 convolution in the | |
| ``InvertedResidual`` layers. | |
| num_blocks (int): The number of ``InvertedResidual`` blocks. | |
| expand_ratio (int): adjusts number of channels of the hidden layer | |
| in ``InvertedResidual`` by this amount. Defaults to 4. | |
| """ | |
| layer = [] | |
| for i in range(num_blocks): | |
| stride = stride if i == 0 else 1 | |
| layer.append( | |
| InvertedResidual( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| stride=stride, | |
| expand_ratio=expand_ratio, | |
| act_cfg=dict(type='Swish'), | |
| )) | |
| in_channels = out_channels | |
| return nn.Sequential(*layer), out_channels | |
| def _freeze_stages(self): | |
| for i in range(0, self.frozen_stages): | |
| layer = self.layers[i] | |
| layer.eval() | |
| for param in layer.parameters(): | |
| param.requires_grad = False | |
| def train(self, mode=True): | |
| super(MobileViT, self).train(mode) | |
| self._freeze_stages() | |
| def forward(self, x): | |
| x = self.stem(x) | |
| outs = [] | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x) | |
| if i == len(self.layers) - 1: | |
| x = self.conv_1x1_exp(x) | |
| if i in self.out_indices: | |
| outs.append(x) | |
| return tuple(outs) | |