Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.model.weight_init import trunc_normal_ | |
| from mmpretrain.registry import MODELS | |
| from .vision_transformer import VisionTransformer | |
| class DistilledVisionTransformer(VisionTransformer): | |
| """Distilled Vision Transformer. | |
| A PyTorch implement of : `Training data-efficient image transformers & | |
| distillation through attention <https://arxiv.org/abs/2012.12877>`_ | |
| Args: | |
| arch (str | dict): Vision Transformer architecture. If use string, | |
| choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' | |
| and 'deit-base'. If use dict, it should have below keys: | |
| - **embed_dims** (int): The dimensions of embedding. | |
| - **num_layers** (int): The number of transformer encoder layers. | |
| - **num_heads** (int): The number of heads in attention modules. | |
| - **feedforward_channels** (int): The hidden dimensions in | |
| feedforward modules. | |
| Defaults to 'deit-base'. | |
| img_size (int | tuple): The expected input image shape. Because we | |
| support dynamic input shape, just set the argument to the most | |
| common input image shape. Defaults to 224. | |
| patch_size (int | tuple): The patch size in patch embedding. | |
| Defaults to 16. | |
| in_channels (int): The num of input channels. Defaults to 3. | |
| out_indices (Sequence | int): Output from which stages. | |
| Defaults to -1, means the last stage. | |
| drop_rate (float): Probability of an element to be zeroed. | |
| Defaults to 0. | |
| drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
| qkv_bias (bool): Whether to add bias for qkv in attention modules. | |
| Defaults to True. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Defaults to ``dict(type='LN')``. | |
| final_norm (bool): Whether to add a additional layer to normalize | |
| final feature map. Defaults to True. | |
| out_type (str): The type of output features. Please choose from | |
| - ``"cls_token"``: A tuple with the class token and the | |
| distillation token. The shapes of both tensor are (B, C). | |
| - ``"featmap"``: The feature map tensor from the patch tokens | |
| with shape (B, C, H, W). | |
| - ``"avg_featmap"``: The global averaged feature map tensor | |
| with shape (B, C). | |
| - ``"raw"``: The raw feature tensor includes patch tokens and | |
| class tokens with shape (B, L, C). | |
| Defaults to ``"cls_token"``. | |
| interpolate_mode (str): Select the interpolate mode for position | |
| embeding vector resize. Defaults to "bicubic". | |
| patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. | |
| layer_cfgs (Sequence | dict): Configs of each transformer layer in | |
| encoder. Defaults to an empty dict. | |
| init_cfg (dict, optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| num_extra_tokens = 2 # class token and distillation token | |
| def __init__(self, arch='deit-base', *args, **kwargs): | |
| super(DistilledVisionTransformer, self).__init__( | |
| arch=arch, | |
| with_cls_token=True, | |
| *args, | |
| **kwargs, | |
| ) | |
| self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| x, patch_resolution = self.patch_embed(x) | |
| # stole cls_tokens impl from Phil Wang, thanks | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| dist_token = self.dist_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, dist_token, x), dim=1) | |
| x = x + self.resize_pos_embed( | |
| self.pos_embed, | |
| self.patch_resolution, | |
| patch_resolution, | |
| mode=self.interpolate_mode, | |
| num_extra_tokens=self.num_extra_tokens) | |
| x = self.drop_after_pos(x) | |
| outs = [] | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x) | |
| if i == len(self.layers) - 1 and self.final_norm: | |
| x = self.ln1(x) | |
| if i in self.out_indices: | |
| outs.append(self._format_output(x, patch_resolution)) | |
| return tuple(outs) | |
| def _format_output(self, x, hw): | |
| if self.out_type == 'cls_token': | |
| return x[:, 0], x[:, 1] | |
| return super()._format_output(x, hw) | |
| def init_weights(self): | |
| super(DistilledVisionTransformer, self).init_weights() | |
| if not (isinstance(self.init_cfg, dict) | |
| and self.init_cfg['type'] == 'Pretrained'): | |
| trunc_normal_(self.dist_token, std=0.02) | |