piaspace's picture
[first]
bb3e610
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from einops import rearrange
from ..utils import Conv2dNormActivation, MLP
from ..utils import _log_api_usage_once
weights = {
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth",
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth",
"vit_h_14": "https://download.pytorch.org/models/vit_h_14-6kbcf7eb.pth",
}
class ConvStemConfig(NamedTuple):
out_channels: int
kernel_size: int
stride: int
norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
activation_layer: Callable[..., nn.Module] = nn.ReLU
class MLPBlock(MLP):
"""Transformer MLP block."""
_version = 2
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
num_h_patches: int,
num_w_patches: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_h_patches = num_h_patches
self.num_w_patches = num_w_patches
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
seq_length = num_h_patches * num_w_patches + 1 # +1 for the class token
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def _get_pos_embedding(self, n_h: int, n_w: int) -> Tensor:
if n_h == self.num_h_patches and n_w == self.num_w_patches:
return self.pos_embedding
else:
pos_embedding = self.pos_embedding[:, 1:, :]
pos_embedding = rearrange(pos_embedding, "1 (h w) d -> 1 d h w", h=self.num_h_patches, w=self.num_w_patches)
pos_embedding = F.interpolate(pos_embedding, size=(n_h, n_w), mode="bicubic")
pos_embedding = rearrange(pos_embedding, "1 d h w -> 1 (h w) d")
return torch.cat([self.pos_embedding[:, :1, :], pos_embedding], dim=1)
def forward(self, input: Tensor, n_h: int, n_w: int) -> Tensor:
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self._get_pos_embedding(n_h, n_w)
return self.ln(self.layers(self.dropout(input)))
class VisionTransformer(nn.Module):
"""Vision Transformer as a feature extractor."""
def __init__(
self,
image_size: int,
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
# num_classes: int = 1000, # No need for the classification head as we only need the features
reduction: Optional[int] = None,
representation_size: Optional[int] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
super().__init__()
_log_api_usage_once(self)
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.attention_dropout = attention_dropout
self.dropout = dropout
# self.num_classes = num_classes
self.representation_size = representation_size
self.norm_layer = norm_layer
if conv_stem_configs is not None:
# As per https://arxiv.org/abs/2106.14881
seq_proj = nn.Sequential()
prev_channels = 3
for i, conv_stem_layer_config in enumerate(conv_stem_configs):
seq_proj.add_module(
f"conv_bn_relu_{i}",
Conv2dNormActivation(
in_channels=prev_channels,
out_channels=conv_stem_layer_config.out_channels,
kernel_size=conv_stem_layer_config.kernel_size,
stride=conv_stem_layer_config.stride,
norm_layer=conv_stem_layer_config.norm_layer,
activation_layer=conv_stem_layer_config.activation_layer,
),
)
prev_channels = conv_stem_layer_config.out_channels
seq_proj.add_module(
"conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
)
self.conv_proj: nn.Module = seq_proj
else:
self.conv_proj = nn.Conv2d(
in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
)
seq_length = (image_size // patch_size) ** 2
# Add a class token
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
seq_length += 1
self.encoder = Encoder(
image_size // patch_size,
image_size // patch_size,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.seq_length = seq_length
# heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
# if representation_size is None:
# heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
# else:
# heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
# heads_layers["act"] = nn.Tanh()
# heads_layers["head"] = nn.Linear(representation_size, num_classes)
# self.heads = nn.Sequential(heads_layers)
if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
if self.conv_proj.bias is not None:
nn.init.zeros_(self.conv_proj.bias)
elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
# Init the last 1x1 conv of the conv stem
nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
)
if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)
# if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
# fan_in = self.heads.pre_logits.in_features
# nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
# nn.init.zeros_(self.heads.pre_logits.bias)
# if isinstance(self.heads.head, nn.Linear):
# nn.init.zeros_(self.heads.head.weight)
# nn.init.zeros_(self.heads.head.bias)
self.encoder_reduction = self.patch_size
self.reduction = self.encoder_reduction if reduction is None else reduction
self.channels = hidden_dim
def _process_input(self, x: Tensor) -> Tuple[Tensor, int, int, int]:
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
n, _, n_h, n_w = x.shape
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
# The self attention layer expects inputs in the format (N, S, E)
# where S is the source sequence length, N is the batch size, E is the
# embedding dimension
x = x.permute(0, 2, 1)
return x, n, n_h, n_w
def forward(self, x: Tensor) -> Tensor:
# Reshape and permute the input tensor
x, n, n_h, n_w = self._process_input(x)
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = self.encoder(x, n_h, n_w) # Allows input image to be of any size.
# Classifier "token" as used by standard language architectures
# x = x[:, 0]
# x = self.heads(x)
x = x[:, 1:, :]
x = rearrange(x, "n (h w) d -> n d h w", h=n_h, w=n_w)
if self.encoder_reduction != self.reduction:
x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
return x # To be consistent with timm models
def _vision_transformer(
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
weights: str,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
model = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
**kwargs,
)
if weights is not None:
weights = load_state_dict_from_url(weights, progress=kwargs.get("progress", True))
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
if len(missing_keys) > 0:
print(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
print(f"Unexpected keys: {unexpected_keys}")
return model
def interpolate_embeddings(
image_size: int,
patch_size: int,
pos_embedding: Tensor,
interpolation_mode: str = "bicubic",
) -> Tensor:
"""This function helps interpolate positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
Args:
image_size (int): Image size of the new model.
patch_size (int): Patch size of the new model.
model_state (OrderedDict[str, Tensor]): State dict of the pre-trained model.
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
reset_heads (bool): If true, not copying the state of heads. Default: False.
Returns:
Tensor: The interpolated positional embedding.
"""
# Shape of pos_embedding is (1, seq_length, hidden_dim)
n, seq_length, hidden_dim = pos_embedding.shape
if n != 1:
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
new_seq_length = (image_size // patch_size) ** 2 + 1
# Need to interpolate the weights for the position embedding.
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if new_seq_length != seq_length:
# The class token embedding shouldn't be interpolated, so we split it up.
seq_length -= 1
new_seq_length -= 1
pos_embedding_token = pos_embedding[:, :1, :]
pos_embedding_img = pos_embedding[:, 1:, :]
# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
seq_length_1d = int(math.sqrt(seq_length))
if seq_length_1d * seq_length_1d != seq_length:
raise ValueError(
f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
)
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
new_seq_length_1d = image_size // patch_size
# Perform interpolation.
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
new_pos_embedding_img = nn.functional.interpolate(
pos_embedding_img,
size=new_seq_length_1d,
mode=interpolation_mode,
)
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
return new_pos_embedding
return pos_embedding
def vit_b_16(
image_size: int = 224,
reduction: int = 16,
**kwargs: Any,
) -> VisionTransformer:
vit = _vision_transformer(
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights["vit_b_16"],
reduction=reduction,
**kwargs,
)
if image_size != 224:
vit.image_size = image_size
new_pos_embedding = interpolate_embeddings(image_size, 16, vit.state_dict()["encoder.pos_embedding"], "bicubic")
vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
return vit
def vit_b_32(
image_size: int = 224,
reduction: int = 32,
**kwargs: Any,
) -> VisionTransformer:
vit = _vision_transformer(
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights["vit_b_32"],
reduction=reduction,
**kwargs,
)
if image_size != 224:
vit.image_size = image_size
new_pos_embedding = interpolate_embeddings(image_size, 32, vit.state_dict()["encoder.pos_embedding"], "bicubic")
vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
return vit
def vit_l_16(
image_size: int = 224,
reduction: int = 16,
**kwargs: Any,
) -> VisionTransformer:
vit = _vision_transformer(
patch_size=16,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights["vit_l_16"],
reduction=reduction,
**kwargs,
)
if image_size != 224:
vit.image_size = image_size
new_pos_embedding = interpolate_embeddings(image_size, 16, vit.state_dict()["encoder.pos_embedding"], "bicubic")
vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
return vit
def vit_l_32(
image_size: int = 224,
reduction: int = 32,
**kwargs: Any,
) -> VisionTransformer:
vit = _vision_transformer(
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights["vit_l_32"],
reduction=reduction,
**kwargs,
)
if image_size != 224:
vit.image_size = image_size
new_pos_embedding = interpolate_embeddings(image_size, 32, vit.state_dict()["encoder.pos_embedding"], "bicubic")
vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
return vit
def vit_h_14(
image_size: int = 224,
reduction: int = 14,
**kwargs: Any,
) -> VisionTransformer:
vit = _vision_transformer(
patch_size=14,
num_layers=32,
num_heads=16,
hidden_dim=1280,
mlp_dim=5120,
weights=weights["vit_h_14"],
reduction=reduction,
**kwargs,
)
if image_size != 224:
vit.image_size = image_size
new_pos_embedding = interpolate_embeddings(image_size, 14, vit.state_dict()["encoder.pos_embedding"], "bicubic")
vit.encoder.pos_embedding = nn.Parameter(new_pos_embedding, requires_grad=True)
return vit