Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Portions Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Code modified from | |
| # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ; | |
| # https://github.com/facebookresearch/deit/blob/main/models.py | |
| # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py | |
| from functools import partial | |
| from typing import Callable, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint as checkpoint | |
| from timm.models.layers import DropPath, trunc_normal_ | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| # NOTE scale factor was wrong in my original version, | |
| # can set manually to be compat with prev weights | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = ( | |
| self.qkv(x) | |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| q, k, v = ( | |
| qkv[0], | |
| qkv[1], | |
| qkv[2], | |
| ) # make torchscript happy (cannot use tensor as tuple) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class Mlp(nn.Module): | |
| def __init__( | |
| self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| act_layer=nn.GELU, | |
| drop=0.0, | |
| ): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class MultiheadAttention(nn.MultiheadAttention): | |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
| return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] | |
| class ViTAttention(Attention): | |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
| assert attn_mask is None | |
| return super().forward(x) | |
| class BlockWithMasking(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| attn_target: Callable, | |
| mlp_ratio: int = 4, | |
| act_layer: Callable = nn.GELU, | |
| norm_layer: Callable = nn.LayerNorm, | |
| ffn_dropout_rate: float = 0.0, | |
| drop_path: float = 0.0, | |
| layer_scale_type: Optional[str] = None, | |
| layer_scale_init_value: float = 1e-4, | |
| ): | |
| super().__init__() | |
| assert not isinstance( | |
| attn_target, nn.Module | |
| ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!" | |
| self.attn = attn_target() | |
| if drop_path > 0.0: | |
| self.drop_path = DropPath(drop_path) | |
| else: | |
| self.drop_path = nn.Identity() | |
| self.norm_1 = norm_layer(dim) | |
| mlp_hidden_dim = int(mlp_ratio * dim) | |
| self.mlp = Mlp( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=ffn_dropout_rate, | |
| ) | |
| self.norm_2 = norm_layer(dim) | |
| self.layer_scale_type = layer_scale_type | |
| if self.layer_scale_type is not None: | |
| assert self.layer_scale_type in [ | |
| "per_channel", | |
| "scalar", | |
| ], f"Found Layer scale type {self.layer_scale_type}" | |
| if self.layer_scale_type == "per_channel": | |
| # one gamma value per channel | |
| gamma_shape = [1, 1, dim] | |
| elif self.layer_scale_type == "scalar": | |
| # single gamma value for all channels | |
| gamma_shape = [1, 1, 1] | |
| # two gammas: for each part of the fwd in the encoder | |
| self.layer_scale_gamma1 = nn.Parameter( | |
| torch.ones(size=gamma_shape) * layer_scale_init_value, | |
| requires_grad=True, | |
| ) | |
| self.layer_scale_gamma2 = nn.Parameter( | |
| torch.ones(size=gamma_shape) * layer_scale_init_value, | |
| requires_grad=True, | |
| ) | |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
| if self.layer_scale_type is None: | |
| x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) | |
| x = x + self.drop_path(self.mlp(self.norm_2(x))) | |
| else: | |
| x = ( | |
| x | |
| + self.drop_path(self.attn(self.norm_1(x), attn_mask)) | |
| # * self.layer_scale_gamma1 | |
| ) | |
| x = x + self.drop_path(self.mlp(self.norm_2(x))) # * self.layer_scale_gamma2 | |
| return x | |
| _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6) | |
| class SimpleTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| attn_target: Callable, | |
| embed_dim: int, | |
| num_blocks: int, | |
| block: Callable = BlockWithMasking, | |
| pre_transformer_layer: Optional[Callable] = None, | |
| post_transformer_layer: Optional[Callable] = None, | |
| drop_path_rate: float = 0.0, | |
| drop_path_type: str = "progressive", | |
| norm_layer: Callable = _LAYER_NORM, | |
| mlp_ratio: int = 4, | |
| ffn_dropout_rate: float = 0.0, | |
| layer_scale_type: Optional[str] = None, # from cait; possible values are None, "per_channel", "scalar" | |
| layer_scale_init_value: float = 1e-4, # from cait; float | |
| weight_init_style: str = "jax", # possible values jax or pytorch | |
| ): | |
| """ | |
| Simple Transformer with the following features | |
| 1. Supports masked attention | |
| 2. Supports DropPath | |
| 3. Supports LayerScale | |
| 4. Supports Dropout in Attention and FFN | |
| 5. Makes few assumptions about the input except that it is a Tensor | |
| """ | |
| super().__init__() | |
| self.pre_transformer_layer = pre_transformer_layer | |
| if drop_path_type == "progressive": | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] | |
| elif drop_path_type == "uniform": | |
| dpr = [drop_path_rate for i in range(num_blocks)] | |
| else: | |
| raise ValueError(f"Unknown drop_path_type: {drop_path_type}") | |
| self.blocks = nn.Sequential( | |
| *[ | |
| block( | |
| dim=embed_dim, | |
| attn_target=attn_target, | |
| mlp_ratio=mlp_ratio, | |
| ffn_dropout_rate=ffn_dropout_rate, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer, | |
| layer_scale_type=layer_scale_type, | |
| layer_scale_init_value=layer_scale_init_value, | |
| ) | |
| for i in range(num_blocks) | |
| ] | |
| ) | |
| self.post_transformer_layer = post_transformer_layer | |
| self.weight_init_style = weight_init_style | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| if self.weight_init_style == "jax": | |
| # Based on MAE and official Jax ViT implementation | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| elif self.weight_init_style == "pytorch": | |
| # PyTorch ViT uses trunc_normal_ | |
| trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, (nn.LayerNorm)): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| attn_mask: torch.Tensor = None, | |
| use_checkpoint: bool = False, | |
| checkpoint_every_n: int = 1, | |
| checkpoint_blk_ids: Optional[List[int]] = None, | |
| # return_multi_layer_outputs = False, | |
| out_layers = [] | |
| ): | |
| """ | |
| Inputs | |
| - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation) | |
| - attn: mask of shape L x L | |
| Output | |
| - x: data of shape N x L x D (or L x N x D depending on the attention implementation) | |
| """ | |
| out_tokens = [] | |
| if self.pre_transformer_layer: | |
| tokens = self.pre_transformer_layer(tokens) | |
| if use_checkpoint and checkpoint_blk_ids is None: | |
| checkpoint_blk_ids = [ | |
| blk_id | |
| for blk_id in range(len(self.blocks)) | |
| if blk_id % checkpoint_every_n == 0 | |
| ] | |
| if checkpoint_blk_ids: | |
| checkpoint_blk_ids = set(checkpoint_blk_ids) | |
| for blk_id, blk in enumerate(self.blocks): | |
| if use_checkpoint and blk_id in checkpoint_blk_ids: | |
| tokens = checkpoint.checkpoint( | |
| blk, tokens, attn_mask, use_reentrant=False | |
| ) | |
| else: | |
| tokens = blk(tokens, attn_mask=attn_mask) | |
| if blk_id in out_layers: | |
| out_tokens.append(tokens) | |
| if self.post_transformer_layer: | |
| tokens = self.post_transformer_layer(tokens) | |
| return tokens, out_tokens | |