Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import itertools | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import build_activation_layer, fuse_conv_bn | |
| from mmcv.cnn.bricks import DropPath | |
| from mmengine.model import BaseModule, ModuleList, Sequential | |
| from mmpretrain.models.backbones.base_backbone import BaseBackbone | |
| from mmpretrain.registry import MODELS | |
| from ..utils import build_norm_layer | |
| class HybridBackbone(BaseModule): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| kernel_size=3, | |
| stride=2, | |
| pad=1, | |
| dilation=1, | |
| groups=1, | |
| act_cfg=dict(type='HSwish'), | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| init_cfg=None, | |
| ): | |
| super(HybridBackbone, self).__init__(init_cfg=init_cfg) | |
| self.input_channels = [ | |
| 3, embed_dim // 8, embed_dim // 4, embed_dim // 2 | |
| ] | |
| self.output_channels = [ | |
| embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim | |
| ] | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.patch_embed = Sequential() | |
| for i in range(len(self.input_channels)): | |
| conv_bn = ConvolutionBatchNorm( | |
| self.input_channels[i], | |
| self.output_channels[i], | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| pad=pad, | |
| dilation=dilation, | |
| groups=groups, | |
| norm_cfg=norm_cfg, | |
| ) | |
| self.patch_embed.add_module('%d' % (2 * i), conv_bn) | |
| if i < len(self.input_channels) - 1: | |
| self.patch_embed.add_module('%d' % (i * 2 + 1), | |
| build_activation_layer(act_cfg)) | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| return x | |
| class ConvolutionBatchNorm(BaseModule): | |
| def __init__( | |
| self, | |
| in_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=2, | |
| pad=1, | |
| dilation=1, | |
| groups=1, | |
| norm_cfg=dict(type='BN'), | |
| ): | |
| super(ConvolutionBatchNorm, self).__init__() | |
| self.conv = nn.Conv2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=pad, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=False) | |
| self.bn = build_norm_layer(norm_cfg, out_channel) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn(x) | |
| return x | |
| def fuse(self): | |
| return fuse_conv_bn(self).conv | |
| class LinearBatchNorm(BaseModule): | |
| def __init__(self, in_feature, out_feature, norm_cfg=dict(type='BN1d')): | |
| super(LinearBatchNorm, self).__init__() | |
| self.linear = nn.Linear(in_feature, out_feature, bias=False) | |
| self.bn = build_norm_layer(norm_cfg, out_feature) | |
| def forward(self, x): | |
| x = self.linear(x) | |
| x = self.bn(x.flatten(0, 1)).reshape_as(x) | |
| return x | |
| def fuse(self): | |
| w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 | |
| w = self.linear.weight * w[:, None] | |
| b = self.bn.bias - self.bn.running_mean * self.bn.weight / \ | |
| (self.bn.running_var + self.bn.eps) ** 0.5 | |
| factory_kwargs = { | |
| 'device': self.linear.weight.device, | |
| 'dtype': self.linear.weight.dtype | |
| } | |
| bias = nn.Parameter( | |
| torch.empty(self.linear.out_features, **factory_kwargs)) | |
| self.linear.register_parameter('bias', bias) | |
| self.linear.weight.data.copy_(w) | |
| self.linear.bias.data.copy_(b) | |
| return self.linear | |
| class Residual(BaseModule): | |
| def __init__(self, block, drop_path_rate=0.): | |
| super(Residual, self).__init__() | |
| self.block = block | |
| if drop_path_rate > 0: | |
| self.drop_path = DropPath(drop_path_rate) | |
| else: | |
| self.drop_path = nn.Identity() | |
| def forward(self, x): | |
| x = x + self.drop_path(self.block(x)) | |
| return x | |
| class Attention(BaseModule): | |
| def __init__( | |
| self, | |
| dim, | |
| key_dim, | |
| num_heads=8, | |
| attn_ratio=4, | |
| act_cfg=dict(type='HSwish'), | |
| resolution=14, | |
| ): | |
| super(Attention, self).__init__() | |
| self.num_heads = num_heads | |
| self.scale = key_dim**-0.5 | |
| self.key_dim = key_dim | |
| self.nh_kd = nh_kd = key_dim * num_heads | |
| self.d = int(attn_ratio * key_dim) | |
| self.dh = int(attn_ratio * key_dim) * num_heads | |
| self.attn_ratio = attn_ratio | |
| h = self.dh + nh_kd * 2 | |
| self.qkv = LinearBatchNorm(dim, h) | |
| self.proj = nn.Sequential( | |
| build_activation_layer(act_cfg), LinearBatchNorm(self.dh, dim)) | |
| points = list(itertools.product(range(resolution), range(resolution))) | |
| N = len(points) | |
| attention_offsets = {} | |
| idxs = [] | |
| for p1 in points: | |
| for p2 in points: | |
| offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) | |
| if offset not in attention_offsets: | |
| attention_offsets[offset] = len(attention_offsets) | |
| idxs.append(attention_offsets[offset]) | |
| self.attention_biases = torch.nn.Parameter( | |
| torch.zeros(num_heads, len(attention_offsets))) | |
| self.register_buffer('attention_bias_idxs', | |
| torch.LongTensor(idxs).view(N, N)) | |
| def train(self, mode=True): | |
| """change the mode of model.""" | |
| super(Attention, self).train(mode) | |
| if mode and hasattr(self, 'ab'): | |
| del self.ab | |
| else: | |
| self.ab = self.attention_biases[:, self.attention_bias_idxs] | |
| def forward(self, x): # x (B,N,C) | |
| B, N, C = x.shape # 2 196 128 | |
| qkv = self.qkv(x) # 2 196 128 | |
| q, k, v = qkv.view(B, N, self.num_heads, -1).split( | |
| [self.key_dim, self.key_dim, self.d], | |
| dim=3) # q 2 196 4 16 ; k 2 196 4 16; v 2 196 4 32 | |
| q = q.permute(0, 2, 1, 3) # 2 4 196 16 | |
| k = k.permute(0, 2, 1, 3) | |
| v = v.permute(0, 2, 1, 3) | |
| attn = ((q @ k.transpose(-2, -1)) * | |
| self.scale # 2 4 196 16 * 2 4 16 196 -> 2 4 196 196 | |
| + (self.attention_biases[:, self.attention_bias_idxs] | |
| if self.training else self.ab)) | |
| attn = attn.softmax(dim=-1) # 2 4 196 196 -> 2 4 196 196 | |
| x = (attn @ v).transpose(1, 2).reshape( | |
| B, N, | |
| self.dh) # 2 4 196 196 * 2 4 196 32 -> 2 4 196 32 -> 2 196 128 | |
| x = self.proj(x) | |
| return x | |
| class MLP(nn.Sequential): | |
| def __init__(self, embed_dim, mlp_ratio, act_cfg=dict(type='HSwish')): | |
| super(MLP, self).__init__() | |
| h = embed_dim * mlp_ratio | |
| self.linear1 = LinearBatchNorm(embed_dim, h) | |
| self.activation = build_activation_layer(act_cfg) | |
| self.linear2 = LinearBatchNorm(h, embed_dim) | |
| def forward(self, x): | |
| x = self.linear1(x) | |
| x = self.activation(x) | |
| x = self.linear2(x) | |
| return x | |
| class Subsample(BaseModule): | |
| def __init__(self, stride, resolution): | |
| super(Subsample, self).__init__() | |
| self.stride = stride | |
| self.resolution = resolution | |
| def forward(self, x): | |
| B, _, C = x.shape | |
| # B, N, C -> B, H, W, C | |
| x = x.view(B, self.resolution, self.resolution, C) | |
| x = x[:, ::self.stride, ::self.stride] | |
| x = x.reshape(B, -1, C) # B, H', W', C -> B, N', C | |
| return x | |
| class AttentionSubsample(nn.Sequential): | |
| def __init__(self, | |
| in_dim, | |
| out_dim, | |
| key_dim, | |
| num_heads=8, | |
| attn_ratio=2, | |
| act_cfg=dict(type='HSwish'), | |
| stride=2, | |
| resolution=14): | |
| super(AttentionSubsample, self).__init__() | |
| self.num_heads = num_heads | |
| self.scale = key_dim**-0.5 | |
| self.key_dim = key_dim | |
| self.nh_kd = nh_kd = key_dim * num_heads | |
| self.d = int(attn_ratio * key_dim) | |
| self.dh = int(attn_ratio * key_dim) * self.num_heads | |
| self.attn_ratio = attn_ratio | |
| self.sub_resolution = (resolution - 1) // stride + 1 | |
| h = self.dh + nh_kd | |
| self.kv = LinearBatchNorm(in_dim, h) | |
| self.q = nn.Sequential( | |
| Subsample(stride, resolution), LinearBatchNorm(in_dim, nh_kd)) | |
| self.proj = nn.Sequential( | |
| build_activation_layer(act_cfg), LinearBatchNorm(self.dh, out_dim)) | |
| self.stride = stride | |
| self.resolution = resolution | |
| points = list(itertools.product(range(resolution), range(resolution))) | |
| sub_points = list( | |
| itertools.product( | |
| range(self.sub_resolution), range(self.sub_resolution))) | |
| N = len(points) | |
| N_sub = len(sub_points) | |
| attention_offsets = {} | |
| idxs = [] | |
| for p1 in sub_points: | |
| for p2 in points: | |
| size = 1 | |
| offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), | |
| abs(p1[1] * stride - p2[1] + (size - 1) / 2)) | |
| if offset not in attention_offsets: | |
| attention_offsets[offset] = len(attention_offsets) | |
| idxs.append(attention_offsets[offset]) | |
| self.attention_biases = torch.nn.Parameter( | |
| torch.zeros(num_heads, len(attention_offsets))) | |
| self.register_buffer('attention_bias_idxs', | |
| torch.LongTensor(idxs).view(N_sub, N)) | |
| def train(self, mode=True): | |
| super(AttentionSubsample, self).train(mode) | |
| if mode and hasattr(self, 'ab'): | |
| del self.ab | |
| else: | |
| self.ab = self.attention_biases[:, self.attention_bias_idxs] | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| k, v = self.kv(x).view(B, N, self.num_heads, | |
| -1).split([self.key_dim, self.d], dim=3) | |
| k = k.permute(0, 2, 1, 3) # BHNC | |
| v = v.permute(0, 2, 1, 3) # BHNC | |
| q = self.q(x).view(B, self.sub_resolution**2, self.num_heads, | |
| self.key_dim).permute(0, 2, 1, 3) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale + \ | |
| (self.attention_biases[:, self.attention_bias_idxs] | |
| if self.training else self.ab) | |
| attn = attn.softmax(dim=-1) | |
| x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) | |
| x = self.proj(x) | |
| return x | |
| class LeViT(BaseBackbone): | |
| """LeViT backbone. | |
| A PyTorch implementation of `LeViT: A Vision Transformer in ConvNet's | |
| Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_ | |
| Modified from the official implementation: | |
| https://github.com/facebookresearch/LeViT | |
| Args: | |
| arch (str | dict): LeViT architecture. | |
| If use string, choose from '128s', '128', '192', '256' and '384'. | |
| If use dict, it should have below keys: | |
| - **embed_dims** (List[int]): The embed dimensions of each stage. | |
| - **key_dims** (List[int]): The embed dimensions of the key in the | |
| attention layers of each stage. | |
| - **num_heads** (List[int]): The number of heads in each stage. | |
| - **depths** (List[int]): The number of blocks in each stage. | |
| img_size (int): Input image size | |
| patch_size (int | tuple): The patch size. Deault to 16 | |
| attn_ratio (int): Ratio of hidden dimensions of the value in attention | |
| layers. Defaults to 2. | |
| mlp_ratio (int): Ratio of hidden dimensions in MLP layers. | |
| Defaults to 2. | |
| act_cfg (dict): The config of activation functions. | |
| Defaults to ``dict(type='HSwish')``. | |
| hybrid_backbone (callable): A callable object to build the patch embed | |
| module. Defaults to use :class:`HybridBackbone`. | |
| out_indices (Sequence | int): Output from which stages. | |
| Defaults to -1, means the last stage. | |
| deploy (bool): Whether to switch the model structure to | |
| deployment mode. Defaults to False. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| arch_zoo = { | |
| '128s': { | |
| 'embed_dims': [128, 256, 384], | |
| 'num_heads': [4, 6, 8], | |
| 'depths': [2, 3, 4], | |
| 'key_dims': [16, 16, 16], | |
| }, | |
| '128': { | |
| 'embed_dims': [128, 256, 384], | |
| 'num_heads': [4, 8, 12], | |
| 'depths': [4, 4, 4], | |
| 'key_dims': [16, 16, 16], | |
| }, | |
| '192': { | |
| 'embed_dims': [192, 288, 384], | |
| 'num_heads': [3, 5, 6], | |
| 'depths': [4, 4, 4], | |
| 'key_dims': [32, 32, 32], | |
| }, | |
| '256': { | |
| 'embed_dims': [256, 384, 512], | |
| 'num_heads': [4, 6, 8], | |
| 'depths': [4, 4, 4], | |
| 'key_dims': [32, 32, 32], | |
| }, | |
| '384': { | |
| 'embed_dims': [384, 512, 768], | |
| 'num_heads': [6, 9, 12], | |
| 'depths': [4, 4, 4], | |
| 'key_dims': [32, 32, 32], | |
| }, | |
| } | |
| def __init__(self, | |
| arch, | |
| img_size=224, | |
| patch_size=16, | |
| attn_ratio=2, | |
| mlp_ratio=2, | |
| act_cfg=dict(type='HSwish'), | |
| hybrid_backbone=HybridBackbone, | |
| out_indices=-1, | |
| deploy=False, | |
| drop_path_rate=0, | |
| init_cfg=None): | |
| super(LeViT, self).__init__(init_cfg=init_cfg) | |
| if isinstance(arch, str): | |
| arch = arch.lower() | |
| assert arch in set(self.arch_zoo), \ | |
| f'Arch {arch} is not in default archs {set(self.arch_zoo)}' | |
| self.arch = self.arch_zoo[arch] | |
| elif isinstance(arch, dict): | |
| essential_keys = {'embed_dim', 'num_heads', 'depth', 'key_dim'} | |
| assert isinstance(arch, dict) and set(arch) == essential_keys, \ | |
| f'Custom arch needs a dict with keys {essential_keys}' | |
| self.arch = arch | |
| else: | |
| raise TypeError('Expect "arch" to be either a string ' | |
| f'or a dict, got {type(arch)}') | |
| self.embed_dims = self.arch['embed_dims'] | |
| self.num_heads = self.arch['num_heads'] | |
| self.key_dims = self.arch['key_dims'] | |
| self.depths = self.arch['depths'] | |
| self.num_stages = len(self.embed_dims) | |
| self.drop_path_rate = drop_path_rate | |
| self.patch_embed = hybrid_backbone(self.embed_dims[0]) | |
| self.resolutions = [] | |
| resolution = img_size // patch_size | |
| self.stages = ModuleList() | |
| for i, (embed_dims, key_dims, depth, num_heads) in enumerate( | |
| zip(self.embed_dims, self.key_dims, self.depths, | |
| self.num_heads)): | |
| blocks = [] | |
| if i > 0: | |
| downsample = AttentionSubsample( | |
| in_dim=self.embed_dims[i - 1], | |
| out_dim=embed_dims, | |
| key_dim=key_dims, | |
| num_heads=self.embed_dims[i - 1] // key_dims, | |
| attn_ratio=4, | |
| act_cfg=act_cfg, | |
| stride=2, | |
| resolution=resolution) | |
| blocks.append(downsample) | |
| resolution = downsample.sub_resolution | |
| if mlp_ratio > 0: # mlp_ratio | |
| blocks.append( | |
| Residual( | |
| MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), | |
| self.drop_path_rate)) | |
| self.resolutions.append(resolution) | |
| for _ in range(depth): | |
| blocks.append( | |
| Residual( | |
| Attention( | |
| embed_dims, | |
| key_dims, | |
| num_heads, | |
| attn_ratio=attn_ratio, | |
| act_cfg=act_cfg, | |
| resolution=resolution, | |
| ), self.drop_path_rate)) | |
| if mlp_ratio > 0: | |
| blocks.append( | |
| Residual( | |
| MLP(embed_dims, mlp_ratio, act_cfg=act_cfg), | |
| self.drop_path_rate)) | |
| self.stages.append(Sequential(*blocks)) | |
| if isinstance(out_indices, int): | |
| out_indices = [out_indices] | |
| elif isinstance(out_indices, tuple): | |
| out_indices = list(out_indices) | |
| elif not isinstance(out_indices, list): | |
| raise TypeError('"out_indices" must by a list, tuple or int, ' | |
| f'get {type(out_indices)} instead.') | |
| for i, index in enumerate(out_indices): | |
| if index < 0: | |
| out_indices[i] = self.num_stages + index | |
| assert 0 <= out_indices[i] < self.num_stages, \ | |
| f'Invalid out_indices {index}.' | |
| self.out_indices = out_indices | |
| self.deploy = False | |
| if deploy: | |
| self.switch_to_deploy() | |
| def switch_to_deploy(self): | |
| if self.deploy: | |
| return | |
| fuse_parameters(self) | |
| self.deploy = True | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| x = x.flatten(2).transpose(1, 2) # B, C, H, W -> B, L, C | |
| outs = [] | |
| for i, stage in enumerate(self.stages): | |
| x = stage(x) | |
| B, _, C = x.shape | |
| if i in self.out_indices: | |
| out = x.reshape(B, self.resolutions[i], self.resolutions[i], C) | |
| out = out.permute(0, 3, 1, 2).contiguous() | |
| outs.append(out) | |
| return tuple(outs) | |
| def fuse_parameters(module): | |
| for child_name, child in module.named_children(): | |
| if hasattr(child, 'fuse'): | |
| setattr(module, child_name, child.fuse()) | |
| else: | |
| fuse_parameters(child) | |