|
from collections import OrderedDict |
|
from typing import Callable, Optional, Union |
|
from einops import rearrange |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint |
|
from timm.models.layers import to_2tuple |
|
from timm.models.layers import trunc_normal_ |
|
from timm.models.layers import DropPath |
|
|
|
from .attention_mask import get_attention_mask |
|
|
|
|
|
class LayerScale(nn.Module): |
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
n_head: int, |
|
mlp_ratio: float = 4.0, |
|
ls_init_value: float = None, |
|
drop: float = 0., |
|
attn_drop: float = 0., |
|
drop_path: float = 0., |
|
act_layer: Callable = nn.GELU, |
|
norm_layer: Callable = nn.LayerNorm, |
|
use_preln: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.ln_1 = norm_layer(d_model) |
|
self.attn = nn.MultiheadAttention(d_model, n_head, dropout=attn_drop) |
|
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
|
|
|
self.ln_2 = norm_layer(d_model) |
|
mlp_width = int(d_model * mlp_ratio) |
|
self.mlp = nn.Sequential(OrderedDict([ |
|
("c_fc", nn.Linear(d_model, mlp_width)), |
|
("gelu", act_layer()), |
|
|
|
|
|
|
|
("c_proj", nn.Linear(mlp_width, d_model)), |
|
("drop2", nn.Dropout(drop)), |
|
])) |
|
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() |
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.use_preln = use_preln |
|
|
|
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False): |
|
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None |
|
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, is_causal=is_causal)[0] |
|
|
|
def checkpoint_forward(self, x: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
is_causal: bool = False): |
|
state = x |
|
if self.use_preln: |
|
x = checkpoint(self.ln_1, x, use_reentrant=False) |
|
x = self.attention(x, attn_mask, is_causal) |
|
x = checkpoint(self.ls_1, x, use_reentrant=False) |
|
state = state + self.drop_path(x) |
|
x = checkpoint(self.ln_2, state, use_reentrant=False) |
|
x = self.mlp(x) |
|
x = checkpoint(self.ls_2, x, use_reentrant=False) |
|
state = state + self.drop_path(x) |
|
else: |
|
x = self.attention(x, attn_mask, is_causal) |
|
x = state + self.drop_path(x) |
|
state = checkpoint(self.ln_1, x, use_reentrant=False) |
|
x = self.mlp(state) |
|
state = state + self.drop_path(x) |
|
state = checkpoint(self.ln_2, state, use_reentrant=False) |
|
return state |
|
|
|
def forward(self, x: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, is_causal: bool =False, |
|
selective_checkpointing: bool = False): |
|
if selective_checkpointing: |
|
return self.checkpoint_forward(x, attn_mask, is_causal=is_causal) |
|
if self.use_preln: |
|
x = x + self.drop_path(self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal))) |
|
x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x)))) |
|
else: |
|
x = x + self.drop_path(self.attention(x, attn_mask=attn_mask, is_causal=is_causal)) |
|
x = self.ln_1(x) |
|
x = x + self.drop_path(self.mlp(x)) |
|
x = self.ln_2(x) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
mlp_ratio: float = 4.0, |
|
ls_init_value: float = None, |
|
drop: float = 0., |
|
attn_drop: float = 0., |
|
drop_path: float = 0., |
|
act_layer: nn.Module = nn.GELU, |
|
norm_layer: nn.Module = nn.LayerNorm, |
|
use_preln: bool = True, |
|
): |
|
super().__init__() |
|
self.width = width |
|
self.layers = layers |
|
self.grad_checkpointing = False |
|
self.selective_checkpointing = False |
|
self.grad_checkpointing_params = {'use_reentrant': False} |
|
if attn_drop == 0 and drop_path == 0 and drop_path == 0: |
|
self.grad_checkpointing_params.update({'preserve_rng_state': False}) |
|
else: |
|
self.grad_checkpointing_params.update({'preserve_rng_state': True}) |
|
|
|
self.resblocks = nn.ModuleList([ |
|
ResidualAttentionBlock( |
|
width, heads, mlp_ratio, ls_init_value=ls_init_value, |
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path, |
|
act_layer=act_layer, norm_layer=norm_layer, |
|
use_preln=use_preln) |
|
for _ in range(layers) |
|
]) |
|
|
|
def forward(self, x: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
is_causal: bool =False): |
|
for r in self.resblocks: |
|
if self.training and self.grad_checkpointing and not torch.jit.is_scripting(): |
|
if not self.selective_checkpointing: |
|
x = checkpoint(r, x, attn_mask, is_causal=is_causal, **self.grad_checkpointing_params) |
|
else: |
|
x = r(x, attn_mask=attn_mask, is_causal=is_causal, selective_checkpointing=True) |
|
else: |
|
x = r(x, attn_mask=attn_mask) |
|
return x |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__(self, |
|
image_size: int, |
|
patch_size: int, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
mlp_ratio: float, |
|
num_frames: int = 1, |
|
cross_frames: bool = True, |
|
ls_init_value: float = None, |
|
drop_rate: float = 0., |
|
attn_drop_rate: float = 0., |
|
drop_path_rate: float = 0., |
|
ln_pre: bool = True, |
|
ln_post: bool = True, |
|
act_layer: str = 'gelu', |
|
norm_layer: str = 'layer_norm', |
|
mask_type: Union[str, None] = 'none', |
|
mask_block_size: int = -1 |
|
): |
|
super().__init__() |
|
self.image_size = to_2tuple(image_size) |
|
self.patch_size = to_2tuple(patch_size) |
|
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) |
|
self.patches_per_frame = self.grid_size[0] * self.grid_size[1] |
|
self.mask_type = mask_type |
|
self.mask_block_size = mask_block_size |
|
|
|
if act_layer.lower() == 'gelu': |
|
self.act_layer = nn.GELU |
|
else: |
|
raise ValueError(f"Unsupported activation function: {act_layer}") |
|
if norm_layer.lower() == 'layer_norm': |
|
self.norm_layer = nn.LayerNorm |
|
else: |
|
raise ValueError(f"Unsupported normalization: {norm_layer}") |
|
|
|
self.conv1 = nn.Linear( |
|
in_features=3 * self.patch_size[0] * self.patch_size[1], |
|
out_features=width, |
|
bias=not ln_pre |
|
) |
|
|
|
scale = width ** -0.5 |
|
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width)) |
|
assert num_frames >= 1 |
|
self.num_frames = num_frames |
|
self.cross_frames = cross_frames |
|
if num_frames > 1 and cross_frames: |
|
self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width)) |
|
else: |
|
self.temporal_positional_embedding = None |
|
|
|
self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity() |
|
|
|
self.transformer = Transformer( |
|
width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, |
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, |
|
act_layer=self.act_layer, norm_layer=self.norm_layer, |
|
) |
|
|
|
self.ln_post = self.norm_layer(width) |
|
|
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
if self.positional_embedding is not None: |
|
nn.init.normal_(self.positional_embedding, std=0.02) |
|
trunc_normal_(self.conv1.weight, std=0.02) |
|
for block in self.transformer.resblocks: |
|
for n, p in block.named_parameters(): |
|
if 'weight' in n: |
|
if 'ln' not in n: |
|
trunc_normal_(p, std=0.02) |
|
elif 'bias' in n: |
|
nn.init.zeros_(p) |
|
else: |
|
raise NotImplementedError(f'Unknown parameters named {n}') |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True, selective=False): |
|
self.transformer.grad_checkpointing = enable |
|
self.transformer.selective_checkpointing = selective |
|
|
|
|
|
def forward(self, x): |
|
if self.num_frames == 1: |
|
x = rearrange( |
|
x, "b c (hh sh) (ww sw) -> b (hh ww) (c sh sw)", |
|
sh=self.patch_size[0], sw=self.patch_size[1] |
|
) |
|
x = self.conv1(x) |
|
x = x + self.positional_embedding.to(x.dtype) |
|
elif self.cross_frames: |
|
num_frames = x.shape[2] |
|
assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting' |
|
x = rearrange( |
|
x, "b c t (hh sh) (ww sw) -> b (t hh ww) (c sh sw)", |
|
sh=self.patch_size[0], sw=self.patch_size[1] |
|
) |
|
x = self.conv1(x) |
|
tile_pos_embed = self.positional_embedding.repeat(num_frames, 1) |
|
tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0) |
|
total_pos_embed = tile_pos_embed + tile_tem_embed |
|
x = x + total_pos_embed.to(x.dtype).squeeze(0) |
|
else: |
|
x = rearrange( |
|
x, "b c t (hh sh) (ww sw) -> (b t) (hh ww) (c sh sw)", |
|
sh=self.patch_size[0], sw=self.patch_size[1] |
|
) |
|
x = self.conv1(x) |
|
x = x + self.positional_embedding.to(x.dtype) |
|
|
|
x = self.ln_pre(x) |
|
x = x.permute(1, 0, 2) |
|
block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size |
|
attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size) |
|
x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal') |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_post(x) |
|
|
|
return x |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
def __init__(self, |
|
image_size: int, |
|
patch_size: int, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
mlp_ratio: float, |
|
num_frames: int = 1, |
|
cross_frames: bool = True, |
|
ls_init_value: float = None, |
|
drop_rate: float = 0., |
|
attn_drop_rate: float = 0., |
|
drop_path_rate: float = 0., |
|
ln_pre: bool = True, |
|
ln_post: bool = True, |
|
act_layer: str = 'gelu', |
|
norm_layer: str = 'layer_norm', |
|
use_ffn_output: bool = True, |
|
dim_ffn_output: int = 3072, |
|
logit_laplace: bool = False, |
|
mask_type: Union[str, None] = 'none', |
|
mask_block_size: int = -1 |
|
): |
|
super().__init__() |
|
self.image_size = to_2tuple(image_size) |
|
self.patch_size = to_2tuple(patch_size) |
|
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) |
|
self.patches_per_frame = self.grid_size[0] * self.grid_size[1] |
|
self.mask_type = mask_type |
|
self.mask_block_size = mask_block_size |
|
|
|
if act_layer.lower() == 'gelu': |
|
self.act_layer = nn.GELU |
|
else: |
|
raise ValueError(f"Unsupported activation function: {act_layer}") |
|
if norm_layer.lower() == 'layer_norm': |
|
self.norm_layer = nn.LayerNorm |
|
else: |
|
raise ValueError(f"Unsupported normalization: {norm_layer}") |
|
|
|
self.use_ffn_output = use_ffn_output |
|
if use_ffn_output: |
|
self.ffn = nn.Sequential( |
|
nn.Linear(width, dim_ffn_output), |
|
nn.Tanh(), |
|
) |
|
self.conv_out = nn.Linear( |
|
in_features=dim_ffn_output, |
|
out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace) |
|
) |
|
else: |
|
self.ffn = nn.Identity() |
|
self.conv_out = nn.Linear( |
|
in_features=width, |
|
out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace) |
|
) |
|
|
|
scale = width ** -0.5 |
|
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width)) |
|
assert num_frames >= 1 |
|
self.num_frames = num_frames |
|
self.cross_frames = cross_frames |
|
if num_frames > 1 and cross_frames: |
|
self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width)) |
|
else: |
|
self.temporal_positional_embedding = None |
|
|
|
self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity() |
|
|
|
self.transformer = Transformer( |
|
width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, |
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, |
|
act_layer=self.act_layer, norm_layer=self.norm_layer, |
|
) |
|
|
|
self.ln_post = self.norm_layer(width) if ln_post else nn.Identity() |
|
|
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
if self.positional_embedding is not None: |
|
nn.init.normal_(self.positional_embedding, std=0.02) |
|
|
|
for block in self.transformer.resblocks: |
|
for n, p in block.named_parameters(): |
|
if 'weight' in n: |
|
if 'ln' not in n: |
|
trunc_normal_(p, std=0.02) |
|
elif 'bias' in n: |
|
nn.init.zeros_(p) |
|
else: |
|
raise NotImplementedError(f'Unknown parameters named {n}') |
|
if self.use_ffn_output: |
|
trunc_normal_(self.ffn[0].weight, std=0.02) |
|
trunc_normal_(self.conv_out.weight, std=0.02) |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True, selective=False): |
|
self.transformer.grad_checkpointing = enable |
|
self.transformer.selective_checkpointing = selective |
|
|
|
def forward(self, x): |
|
if self.num_frames == 1 or not self.cross_frames: |
|
x = x + self.positional_embedding.to(x.dtype) |
|
else: |
|
num_frames = x.shape[1] // self.patches_per_frame |
|
assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting' |
|
tile_pos_embed = self.positional_embedding.repeat(num_frames, 1) |
|
tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0) |
|
total_pos_embed = tile_pos_embed + tile_tem_embed |
|
x = x + total_pos_embed.to(x.dtype).squeeze(0) |
|
x = self.ln_pre(x) |
|
x = x.permute(1, 0, 2) |
|
block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size |
|
attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size) |
|
x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal') |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_post(x) |
|
x = self.ffn(x) |
|
x = self.conv_out(x) |
|
if self.num_frames == 1: |
|
x = rearrange( |
|
x, "b (hh ww) (c sh sw) -> b c (hh sh) (ww sw)", |
|
hh = self.grid_size[0], ww=self.grid_size[1], |
|
sh=self.patch_size[0], sw=self.patch_size[1] |
|
) |
|
elif self.cross_frames: |
|
x = rearrange( |
|
x, "b (t hh ww) (c sh sw) -> b c t (hh sh) (ww sw)", |
|
t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1], |
|
sh=self.patch_size[0], sw=self.patch_size[1] |
|
) |
|
else: |
|
x = rearrange( |
|
x, "(b t) (hh ww) (c sh sw) -> b c t (hh sh) (ww sw)", |
|
t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1], |
|
sh=self.patch_size[0], sw=self.patch_size[1] |
|
) |
|
|
|
return x |
|
|