|
import torch |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint |
|
from timm.models.vision_transformer import Block |
|
from functools import partial |
|
|
|
|
|
class MARDecoder(nn.Module): |
|
""" Masked Autoencoder with VisionTransformer backbone |
|
""" |
|
def __init__(self, img_size=256, vae_stride=16, |
|
patch_size=1, |
|
|
|
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, |
|
mlp_ratio=4., |
|
attn_dropout=0.1, |
|
proj_dropout=0.1, |
|
buffer_size=64, |
|
grad_checkpointing=False, |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
self.img_size = img_size |
|
self.vae_stride = vae_stride |
|
|
|
self.seq_h = self.seq_w = img_size // vae_stride // patch_size |
|
self.seq_len = self.seq_h * self.seq_w |
|
|
|
self.grad_checkpointing = grad_checkpointing |
|
|
|
|
|
|
|
self.buffer_size = buffer_size |
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) |
|
self.decoder_blocks = nn.ModuleList([ |
|
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) |
|
|
|
self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6) |
|
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
|
|
torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
|
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) |
|
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
if m.weight is not None: |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def forward(self, x, mask): |
|
|
|
|
|
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) |
|
|
|
|
|
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) |
|
x_after_pad = mask_tokens.clone() |
|
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) |
|
|
|
|
|
x = x_after_pad + self.decoder_pos_embed_learned |
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
for block in self.decoder_blocks: |
|
x = checkpoint(block, x) |
|
else: |
|
for block in self.decoder_blocks: |
|
x = block(x) |
|
x = self.decoder_norm(x) |
|
|
|
x = x[:, self.buffer_size:] |
|
x = x + self.diffusion_pos_embed_learned |
|
return x |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.grad_checkpointing = True |
|
|
|
def gradient_checkpointing_disable(self): |
|
self.grad_checkpointing = False |
|
|