File size: 3,990 Bytes
ea88892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.models.vision_transformer import Block
from functools import partial
class MARDecoder(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=256, vae_stride=16,
patch_size=1,
# encoder_embed_dim=1024,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4.,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
grad_checkpointing=False,
):
super().__init__()
# --------------------------------------------------------------------------
# VAE
self.img_size = img_size
self.vae_stride = vae_stride
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
self.seq_len = self.seq_h * self.seq_w
self.grad_checkpointing = grad_checkpointing
# --------------------------------------------------------------------------
# MAR decoder specifics
self.buffer_size = buffer_size
# self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6)
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
self.initialize_weights()
def initialize_weights(self):
# parameters
torch.nn.init.normal_(self.mask_token, std=.02)
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def forward(self, x, mask):
# x = self.decoder_embed(x)
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
# pad mask tokens
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
x_after_pad = mask_tokens.clone()
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
# decoder position embedding
x = x_after_pad + self.decoder_pos_embed_learned
# apply Transformer blocks
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.decoder_blocks:
x = checkpoint(block, x)
else:
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = x[:, self.buffer_size:]
x = x + self.diffusion_pos_embed_learned
return x
def gradient_checkpointing_enable(self):
self.grad_checkpointing = True
def gradient_checkpointing_disable(self):
self.grad_checkpointing = False
|