Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import build_norm_layer | |
| from mmengine.model import BaseModule | |
| from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer | |
| from mmpretrain.registry import MODELS | |
| class BEiTV2Neck(BaseModule): | |
| """Neck for BEiTV2 Pre-training. | |
| This module construct the decoder for the final prediction. | |
| Args: | |
| num_layers (int): Number of encoder layers of neck. Defaults to 2. | |
| early_layers (int): The layer index of the early output from the | |
| backbone. Defaults to 9. | |
| backbone_arch (str): Vision Transformer architecture. Defaults to base. | |
| drop_rate (float): Probability of an element to be zeroed. | |
| Defaults to 0. | |
| drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
| layer_scale_init_value (float): The initialization value for the | |
| learnable scaling of attention and FFN. Defaults to 0.1. | |
| use_rel_pos_bias (bool): Whether to use unique relative position bias, | |
| if False, use shared relative position bias defined in backbone. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Defaults to ``dict(type='LN')``. | |
| init_cfg (dict, optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| arch_zoo = { | |
| **dict.fromkeys( | |
| ['b', 'base'], { | |
| 'embed_dims': 768, | |
| 'depth': 12, | |
| 'num_heads': 12, | |
| 'feedforward_channels': 3072, | |
| }), | |
| **dict.fromkeys( | |
| ['l', 'large'], { | |
| 'embed_dims': 1024, | |
| 'depth': 24, | |
| 'num_heads': 16, | |
| 'feedforward_channels': 4096, | |
| }), | |
| } | |
| def __init__( | |
| self, | |
| num_layers: int = 2, | |
| early_layers: int = 9, | |
| backbone_arch: str = 'base', | |
| drop_rate: float = 0., | |
| drop_path_rate: float = 0., | |
| layer_scale_init_value: float = 0.1, | |
| use_rel_pos_bias: bool = False, | |
| norm_cfg: dict = dict(type='LN', eps=1e-6), | |
| init_cfg: Optional[Union[dict, List[dict]]] = dict( | |
| type='TruncNormal', layer='Linear', std=0.02, bias=0) | |
| ) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| if isinstance(backbone_arch, str): | |
| backbone_arch = backbone_arch.lower() | |
| assert backbone_arch in set(self.arch_zoo), \ | |
| (f'Arch {backbone_arch} is not in default archs ' | |
| f'{set(self.arch_zoo)}') | |
| self.arch_settings = self.arch_zoo[backbone_arch] | |
| else: | |
| essential_keys = { | |
| 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' | |
| } | |
| assert isinstance(backbone_arch, dict) and essential_keys <= set( | |
| backbone_arch | |
| ), f'Custom arch needs a dict with keys {essential_keys}' | |
| self.arch_settings = backbone_arch | |
| # stochastic depth decay rule | |
| self.early_layers = early_layers | |
| depth = self.arch_settings['depth'] | |
| dpr = np.linspace(0, drop_path_rate, | |
| max(depth, early_layers + num_layers)) | |
| self.patch_aggregation = nn.ModuleList() | |
| for i in range(early_layers, early_layers + num_layers): | |
| _layer_cfg = dict( | |
| embed_dims=self.arch_settings['embed_dims'], | |
| num_heads=self.arch_settings['num_heads'], | |
| feedforward_channels=self. | |
| arch_settings['feedforward_channels'], | |
| drop_rate=drop_rate, | |
| drop_path_rate=dpr[i], | |
| norm_cfg=norm_cfg, | |
| layer_scale_init_value=layer_scale_init_value, | |
| window_size=None, | |
| use_rel_pos_bias=use_rel_pos_bias) | |
| self.patch_aggregation.append( | |
| BEiTTransformerEncoderLayer(**_layer_cfg)) | |
| self.rescale_patch_aggregation_init_weight() | |
| embed_dims = self.arch_settings['embed_dims'] | |
| _, norm = build_norm_layer(norm_cfg, embed_dims) | |
| self.add_module('norm', norm) | |
| def rescale_patch_aggregation_init_weight(self): | |
| """Rescale the initialized weights.""" | |
| def rescale(param, layer_id): | |
| param.div_(math.sqrt(2.0 * layer_id)) | |
| for layer_id, layer in enumerate(self.patch_aggregation): | |
| rescale(layer.attn.proj.weight.data, | |
| self.early_layers + layer_id + 1) | |
| rescale(layer.ffn.layers[1].weight.data, | |
| self.early_layers + layer_id + 1) | |
| def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor, | |
| **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Get the latent prediction and final prediction. | |
| Args: | |
| x (Tuple[torch.Tensor]): Features of tokens. | |
| rel_pos_bias (torch.Tensor): Shared relative position bias table. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: | |
| - ``x``: The final layer features from backbone, which are normed | |
| in ``BEiTV2Neck``. | |
| - ``x_cls_pt``: The early state features from backbone, which are | |
| consist of final layer cls_token and early state patch_tokens | |
| from backbone and sent to PatchAggregation layers in the neck. | |
| """ | |
| early_states, x = inputs[0], inputs[1] | |
| x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1) | |
| for layer in self.patch_aggregation: | |
| x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias) | |
| # shared norm | |
| x, x_cls_pt = self.norm(x), self.norm(x_cls_pt) | |
| # remove cls_token | |
| x = x[:, 1:] | |
| x_cls_pt = x_cls_pt[:, 1:] | |
| return x, x_cls_pt | |