VTBench / src /vqvaes /bsqvit /transformer.py
huaweilin's picture
update
14ce5a9
from collections import OrderedDict
from typing import Callable, Optional, Union
from einops import rearrange
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.models.layers import to_2tuple
from timm.models.layers import trunc_normal_
from timm.models.layers import DropPath
from .attention_mask import get_attention_mask
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
use_preln: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head, dropout=attn_drop)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
# disable this following JAX implementation.
# Reference: https://github.com/google-research/magvit/blob/main/videogvt/models/simplified_bert.py#L112
# ("drop1", nn.Dropout(drop)),
("c_proj", nn.Linear(mlp_width, d_model)),
("drop2", nn.Dropout(drop)),
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.use_preln = use_preln
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False):
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, is_causal=is_causal)[0]
def checkpoint_forward(self, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False):
state = x
if self.use_preln:
x = checkpoint(self.ln_1, x, use_reentrant=False)
x = self.attention(x, attn_mask, is_causal)
x = checkpoint(self.ls_1, x, use_reentrant=False)
state = state + self.drop_path(x)
x = checkpoint(self.ln_2, state, use_reentrant=False)
x = self.mlp(x)
x = checkpoint(self.ls_2, x, use_reentrant=False)
state = state + self.drop_path(x)
else:
x = self.attention(x, attn_mask, is_causal)
x = state + self.drop_path(x)
state = checkpoint(self.ln_1, x, use_reentrant=False)
x = self.mlp(state)
state = state + self.drop_path(x)
state = checkpoint(self.ln_2, state, use_reentrant=False)
return state
def forward(self, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, is_causal: bool =False,
selective_checkpointing: bool = False):
if selective_checkpointing:
return self.checkpoint_forward(x, attn_mask, is_causal=is_causal)
if self.use_preln:
x = x + self.drop_path(self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)))
x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x))))
else:
x = x + self.drop_path(self.attention(x, attn_mask=attn_mask, is_causal=is_causal))
x = self.ln_1(x)
x = x + self.drop_path(self.mlp(x))
x = self.ln_2(x)
return x
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
use_preln: bool = True,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.selective_checkpointing = False
self.grad_checkpointing_params = {'use_reentrant': False}
if attn_drop == 0 and drop_path == 0 and drop_path == 0:
self.grad_checkpointing_params.update({'preserve_rng_state': False})
else:
self.grad_checkpointing_params.update({'preserve_rng_state': True})
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width, heads, mlp_ratio, ls_init_value=ls_init_value,
drop=drop, attn_drop=attn_drop, drop_path=drop_path,
act_layer=act_layer, norm_layer=norm_layer,
use_preln=use_preln)
for _ in range(layers)
])
def forward(self, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool =False):
for r in self.resblocks:
if self.training and self.grad_checkpointing and not torch.jit.is_scripting():
if not self.selective_checkpointing:
x = checkpoint(r, x, attn_mask, is_causal=is_causal, **self.grad_checkpointing_params)
else:
x = r(x, attn_mask=attn_mask, is_causal=is_causal, selective_checkpointing=True)
else:
x = r(x, attn_mask=attn_mask)
return x
class TransformerEncoder(nn.Module):
def __init__(self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
num_frames: int = 1,
cross_frames: bool = True,
ls_init_value: float = None,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
ln_pre: bool = True,
ln_post: bool = True,
act_layer: str = 'gelu',
norm_layer: str = 'layer_norm',
mask_type: Union[str, None] = 'none',
mask_block_size: int = -1
):
super().__init__()
self.image_size = to_2tuple(image_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
self.mask_type = mask_type
self.mask_block_size = mask_block_size
if act_layer.lower() == 'gelu':
self.act_layer = nn.GELU
else:
raise ValueError(f"Unsupported activation function: {act_layer}")
if norm_layer.lower() == 'layer_norm':
self.norm_layer = nn.LayerNorm
else:
raise ValueError(f"Unsupported normalization: {norm_layer}")
self.conv1 = nn.Linear(
in_features=3 * self.patch_size[0] * self.patch_size[1],
out_features=width,
bias=not ln_pre
)
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
assert num_frames >= 1
self.num_frames = num_frames
self.cross_frames = cross_frames
if num_frames > 1 and cross_frames:
self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width))
else:
self.temporal_positional_embedding = None
self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()
self.transformer = Transformer(
width, layers, heads, mlp_ratio, ls_init_value=ls_init_value,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
act_layer=self.act_layer, norm_layer=self.norm_layer,
)
self.ln_post = self.norm_layer(width)
self.init_parameters()
def init_parameters(self):
if self.positional_embedding is not None:
nn.init.normal_(self.positional_embedding, std=0.02)
trunc_normal_(self.conv1.weight, std=0.02)
for block in self.transformer.resblocks:
for n, p in block.named_parameters():
if 'weight' in n:
if 'ln' not in n:
trunc_normal_(p, std=0.02)
elif 'bias' in n:
nn.init.zeros_(p)
else:
raise NotImplementedError(f'Unknown parameters named {n}')
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True, selective=False):
self.transformer.grad_checkpointing = enable
self.transformer.selective_checkpointing = selective
def forward(self, x):
if self.num_frames == 1:
x = rearrange(
x, "b c (hh sh) (ww sw) -> b (hh ww) (c sh sw)",
sh=self.patch_size[0], sw=self.patch_size[1]
)
x = self.conv1(x)
x = x + self.positional_embedding.to(x.dtype)
elif self.cross_frames:
num_frames = x.shape[2]
assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting'
x = rearrange(
x, "b c t (hh sh) (ww sw) -> b (t hh ww) (c sh sw)",
sh=self.patch_size[0], sw=self.patch_size[1]
)
x = self.conv1(x)
tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0)
total_pos_embed = tile_pos_embed + tile_tem_embed
x = x + total_pos_embed.to(x.dtype).squeeze(0)
else:
x = rearrange(
x, "b c t (hh sh) (ww sw) -> (b t) (hh ww) (c sh sw)",
sh=self.patch_size[0], sw=self.patch_size[1]
)
x = self.conv1(x)
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2)
block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size
attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size)
x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal')
x = x.permute(1, 0, 2)
x = self.ln_post(x)
return x
class TransformerDecoder(nn.Module):
def __init__(self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
num_frames: int = 1,
cross_frames: bool = True,
ls_init_value: float = None,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
ln_pre: bool = True,
ln_post: bool = True,
act_layer: str = 'gelu',
norm_layer: str = 'layer_norm',
use_ffn_output: bool = True,
dim_ffn_output: int = 3072,
logit_laplace: bool = False,
mask_type: Union[str, None] = 'none',
mask_block_size: int = -1
):
super().__init__()
self.image_size = to_2tuple(image_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
self.mask_type = mask_type
self.mask_block_size = mask_block_size
if act_layer.lower() == 'gelu':
self.act_layer = nn.GELU
else:
raise ValueError(f"Unsupported activation function: {act_layer}")
if norm_layer.lower() == 'layer_norm':
self.norm_layer = nn.LayerNorm
else:
raise ValueError(f"Unsupported normalization: {norm_layer}")
self.use_ffn_output = use_ffn_output
if use_ffn_output:
self.ffn = nn.Sequential(
nn.Linear(width, dim_ffn_output),
nn.Tanh(),
)
self.conv_out = nn.Linear(
in_features=dim_ffn_output,
out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace)
)
else:
self.ffn = nn.Identity()
self.conv_out = nn.Linear(
in_features=width,
out_features=3 * self.patch_size[0] * self.patch_size[1] * (1 + logit_laplace)
)
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1], width))
assert num_frames >= 1
self.num_frames = num_frames
self.cross_frames = cross_frames
if num_frames > 1 and cross_frames:
self.temporal_positional_embedding = nn.Parameter(torch.zeros(num_frames, width))
else:
self.temporal_positional_embedding = None
self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()
self.transformer = Transformer(
width, layers, heads, mlp_ratio, ls_init_value=ls_init_value,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate,
act_layer=self.act_layer, norm_layer=self.norm_layer,
)
self.ln_post = self.norm_layer(width) if ln_post else nn.Identity()
self.init_parameters()
def init_parameters(self):
if self.positional_embedding is not None:
nn.init.normal_(self.positional_embedding, std=0.02)
for block in self.transformer.resblocks:
for n, p in block.named_parameters():
if 'weight' in n:
if 'ln' not in n:
trunc_normal_(p, std=0.02)
elif 'bias' in n:
nn.init.zeros_(p)
else:
raise NotImplementedError(f'Unknown parameters named {n}')
if self.use_ffn_output:
trunc_normal_(self.ffn[0].weight, std=0.02)
trunc_normal_(self.conv_out.weight, std=0.02)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True, selective=False):
self.transformer.grad_checkpointing = enable
self.transformer.selective_checkpointing = selective
def forward(self, x):
if self.num_frames == 1 or not self.cross_frames:
x = x + self.positional_embedding.to(x.dtype)
else:
num_frames = x.shape[1] // self.patches_per_frame
assert num_frames <= self.num_frames, 'Number of frames should be less or equal to the model setting'
tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
tile_tem_embed = self.temporal_positional_embedding[:num_frames].repeat_interleave(self.patches_per_frame, 0)
total_pos_embed = tile_pos_embed + tile_tem_embed
x = x + total_pos_embed.to(x.dtype).squeeze(0)
x = self.ln_pre(x)
x = x.permute(1, 0, 2)
block_size = self.grid_size[0] * self.grid_size[1] if self.mask_block_size <= 0 else self.mask_block_size
attn_mask = get_attention_mask(x.size(0), x.device, mask_type=self.mask_type, block_size=block_size)
x = self.transformer(x, attn_mask, is_causal=self.mask_type == 'causal')
x = x.permute(1, 0, 2)
x = self.ln_post(x)
x = self.ffn(x)
x = self.conv_out(x)
if self.num_frames == 1:
x = rearrange(
x, "b (hh ww) (c sh sw) -> b c (hh sh) (ww sw)",
hh = self.grid_size[0], ww=self.grid_size[1],
sh=self.patch_size[0], sw=self.patch_size[1]
)
elif self.cross_frames:
x = rearrange(
x, "b (t hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1],
sh=self.patch_size[0], sw=self.patch_size[1]
)
else:
x = rearrange(
x, "(b t) (hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
t = num_frames, hh = self.grid_size[0], ww=self.grid_size[1],
sh=self.patch_size[0], sw=self.patch_size[1]
)
return x