Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Sequence | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import build_activation_layer, build_norm_layer | |
| from mmcv.cnn.bricks import DropPath | |
| from mmcv.cnn.bricks.transformer import PatchEmbed | |
| from mmengine.model import BaseModule, ModuleList | |
| from mmengine.model.weight_init import trunc_normal_ | |
| from mmengine.utils import to_2tuple | |
| from ..builder import BACKBONES | |
| from ..utils import resize_pos_embed | |
| from .base_backbone import BaseBackbone | |
| def resize_decomposed_rel_pos(rel_pos, q_size, k_size): | |
| """Get relative positional embeddings according to the relative positions | |
| of query and key sizes. | |
| Args: | |
| q_size (int): size of query q. | |
| k_size (int): size of key k. | |
| rel_pos (Tensor): relative position embeddings (L, C). | |
| Returns: | |
| Extracted positional embeddings according to relative positions. | |
| """ | |
| max_rel_dist = int(2 * max(q_size, k_size) - 1) | |
| # Interpolate rel pos if needed. | |
| if rel_pos.shape[0] != max_rel_dist: | |
| # Interpolate rel pos. | |
| resized = F.interpolate( | |
| # (L, C) -> (1, C, L) | |
| rel_pos.transpose(0, 1).unsqueeze(0), | |
| size=max_rel_dist, | |
| mode='linear', | |
| ) | |
| # (1, C, L) -> (L, C) | |
| resized = resized.squeeze(0).transpose(0, 1) | |
| else: | |
| resized = rel_pos | |
| # Scale the coords with short length if shapes for q and k are different. | |
| q_h_ratio = max(k_size / q_size, 1.0) | |
| k_h_ratio = max(q_size / k_size, 1.0) | |
| q_coords = torch.arange(q_size)[:, None] * q_h_ratio | |
| k_coords = torch.arange(k_size)[None, :] * k_h_ratio | |
| relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio | |
| return resized[relative_coords.long()] | |
| def add_decomposed_rel_pos(attn, | |
| q, | |
| q_shape, | |
| k_shape, | |
| rel_pos_h, | |
| rel_pos_w, | |
| has_cls_token=False): | |
| """Spatial Relative Positional Embeddings.""" | |
| sp_idx = 1 if has_cls_token else 0 | |
| B, num_heads, _, C = q.shape | |
| q_h, q_w = q_shape | |
| k_h, k_w = k_shape | |
| Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h) | |
| Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w) | |
| r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C) | |
| rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) | |
| rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) | |
| rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :] | |
| attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) | |
| attn_map += rel_pos_embed | |
| attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w) | |
| return attn | |
| class MLP(BaseModule): | |
| """Two-layer multilayer perceptron. | |
| Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows | |
| different input and output channel numbers. | |
| Args: | |
| in_channels (int): The number of input channels. | |
| hidden_channels (int, optional): The number of hidden layer channels. | |
| If None, same as the ``in_channels``. Defaults to None. | |
| out_channels (int, optional): The number of output channels. If None, | |
| same as the ``in_channels``. Defaults to None. | |
| act_cfg (dict): The config of activation function. | |
| Defaults to ``dict(type='GELU')``. | |
| init_cfg (dict, optional): The config of weight initialization. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| hidden_channels=None, | |
| out_channels=None, | |
| act_cfg=dict(type='GELU'), | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| out_channels = out_channels or in_channels | |
| hidden_channels = hidden_channels or in_channels | |
| self.fc1 = nn.Linear(in_channels, hidden_channels) | |
| self.act = build_activation_layer(act_cfg) | |
| self.fc2 = nn.Linear(hidden_channels, out_channels) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.fc2(x) | |
| return x | |
| def attention_pool(x: torch.Tensor, | |
| pool: nn.Module, | |
| in_size: tuple, | |
| norm: Optional[nn.Module] = None): | |
| """Pooling the feature tokens. | |
| Args: | |
| x (torch.Tensor): The input tensor, should be with shape | |
| ``(B, num_heads, L, C)`` or ``(B, L, C)``. | |
| pool (nn.Module): The pooling module. | |
| in_size (Tuple[int]): The shape of the input feature map. | |
| norm (nn.Module, optional): The normalization module. | |
| Defaults to None. | |
| """ | |
| ndim = x.ndim | |
| if ndim == 4: | |
| B, num_heads, L, C = x.shape | |
| elif ndim == 3: | |
| num_heads = 1 | |
| B, L, C = x.shape | |
| else: | |
| raise RuntimeError(f'Unsupported input dimension {x.shape}') | |
| H, W = in_size | |
| assert L == H * W | |
| # (B, num_heads, H*W, C) -> (B*num_heads, C, H, W) | |
| x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous() | |
| x = pool(x) | |
| out_size = x.shape[-2:] | |
| # (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C) | |
| x = x.reshape(B, num_heads, C, -1).transpose(2, 3) | |
| if norm is not None: | |
| x = norm(x) | |
| if ndim == 3: | |
| x = x.squeeze(1) | |
| return x, out_size | |
| class MultiScaleAttention(BaseModule): | |
| """Multiscale Multi-head Attention block. | |
| Args: | |
| in_dims (int): Number of input channels. | |
| out_dims (int): Number of output channels. | |
| num_heads (int): Number of attention heads. | |
| qkv_bias (bool): If True, add a learnable bias to query, key and | |
| value. Defaults to True. | |
| norm_cfg (dict): The config of normalization layers. | |
| Defaults to ``dict(type='LN')``. | |
| pool_kernel (tuple): kernel size for qkv pooling layers. | |
| Defaults to (3, 3). | |
| stride_q (int): stride size for q pooling layer. Defaults to 1. | |
| stride_kv (int): stride size for kv pooling layer. Defaults to 1. | |
| rel_pos_spatial (bool): Whether to enable the spatial relative | |
| position embedding. Defaults to True. | |
| residual_pooling (bool): Whether to enable the residual connection | |
| after attention pooling. Defaults to True. | |
| input_size (Tuple[int], optional): The input resolution, necessary | |
| if enable the ``rel_pos_spatial``. Defaults to None. | |
| rel_pos_zero_init (bool): If True, zero initialize relative | |
| positional parameters. Defaults to False. | |
| init_cfg (dict, optional): The config of weight initialization. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_dims, | |
| out_dims, | |
| num_heads, | |
| qkv_bias=True, | |
| norm_cfg=dict(type='LN'), | |
| pool_kernel=(3, 3), | |
| stride_q=1, | |
| stride_kv=1, | |
| rel_pos_spatial=False, | |
| residual_pooling=True, | |
| input_size=None, | |
| rel_pos_zero_init=False, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_heads = num_heads | |
| self.in_dims = in_dims | |
| self.out_dims = out_dims | |
| head_dim = out_dims // num_heads | |
| self.scale = head_dim**-0.5 | |
| self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias) | |
| self.proj = nn.Linear(out_dims, out_dims) | |
| # qkv pooling | |
| pool_padding = [k // 2 for k in pool_kernel] | |
| pool_dims = out_dims // num_heads | |
| def build_pooling(stride): | |
| pool = nn.Conv2d( | |
| pool_dims, | |
| pool_dims, | |
| pool_kernel, | |
| stride=stride, | |
| padding=pool_padding, | |
| groups=pool_dims, | |
| bias=False, | |
| ) | |
| norm = build_norm_layer(norm_cfg, pool_dims)[1] | |
| return pool, norm | |
| self.pool_q, self.norm_q = build_pooling(stride_q) | |
| self.pool_k, self.norm_k = build_pooling(stride_kv) | |
| self.pool_v, self.norm_v = build_pooling(stride_kv) | |
| self.residual_pooling = residual_pooling | |
| self.rel_pos_spatial = rel_pos_spatial | |
| self.rel_pos_zero_init = rel_pos_zero_init | |
| if self.rel_pos_spatial: | |
| # initialize relative positional embeddings | |
| assert input_size[0] == input_size[1] | |
| size = input_size[0] | |
| rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1 | |
| self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim)) | |
| self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim)) | |
| def init_weights(self): | |
| """Weight initialization.""" | |
| super().init_weights() | |
| if (isinstance(self.init_cfg, dict) | |
| and self.init_cfg['type'] == 'Pretrained'): | |
| # Suppress rel_pos_zero_init if use pretrained model. | |
| return | |
| if not self.rel_pos_zero_init: | |
| trunc_normal_(self.rel_pos_h, std=0.02) | |
| trunc_normal_(self.rel_pos_w, std=0.02) | |
| def forward(self, x, in_size): | |
| """Forward the MultiScaleAttention.""" | |
| B, N, _ = x.shape # (B, H*W, C) | |
| # qkv: (B, H*W, 3, num_heads, C) | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1) | |
| # q, k, v: (B, num_heads, H*W, C) | |
| q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) | |
| q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q) | |
| k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k) | |
| v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v) | |
| attn = (q * self.scale) @ k.transpose(-2, -1) | |
| if self.rel_pos_spatial: | |
| attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape, | |
| self.rel_pos_h, self.rel_pos_w) | |
| attn = attn.softmax(dim=-1) | |
| x = attn @ v | |
| if self.residual_pooling: | |
| x = x + q | |
| # (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C') | |
| x = x.transpose(1, 2).reshape(B, -1, self.out_dims) | |
| x = self.proj(x) | |
| return x, q_shape | |
| class MultiScaleBlock(BaseModule): | |
| """Multiscale Transformer blocks. | |
| Args: | |
| in_dims (int): Number of input channels. | |
| out_dims (int): Number of output channels. | |
| num_heads (int): Number of attention heads. | |
| mlp_ratio (float): Ratio of hidden dimensions in MLP layers. | |
| Defaults to 4.0. | |
| qkv_bias (bool): If True, add a learnable bias to query, key and | |
| value. Defaults to True. | |
| drop_path (float): Stochastic depth rate. Defaults to 0. | |
| norm_cfg (dict): The config of normalization layers. | |
| Defaults to ``dict(type='LN')``. | |
| act_cfg (dict): The config of activation function. | |
| Defaults to ``dict(type='GELU')``. | |
| qkv_pool_kernel (tuple): kernel size for qkv pooling layers. | |
| Defaults to (3, 3). | |
| stride_q (int): stride size for q pooling layer. Defaults to 1. | |
| stride_kv (int): stride size for kv pooling layer. Defaults to 1. | |
| rel_pos_spatial (bool): Whether to enable the spatial relative | |
| position embedding. Defaults to True. | |
| residual_pooling (bool): Whether to enable the residual connection | |
| after attention pooling. Defaults to True. | |
| dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in | |
| attention layers. If False, multiply it in MLP layers. | |
| Defaults to True. | |
| input_size (Tuple[int], optional): The input resolution, necessary | |
| if enable the ``rel_pos_spatial``. Defaults to None. | |
| rel_pos_zero_init (bool): If True, zero initialize relative | |
| positional parameters. Defaults to False. | |
| init_cfg (dict, optional): The config of weight initialization. | |
| Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| in_dims, | |
| out_dims, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=0.0, | |
| norm_cfg=dict(type='LN'), | |
| act_cfg=dict(type='GELU'), | |
| qkv_pool_kernel=(3, 3), | |
| stride_q=1, | |
| stride_kv=1, | |
| rel_pos_spatial=True, | |
| residual_pooling=True, | |
| dim_mul_in_attention=True, | |
| input_size=None, | |
| rel_pos_zero_init=False, | |
| init_cfg=None, | |
| ): | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_dims = in_dims | |
| self.out_dims = out_dims | |
| self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] | |
| self.dim_mul_in_attention = dim_mul_in_attention | |
| attn_dims = out_dims if dim_mul_in_attention else in_dims | |
| self.attn = MultiScaleAttention( | |
| in_dims, | |
| attn_dims, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| norm_cfg=norm_cfg, | |
| pool_kernel=qkv_pool_kernel, | |
| stride_q=stride_q, | |
| stride_kv=stride_kv, | |
| rel_pos_spatial=rel_pos_spatial, | |
| residual_pooling=residual_pooling, | |
| input_size=input_size, | |
| rel_pos_zero_init=rel_pos_zero_init) | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1] | |
| self.mlp = MLP( | |
| in_channels=attn_dims, | |
| hidden_channels=int(attn_dims * mlp_ratio), | |
| out_channels=out_dims, | |
| act_cfg=act_cfg) | |
| if in_dims != out_dims: | |
| self.proj = nn.Linear(in_dims, out_dims) | |
| else: | |
| self.proj = None | |
| if stride_q > 1: | |
| kernel_skip = stride_q + 1 | |
| padding_skip = int(kernel_skip // 2) | |
| self.pool_skip = nn.MaxPool2d( | |
| kernel_skip, stride_q, padding_skip, ceil_mode=False) | |
| if input_size is not None: | |
| input_size = to_2tuple(input_size) | |
| out_size = [size // stride_q for size in input_size] | |
| self.init_out_size = out_size | |
| else: | |
| self.init_out_size = None | |
| else: | |
| self.pool_skip = None | |
| self.init_out_size = input_size | |
| def forward(self, x, in_size): | |
| x_norm = self.norm1(x) | |
| x_attn, out_size = self.attn(x_norm, in_size) | |
| if self.dim_mul_in_attention and self.proj is not None: | |
| skip = self.proj(x_norm) | |
| else: | |
| skip = x | |
| if self.pool_skip is not None: | |
| skip, _ = attention_pool(skip, self.pool_skip, in_size) | |
| x = skip + self.drop_path(x_attn) | |
| x_norm = self.norm2(x) | |
| x_mlp = self.mlp(x_norm) | |
| if not self.dim_mul_in_attention and self.proj is not None: | |
| skip = self.proj(x_norm) | |
| else: | |
| skip = x | |
| x = skip + self.drop_path(x_mlp) | |
| return x, out_size | |
| class MViT(BaseBackbone): | |
| """Multi-scale ViT v2. | |
| A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers | |
| for Classification and Detection <https://arxiv.org/abs/2112.01526>`_ | |
| Inspiration from `the official implementation | |
| <https://github.com/facebookresearch/mvit>`_ and `the detectron2 | |
| implementation <https://github.com/facebookresearch/detectron2>`_ | |
| Args: | |
| arch (str | dict): MViT architecture. If use string, choose | |
| from 'tiny', 'small', 'base' and 'large'. If use dict, it should | |
| have below keys: | |
| - **embed_dims** (int): The dimensions of embedding. | |
| - **num_layers** (int): The number of layers. | |
| - **num_heads** (int): The number of heads in attention | |
| modules of the initial layer. | |
| - **downscale_indices** (List[int]): The layer indices to downscale | |
| the feature map. | |
| Defaults to 'base'. | |
| img_size (int): The expected input image shape. Defaults to 224. | |
| in_channels (int): The num of input channels. Defaults to 3. | |
| out_scales (int | Sequence[int]): The output scale indices. | |
| They should not exceed the length of ``downscale_indices``. | |
| Defaults to -1, which means the last scale. | |
| drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. | |
| use_abs_pos_embed (bool): If True, add absolute position embedding to | |
| the patch embedding. Defaults to False. | |
| interpolate_mode (str): Select the interpolate mode for absolute | |
| position embedding vector resize. Defaults to "bicubic". | |
| pool_kernel (tuple): kernel size for qkv pooling layers. | |
| Defaults to (3, 3). | |
| dim_mul (int): The magnification for ``embed_dims`` in the downscale | |
| layers. Defaults to 2. | |
| head_mul (int): The magnification for ``num_heads`` in the downscale | |
| layers. Defaults to 2. | |
| adaptive_kv_stride (int): The stride size for kv pooling in the initial | |
| layer. Defaults to 4. | |
| rel_pos_spatial (bool): Whether to enable the spatial relative position | |
| embedding. Defaults to True. | |
| residual_pooling (bool): Whether to enable the residual connection | |
| after attention pooling. Defaults to True. | |
| dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in | |
| attention layers. If False, multiply it in MLP layers. | |
| Defaults to True. | |
| rel_pos_zero_init (bool): If True, zero initialize relative | |
| positional parameters. Defaults to False. | |
| mlp_ratio (float): Ratio of hidden dimensions in MLP layers. | |
| Defaults to 4.0. | |
| qkv_bias (bool): enable bias for qkv if True. Defaults to True. | |
| norm_cfg (dict): Config dict for normalization layer for all output | |
| features. Defaults to ``dict(type='LN', eps=1e-6)``. | |
| patch_cfg (dict): Config dict for the patch embedding layer. | |
| Defaults to ``dict(kernel_size=7, stride=4, padding=3)``. | |
| init_cfg (dict, optional): The Config for initialization. | |
| Defaults to None. | |
| Examples: | |
| >>> import torch | |
| >>> from mmpretrain.models import build_backbone | |
| >>> | |
| >>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3]) | |
| >>> model = build_backbone(cfg) | |
| >>> inputs = torch.rand(1, 3, 224, 224) | |
| >>> outputs = model(inputs) | |
| >>> for i, output in enumerate(outputs): | |
| >>> print(f'scale{i}: {output.shape}') | |
| scale0: torch.Size([1, 96, 56, 56]) | |
| scale1: torch.Size([1, 192, 28, 28]) | |
| scale2: torch.Size([1, 384, 14, 14]) | |
| scale3: torch.Size([1, 768, 7, 7]) | |
| """ | |
| arch_zoo = { | |
| 'tiny': { | |
| 'embed_dims': 96, | |
| 'num_layers': 10, | |
| 'num_heads': 1, | |
| 'downscale_indices': [1, 3, 8] | |
| }, | |
| 'small': { | |
| 'embed_dims': 96, | |
| 'num_layers': 16, | |
| 'num_heads': 1, | |
| 'downscale_indices': [1, 3, 14] | |
| }, | |
| 'base': { | |
| 'embed_dims': 96, | |
| 'num_layers': 24, | |
| 'num_heads': 1, | |
| 'downscale_indices': [2, 5, 21] | |
| }, | |
| 'large': { | |
| 'embed_dims': 144, | |
| 'num_layers': 48, | |
| 'num_heads': 2, | |
| 'downscale_indices': [2, 8, 44] | |
| }, | |
| } | |
| num_extra_tokens = 0 | |
| def __init__(self, | |
| arch='base', | |
| img_size=224, | |
| in_channels=3, | |
| out_scales=-1, | |
| drop_path_rate=0., | |
| use_abs_pos_embed=False, | |
| interpolate_mode='bicubic', | |
| pool_kernel=(3, 3), | |
| dim_mul=2, | |
| head_mul=2, | |
| adaptive_kv_stride=4, | |
| rel_pos_spatial=True, | |
| residual_pooling=True, | |
| dim_mul_in_attention=True, | |
| rel_pos_zero_init=False, | |
| mlp_ratio=4., | |
| qkv_bias=True, | |
| norm_cfg=dict(type='LN', eps=1e-6), | |
| patch_cfg=dict(kernel_size=7, stride=4, padding=3), | |
| init_cfg=None): | |
| super().__init__(init_cfg) | |
| if isinstance(arch, str): | |
| arch = arch.lower() | |
| assert arch in set(self.arch_zoo), \ | |
| f'Arch {arch} is not in default archs {set(self.arch_zoo)}' | |
| self.arch_settings = self.arch_zoo[arch] | |
| else: | |
| essential_keys = { | |
| 'embed_dims', 'num_layers', 'num_heads', 'downscale_indices' | |
| } | |
| assert isinstance(arch, dict) and essential_keys <= set(arch), \ | |
| f'Custom arch needs a dict with keys {essential_keys}' | |
| self.arch_settings = arch | |
| self.embed_dims = self.arch_settings['embed_dims'] | |
| self.num_layers = self.arch_settings['num_layers'] | |
| self.num_heads = self.arch_settings['num_heads'] | |
| self.downscale_indices = self.arch_settings['downscale_indices'] | |
| self.num_scales = len(self.downscale_indices) + 1 | |
| self.stage_indices = { | |
| index - 1: i | |
| for i, index in enumerate(self.downscale_indices) | |
| } | |
| self.stage_indices[self.num_layers - 1] = self.num_scales - 1 | |
| self.use_abs_pos_embed = use_abs_pos_embed | |
| self.interpolate_mode = interpolate_mode | |
| if isinstance(out_scales, int): | |
| out_scales = [out_scales] | |
| assert isinstance(out_scales, Sequence), \ | |
| f'"out_scales" must by a sequence or int, ' \ | |
| f'get {type(out_scales)} instead.' | |
| for i, index in enumerate(out_scales): | |
| if index < 0: | |
| out_scales[i] = self.num_scales + index | |
| assert 0 <= out_scales[i] <= self.num_scales, \ | |
| f'Invalid out_scales {index}' | |
| self.out_scales = sorted(list(out_scales)) | |
| # Set patch embedding | |
| _patch_cfg = dict( | |
| in_channels=in_channels, | |
| input_size=img_size, | |
| embed_dims=self.embed_dims, | |
| conv_type='Conv2d', | |
| ) | |
| _patch_cfg.update(patch_cfg) | |
| self.patch_embed = PatchEmbed(**_patch_cfg) | |
| self.patch_resolution = self.patch_embed.init_out_size | |
| # Set absolute position embedding | |
| if self.use_abs_pos_embed: | |
| num_patches = self.patch_resolution[0] * self.patch_resolution[1] | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros(1, num_patches, self.embed_dims)) | |
| # stochastic depth decay rule | |
| dpr = np.linspace(0, drop_path_rate, self.num_layers) | |
| self.blocks = ModuleList() | |
| out_dims_list = [self.embed_dims] | |
| num_heads = self.num_heads | |
| stride_kv = adaptive_kv_stride | |
| input_size = self.patch_resolution | |
| for i in range(self.num_layers): | |
| if i in self.downscale_indices: | |
| num_heads *= head_mul | |
| stride_q = 2 | |
| stride_kv = max(stride_kv // 2, 1) | |
| else: | |
| stride_q = 1 | |
| # Set output embed_dims | |
| if dim_mul_in_attention and i in self.downscale_indices: | |
| # multiply embed_dims in downscale layers. | |
| out_dims = out_dims_list[-1] * dim_mul | |
| elif not dim_mul_in_attention and i + 1 in self.downscale_indices: | |
| # multiply embed_dims before downscale layers. | |
| out_dims = out_dims_list[-1] * dim_mul | |
| else: | |
| out_dims = out_dims_list[-1] | |
| attention_block = MultiScaleBlock( | |
| in_dims=out_dims_list[-1], | |
| out_dims=out_dims, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| drop_path=dpr[i], | |
| norm_cfg=norm_cfg, | |
| qkv_pool_kernel=pool_kernel, | |
| stride_q=stride_q, | |
| stride_kv=stride_kv, | |
| rel_pos_spatial=rel_pos_spatial, | |
| residual_pooling=residual_pooling, | |
| dim_mul_in_attention=dim_mul_in_attention, | |
| input_size=input_size, | |
| rel_pos_zero_init=rel_pos_zero_init) | |
| self.blocks.append(attention_block) | |
| input_size = attention_block.init_out_size | |
| out_dims_list.append(out_dims) | |
| if i in self.stage_indices: | |
| stage_index = self.stage_indices[i] | |
| if stage_index in self.out_scales: | |
| norm_layer = build_norm_layer(norm_cfg, out_dims)[1] | |
| self.add_module(f'norm{stage_index}', norm_layer) | |
| def init_weights(self): | |
| super().init_weights() | |
| if (isinstance(self.init_cfg, dict) | |
| and self.init_cfg['type'] == 'Pretrained'): | |
| # Suppress default init if use pretrained model. | |
| return | |
| if self.use_abs_pos_embed: | |
| trunc_normal_(self.pos_embed, std=0.02) | |
| def forward(self, x): | |
| """Forward the MViT.""" | |
| B = x.shape[0] | |
| x, patch_resolution = self.patch_embed(x) | |
| if self.use_abs_pos_embed: | |
| x = x + resize_pos_embed( | |
| self.pos_embed, | |
| self.patch_resolution, | |
| patch_resolution, | |
| mode=self.interpolate_mode, | |
| num_extra_tokens=self.num_extra_tokens) | |
| outs = [] | |
| for i, block in enumerate(self.blocks): | |
| x, patch_resolution = block(x, patch_resolution) | |
| if i in self.stage_indices: | |
| stage_index = self.stage_indices[i] | |
| if stage_index in self.out_scales: | |
| B, _, C = x.shape | |
| x = getattr(self, f'norm{stage_index}')(x) | |
| out = x.transpose(1, 2).reshape(B, C, *patch_resolution) | |
| outs.append(out.contiguous()) | |
| return tuple(outs) | |