Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import build_norm_layer | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| from ..backbones.vision_transformer import TransformerEncoderLayer | |
| from ..utils import build_2d_sincos_position_embedding | |
| class MAEPretrainDecoder(BaseModule): | |
| """Decoder for MAE Pre-training. | |
| Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa | |
| Args: | |
| num_patches (int): The number of total patches. Defaults to 196. | |
| patch_size (int): Image patch size. Defaults to 16. | |
| in_chans (int): The channel of input image. Defaults to 3. | |
| embed_dim (int): Encoder's embedding dimension. Defaults to 1024. | |
| decoder_embed_dim (int): Decoder's embedding dimension. | |
| Defaults to 512. | |
| decoder_depth (int): The depth of decoder. Defaults to 8. | |
| decoder_num_heads (int): Number of attention heads of decoder. | |
| Defaults to 16. | |
| mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. | |
| Defaults to 4. | |
| norm_cfg (dict): Normalization layer. Defaults to LayerNorm. | |
| init_cfg (Union[List[dict], dict], optional): Initialization config | |
| dict. Defaults to None. | |
| Example: | |
| >>> from mmpretrain.models import MAEPretrainDecoder | |
| >>> import torch | |
| >>> self = MAEPretrainDecoder() | |
| >>> self.eval() | |
| >>> inputs = torch.rand(1, 50, 1024) | |
| >>> ids_restore = torch.arange(0, 196).unsqueeze(0) | |
| >>> level_outputs = self.forward(inputs, ids_restore) | |
| >>> print(tuple(level_outputs.shape)) | |
| (1, 196, 768) | |
| """ | |
| def __init__(self, | |
| num_patches: int = 196, | |
| patch_size: int = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 1024, | |
| decoder_embed_dim: int = 512, | |
| decoder_depth: int = 8, | |
| decoder_num_heads: int = 16, | |
| mlp_ratio: int = 4, | |
| norm_cfg: dict = dict(type='LN', eps=1e-6), | |
| predict_feature_dim: Optional[float] = None, | |
| init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.num_patches = num_patches | |
| # used to convert the dim of features from encoder to the dim | |
| # compatible with that of decoder | |
| self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
| # create new position embedding, different from that in encoder | |
| # and is not learnable | |
| self.decoder_pos_embed = nn.Parameter( | |
| torch.zeros(1, self.num_patches + 1, decoder_embed_dim), | |
| requires_grad=False) | |
| self.decoder_blocks = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| decoder_embed_dim, | |
| decoder_num_heads, | |
| int(mlp_ratio * decoder_embed_dim), | |
| qkv_bias=True, | |
| norm_cfg=norm_cfg) for _ in range(decoder_depth) | |
| ]) | |
| self.decoder_norm_name, decoder_norm = build_norm_layer( | |
| norm_cfg, decoder_embed_dim, postfix=1) | |
| self.add_module(self.decoder_norm_name, decoder_norm) | |
| # Used to map features to pixels | |
| if predict_feature_dim is None: | |
| predict_feature_dim = patch_size**2 * in_chans | |
| self.decoder_pred = nn.Linear( | |
| decoder_embed_dim, predict_feature_dim, bias=True) | |
| def init_weights(self) -> None: | |
| """Initialize position embedding and mask token of MAE decoder.""" | |
| super().init_weights() | |
| decoder_pos_embed = build_2d_sincos_position_embedding( | |
| int(self.num_patches**.5), | |
| self.decoder_pos_embed.shape[-1], | |
| cls_token=True) | |
| self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) | |
| torch.nn.init.normal_(self.mask_token, std=.02) | |
| def decoder_norm(self): | |
| """The normalization layer of decoder.""" | |
| return getattr(self, self.decoder_norm_name) | |
| def forward(self, x: torch.Tensor, | |
| ids_restore: torch.Tensor) -> torch.Tensor: | |
| """The forward function. | |
| The process computes the visible patches' features vectors and the mask | |
| tokens to output feature vectors, which will be used for | |
| reconstruction. | |
| Args: | |
| x (torch.Tensor): hidden features, which is of shape | |
| B x (L * mask_ratio) x C. | |
| ids_restore (torch.Tensor): ids to restore original image. | |
| Returns: | |
| torch.Tensor: The reconstructed feature vectors, which is of | |
| shape B x (num_patches) x C. | |
| """ | |
| # embed tokens | |
| x = self.decoder_embed(x) | |
| # append mask tokens to sequence | |
| mask_tokens = self.mask_token.repeat( | |
| x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) | |
| x_ = torch.gather( | |
| x_, | |
| dim=1, | |
| index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) | |
| x = torch.cat([x[:, :1, :], x_], dim=1) | |
| # add pos embed | |
| x = x + self.decoder_pos_embed | |
| # apply Transformer blocks | |
| for blk in self.decoder_blocks: | |
| x = blk(x) | |
| x = self.decoder_norm(x) | |
| # predictor projection | |
| x = self.decoder_pred(x) | |
| # remove cls token | |
| x = x[:, 1:, :] | |
| return x | |
| class ClsBatchNormNeck(BaseModule): | |
| """Normalize cls token across batch before head. | |
| This module is proposed by MAE, when running linear probing. | |
| Args: | |
| input_features (int): The dimension of features. | |
| affine (bool): a boolean value that when set to ``True``, this module | |
| has learnable affine parameters. Defaults to False. | |
| eps (float): a value added to the denominator for numerical stability. | |
| Defaults to 1e-6. | |
| init_cfg (Dict or List[Dict], optional): Config dict for weight | |
| initialization. Defaults to None. | |
| """ | |
| def __init__(self, | |
| input_features: int, | |
| affine: bool = False, | |
| eps: float = 1e-6, | |
| init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: | |
| super().__init__(init_cfg) | |
| self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps) | |
| def forward( | |
| self, | |
| inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]: | |
| """The forward function.""" | |
| # Only apply batch norm to cls_token | |
| inputs = [self.bn(input_) for input_ in inputs] | |
| return tuple(inputs) | |