Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Sequence | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import build_activation_layer, build_norm_layer | |
| from mmcv.cnn.bricks.drop import DropPath | |
| from mmcv.cnn.bricks.transformer import AdaptivePadding | |
| from mmengine.model import BaseModule | |
| from mmengine.model.weight_init import trunc_normal_ | |
| from mmpretrain.registry import MODELS | |
| from .base_backbone import BaseBackbone | |
| from .vision_transformer import TransformerEncoderLayer | |
| class ConvBlock(BaseModule): | |
| """Basic convluation block used in Conformer. | |
| This block includes three convluation modules, and supports three new | |
| functions: | |
| 1. Returns the output of both the final layers and the second convluation | |
| module. | |
| 2. Fuses the input of the second convluation module with an extra input | |
| feature map. | |
| 3. Supports to add an extra convluation module to the identity connection. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| out_channels (int): The number of output channels. | |
| stride (int): The stride of the second convluation module. | |
| Defaults to 1. | |
| groups (int): The groups of the second convluation module. | |
| Defaults to 1. | |
| drop_path_rate (float): The rate of the DropPath layer. Defaults to 0. | |
| with_residual_conv (bool): Whether to add an extra convluation module | |
| to the identity connection. Defaults to False. | |
| norm_cfg (dict): The config of normalization layers. | |
| Defaults to ``dict(type='BN', eps=1e-6)``. | |
| act_cfg (dict): The config of activative functions. | |
| Defaults to ``dict(type='ReLU', inplace=True))``. | |
| init_cfg (dict, optional): The extra config to initialize the module. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| stride=1, | |
| groups=1, | |
| drop_path_rate=0., | |
| with_residual_conv=False, | |
| norm_cfg=dict(type='BN', eps=1e-6), | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| init_cfg=None): | |
| super(ConvBlock, self).__init__(init_cfg=init_cfg) | |
| expansion = 4 | |
| mid_channels = out_channels // expansion | |
| self.conv1 = nn.Conv2d( | |
| in_channels, | |
| mid_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False) | |
| self.bn1 = build_norm_layer(norm_cfg, mid_channels)[1] | |
| self.act1 = build_activation_layer(act_cfg) | |
| self.conv2 = nn.Conv2d( | |
| mid_channels, | |
| mid_channels, | |
| kernel_size=3, | |
| stride=stride, | |
| groups=groups, | |
| padding=1, | |
| bias=False) | |
| self.bn2 = build_norm_layer(norm_cfg, mid_channels)[1] | |
| self.act2 = build_activation_layer(act_cfg) | |
| self.conv3 = nn.Conv2d( | |
| mid_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False) | |
| self.bn3 = build_norm_layer(norm_cfg, out_channels)[1] | |
| self.act3 = build_activation_layer(act_cfg) | |
| if with_residual_conv: | |
| self.residual_conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| padding=0, | |
| bias=False) | |
| self.residual_bn = build_norm_layer(norm_cfg, out_channels)[1] | |
| self.with_residual_conv = with_residual_conv | |
| self.drop_path = DropPath( | |
| drop_path_rate) if drop_path_rate > 0. else nn.Identity() | |
| def zero_init_last_bn(self): | |
| nn.init.zeros_(self.bn3.weight) | |
| def forward(self, x, fusion_features=None, out_conv2=True): | |
| identity = x | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.act1(x) | |
| x = self.conv2(x) if fusion_features is None else self.conv2( | |
| x + fusion_features) | |
| x = self.bn2(x) | |
| x2 = self.act2(x) | |
| x = self.conv3(x2) | |
| x = self.bn3(x) | |
| if self.drop_path is not None: | |
| x = self.drop_path(x) | |
| if self.with_residual_conv: | |
| identity = self.residual_conv(identity) | |
| identity = self.residual_bn(identity) | |
| x += identity | |
| x = self.act3(x) | |
| if out_conv2: | |
| return x, x2 | |
| else: | |
| return x | |
| class FCUDown(BaseModule): | |
| """CNN feature maps -> Transformer patch embeddings.""" | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| down_stride, | |
| with_cls_token=True, | |
| norm_cfg=dict(type='LN', eps=1e-6), | |
| act_cfg=dict(type='GELU'), | |
| init_cfg=None): | |
| super(FCUDown, self).__init__(init_cfg=init_cfg) | |
| self.down_stride = down_stride | |
| self.with_cls_token = with_cls_token | |
| self.conv_project = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| self.sample_pooling = nn.AvgPool2d( | |
| kernel_size=down_stride, stride=down_stride) | |
| self.ln = build_norm_layer(norm_cfg, out_channels)[1] | |
| self.act = build_activation_layer(act_cfg) | |
| def forward(self, x, x_t): | |
| x = self.conv_project(x) # [N, C, H, W] | |
| x = self.sample_pooling(x).flatten(2).transpose(1, 2) | |
| x = self.ln(x) | |
| x = self.act(x) | |
| if self.with_cls_token: | |
| x = torch.cat([x_t[:, 0][:, None, :], x], dim=1) | |
| return x | |
| class FCUUp(BaseModule): | |
| """Transformer patch embeddings -> CNN feature maps.""" | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| up_stride, | |
| with_cls_token=True, | |
| norm_cfg=dict(type='BN', eps=1e-6), | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| init_cfg=None): | |
| super(FCUUp, self).__init__(init_cfg=init_cfg) | |
| self.up_stride = up_stride | |
| self.with_cls_token = with_cls_token | |
| self.conv_project = nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| self.bn = build_norm_layer(norm_cfg, out_channels)[1] | |
| self.act = build_activation_layer(act_cfg) | |
| def forward(self, x, H, W): | |
| B, _, C = x.shape | |
| # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14] | |
| if self.with_cls_token: | |
| x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W) | |
| else: | |
| x_r = x.transpose(1, 2).reshape(B, C, H, W) | |
| x_r = self.act(self.bn(self.conv_project(x_r))) | |
| return F.interpolate( | |
| x_r, size=(H * self.up_stride, W * self.up_stride)) | |
| class ConvTransBlock(BaseModule): | |
| """Basic module for Conformer. | |
| This module is a fusion of CNN block transformer encoder block. | |
| Args: | |
| in_channels (int): The number of input channels in conv blocks. | |
| out_channels (int): The number of output channels in conv blocks. | |
| embed_dims (int): The embedding dimension in transformer blocks. | |
| conv_stride (int): The stride of conv2d layers. Defaults to 1. | |
| groups (int): The groups of conv blocks. Defaults to 1. | |
| with_residual_conv (bool): Whether to add a conv-bn layer to the | |
| identity connect in the conv block. Defaults to False. | |
| down_stride (int): The stride of the downsample pooling layer. | |
| Defaults to 4. | |
| num_heads (int): The number of heads in transformer attention layers. | |
| Defaults to 12. | |
| mlp_ratio (float): The expansion ratio in transformer FFN module. | |
| Defaults to 4. | |
| qkv_bias (bool): Enable bias for qkv if True. Defaults to False. | |
| with_cls_token (bool): Whether use class token or not. | |
| Defaults to True. | |
| drop_rate (float): The dropout rate of the output projection and | |
| FFN in the transformer block. Defaults to 0. | |
| attn_drop_rate (float): The dropout rate after the attention | |
| calculation in the transformer block. Defaults to 0. | |
| drop_path_rate (bloat): The drop path rate in both the conv block | |
| and the transformer block. Defaults to 0. | |
| last_fusion (bool): Whether this block is the last stage. If so, | |
| downsample the fusion feature map. | |
| init_cfg (dict, optional): The extra config to initialize the module. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| embed_dims, | |
| conv_stride=1, | |
| groups=1, | |
| with_residual_conv=False, | |
| down_stride=4, | |
| num_heads=12, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| with_cls_token=True, | |
| drop_rate=0., | |
| attn_drop_rate=0., | |
| drop_path_rate=0., | |
| last_fusion=False, | |
| init_cfg=None): | |
| super(ConvTransBlock, self).__init__(init_cfg=init_cfg) | |
| expansion = 4 | |
| self.cnn_block = ConvBlock( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| with_residual_conv=with_residual_conv, | |
| stride=conv_stride, | |
| groups=groups) | |
| if last_fusion: | |
| self.fusion_block = ConvBlock( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| stride=2, | |
| with_residual_conv=True, | |
| groups=groups, | |
| drop_path_rate=drop_path_rate) | |
| else: | |
| self.fusion_block = ConvBlock( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| groups=groups, | |
| drop_path_rate=drop_path_rate) | |
| self.squeeze_block = FCUDown( | |
| in_channels=out_channels // expansion, | |
| out_channels=embed_dims, | |
| down_stride=down_stride, | |
| with_cls_token=with_cls_token) | |
| self.expand_block = FCUUp( | |
| in_channels=embed_dims, | |
| out_channels=out_channels // expansion, | |
| up_stride=down_stride, | |
| with_cls_token=with_cls_token) | |
| self.trans_block = TransformerEncoderLayer( | |
| embed_dims=embed_dims, | |
| num_heads=num_heads, | |
| feedforward_channels=int(embed_dims * mlp_ratio), | |
| drop_rate=drop_rate, | |
| drop_path_rate=drop_path_rate, | |
| attn_drop_rate=attn_drop_rate, | |
| qkv_bias=qkv_bias, | |
| norm_cfg=dict(type='LN', eps=1e-6)) | |
| self.down_stride = down_stride | |
| self.embed_dim = embed_dims | |
| self.last_fusion = last_fusion | |
| def forward(self, cnn_input, trans_input): | |
| x, x_conv2 = self.cnn_block(cnn_input, out_conv2=True) | |
| _, _, H, W = x_conv2.shape | |
| # Convert the feature map of conv2 to transformer embedding | |
| # and concat with class token. | |
| conv2_embedding = self.squeeze_block(x_conv2, trans_input) | |
| trans_output = self.trans_block(conv2_embedding + trans_input) | |
| # Convert the transformer output embedding to feature map | |
| trans_features = self.expand_block(trans_output, H // self.down_stride, | |
| W // self.down_stride) | |
| x = self.fusion_block( | |
| x, fusion_features=trans_features, out_conv2=False) | |
| return x, trans_output | |
| class Conformer(BaseBackbone): | |
| """Conformer backbone. | |
| A PyTorch implementation of : `Conformer: Local Features Coupling Global | |
| Representations for Visual Recognition <https://arxiv.org/abs/2105.03889>`_ | |
| Args: | |
| arch (str | dict): Conformer architecture. Defaults to 'tiny'. | |
| patch_size (int): The patch size. Defaults to 16. | |
| base_channels (int): The base number of channels in CNN network. | |
| Defaults to 64. | |
| mlp_ratio (float): The expansion ratio of FFN network in transformer | |
| block. Defaults to 4. | |
| with_cls_token (bool): Whether use class token or not. | |
| Defaults to True. | |
| drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
| out_indices (Sequence | int): Output from which stages. | |
| Defaults to -1, means the last stage. | |
| init_cfg (dict, optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| arch_zoo = { | |
| **dict.fromkeys(['t', 'tiny'], | |
| {'embed_dims': 384, | |
| 'channel_ratio': 1, | |
| 'num_heads': 6, | |
| 'depths': 12 | |
| }), | |
| **dict.fromkeys(['s', 'small'], | |
| {'embed_dims': 384, | |
| 'channel_ratio': 4, | |
| 'num_heads': 6, | |
| 'depths': 12 | |
| }), | |
| **dict.fromkeys(['b', 'base'], | |
| {'embed_dims': 576, | |
| 'channel_ratio': 6, | |
| 'num_heads': 9, | |
| 'depths': 12 | |
| }), | |
| } # yapf: disable | |
| _version = 1 | |
| def __init__(self, | |
| arch='tiny', | |
| patch_size=16, | |
| base_channels=64, | |
| mlp_ratio=4., | |
| qkv_bias=True, | |
| with_cls_token=True, | |
| drop_path_rate=0., | |
| norm_eval=True, | |
| frozen_stages=0, | |
| out_indices=-1, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| if isinstance(arch, str): | |
| arch = arch.lower() | |
| assert arch in set(self.arch_zoo), \ | |
| f'Arch {arch} is not in default archs {set(self.arch_zoo)}' | |
| self.arch_settings = self.arch_zoo[arch] | |
| else: | |
| essential_keys = { | |
| 'embed_dims', 'depths', 'num_heads', 'channel_ratio' | |
| } | |
| assert isinstance(arch, dict) and set(arch) == essential_keys, \ | |
| f'Custom arch needs a dict with keys {essential_keys}' | |
| self.arch_settings = arch | |
| self.num_features = self.embed_dims = self.arch_settings['embed_dims'] | |
| self.depths = self.arch_settings['depths'] | |
| self.num_heads = self.arch_settings['num_heads'] | |
| self.channel_ratio = self.arch_settings['channel_ratio'] | |
| if isinstance(out_indices, int): | |
| out_indices = [out_indices] | |
| assert isinstance(out_indices, Sequence), \ | |
| f'"out_indices" must by a sequence or int, ' \ | |
| f'get {type(out_indices)} instead.' | |
| for i, index in enumerate(out_indices): | |
| if index < 0: | |
| out_indices[i] = self.depths + index + 1 | |
| assert out_indices[i] >= 0, f'Invalid out_indices {index}' | |
| self.out_indices = out_indices | |
| self.norm_eval = norm_eval | |
| self.frozen_stages = frozen_stages | |
| self.with_cls_token = with_cls_token | |
| if self.with_cls_token: | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) | |
| # stochastic depth decay rule | |
| self.trans_dpr = [ | |
| x.item() for x in torch.linspace(0, drop_path_rate, self.depths) | |
| ] | |
| # Stem stage: get the feature maps by conv block | |
| self.conv1 = nn.Conv2d( | |
| 3, 64, kernel_size=7, stride=2, padding=3, | |
| bias=False) # 1 / 2 [112, 112] | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.act1 = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool2d( | |
| kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56] | |
| assert patch_size % 16 == 0, 'The patch size of Conformer must ' \ | |
| 'be divisible by 16.' | |
| trans_down_stride = patch_size // 4 | |
| # To solve the issue #680 | |
| # Auto pad the feature map to be divisible by trans_down_stride | |
| self.auto_pad = AdaptivePadding(trans_down_stride, trans_down_stride) | |
| # 1 stage | |
| stage1_channels = int(base_channels * self.channel_ratio) | |
| self.conv_1 = ConvBlock( | |
| in_channels=64, | |
| out_channels=stage1_channels, | |
| with_residual_conv=True, | |
| stride=1) | |
| self.trans_patch_conv = nn.Conv2d( | |
| 64, | |
| self.embed_dims, | |
| kernel_size=trans_down_stride, | |
| stride=trans_down_stride, | |
| padding=0) | |
| self.trans_1 = TransformerEncoderLayer( | |
| embed_dims=self.embed_dims, | |
| num_heads=self.num_heads, | |
| feedforward_channels=int(self.embed_dims * mlp_ratio), | |
| drop_path_rate=self.trans_dpr[0], | |
| qkv_bias=qkv_bias, | |
| norm_cfg=dict(type='LN', eps=1e-6)) | |
| # 2~4 stage | |
| init_stage = 2 | |
| fin_stage = self.depths // 3 + 1 | |
| for i in range(init_stage, fin_stage): | |
| self.add_module( | |
| f'conv_trans_{i}', | |
| ConvTransBlock( | |
| in_channels=stage1_channels, | |
| out_channels=stage1_channels, | |
| embed_dims=self.embed_dims, | |
| conv_stride=1, | |
| with_residual_conv=False, | |
| down_stride=trans_down_stride, | |
| num_heads=self.num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| drop_path_rate=self.trans_dpr[i - 1], | |
| with_cls_token=self.with_cls_token)) | |
| stage2_channels = int(base_channels * self.channel_ratio * 2) | |
| # 5~8 stage | |
| init_stage = fin_stage # 5 | |
| fin_stage = fin_stage + self.depths // 3 # 9 | |
| for i in range(init_stage, fin_stage): | |
| if i == init_stage: | |
| conv_stride = 2 | |
| in_channels = stage1_channels | |
| else: | |
| conv_stride = 1 | |
| in_channels = stage2_channels | |
| with_residual_conv = True if i == init_stage else False | |
| self.add_module( | |
| f'conv_trans_{i}', | |
| ConvTransBlock( | |
| in_channels=in_channels, | |
| out_channels=stage2_channels, | |
| embed_dims=self.embed_dims, | |
| conv_stride=conv_stride, | |
| with_residual_conv=with_residual_conv, | |
| down_stride=trans_down_stride // 2, | |
| num_heads=self.num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| drop_path_rate=self.trans_dpr[i - 1], | |
| with_cls_token=self.with_cls_token)) | |
| stage3_channels = int(base_channels * self.channel_ratio * 2 * 2) | |
| # 9~12 stage | |
| init_stage = fin_stage # 9 | |
| fin_stage = fin_stage + self.depths // 3 # 13 | |
| for i in range(init_stage, fin_stage): | |
| if i == init_stage: | |
| conv_stride = 2 | |
| in_channels = stage2_channels | |
| with_residual_conv = True | |
| else: | |
| conv_stride = 1 | |
| in_channels = stage3_channels | |
| with_residual_conv = False | |
| last_fusion = (i == self.depths) | |
| self.add_module( | |
| f'conv_trans_{i}', | |
| ConvTransBlock( | |
| in_channels=in_channels, | |
| out_channels=stage3_channels, | |
| embed_dims=self.embed_dims, | |
| conv_stride=conv_stride, | |
| with_residual_conv=with_residual_conv, | |
| down_stride=trans_down_stride // 4, | |
| num_heads=self.num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| drop_path_rate=self.trans_dpr[i - 1], | |
| with_cls_token=self.with_cls_token, | |
| last_fusion=last_fusion)) | |
| self.fin_stage = fin_stage | |
| self.pooling = nn.AdaptiveAvgPool2d(1) | |
| self.trans_norm = nn.LayerNorm(self.embed_dims) | |
| if self.with_cls_token: | |
| trunc_normal_(self.cls_token, std=.02) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| elif isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_( | |
| m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1.) | |
| nn.init.constant_(m.bias, 0.) | |
| if hasattr(m, 'zero_init_last_bn'): | |
| m.zero_init_last_bn() | |
| def init_weights(self): | |
| super(Conformer, self).init_weights() | |
| if (isinstance(self.init_cfg, dict) | |
| and self.init_cfg['type'] == 'Pretrained'): | |
| # Suppress default init if use pretrained model. | |
| return | |
| self.apply(self._init_weights) | |
| def forward(self, x): | |
| output = [] | |
| B = x.shape[0] | |
| if self.with_cls_token: | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| # stem | |
| x_base = self.maxpool(self.act1(self.bn1(self.conv1(x)))) | |
| x_base = self.auto_pad(x_base) | |
| # 1 stage [N, 64, 56, 56] -> [N, 128, 56, 56] | |
| x = self.conv_1(x_base, out_conv2=False) | |
| x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2) | |
| if self.with_cls_token: | |
| x_t = torch.cat([cls_tokens, x_t], dim=1) | |
| x_t = self.trans_1(x_t) | |
| # 2 ~ final | |
| for i in range(2, self.fin_stage): | |
| stage = getattr(self, f'conv_trans_{i}') | |
| x, x_t = stage(x, x_t) | |
| if i in self.out_indices: | |
| if self.with_cls_token: | |
| output.append([ | |
| self.pooling(x).flatten(1), | |
| self.trans_norm(x_t)[:, 0] | |
| ]) | |
| else: | |
| # if no class token, use the mean patch token | |
| # as the transformer feature. | |
| output.append([ | |
| self.pooling(x).flatten(1), | |
| self.trans_norm(x_t).mean(dim=1) | |
| ]) | |
| return tuple(output) | |