import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from timm.models.vision_transformer import Block from functools import partial class MARDecoder(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=256, vae_stride=16, patch_size=1, # encoder_embed_dim=1024, decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, mlp_ratio=4., attn_dropout=0.1, proj_dropout=0.1, buffer_size=64, grad_checkpointing=False, ): super().__init__() # -------------------------------------------------------------------------- # VAE self.img_size = img_size self.vae_stride = vae_stride self.seq_h = self.seq_w = img_size // vae_stride // patch_size self.seq_len = self.seq_h * self.seq_w self.grad_checkpointing = grad_checkpointing # -------------------------------------------------------------------------- # MAR decoder specifics self.buffer_size = buffer_size # self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) self.decoder_norm = nn.LayerNorm(decoder_embed_dim, eps=1e-6) self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) self.initialize_weights() def initialize_weights(self): # parameters torch.nn.init.normal_(self.mask_token, std=.02) torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) def forward(self, x, mask): # x = self.decoder_embed(x) mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) # pad mask tokens mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) x_after_pad = mask_tokens.clone() x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) # decoder position embedding x = x_after_pad + self.decoder_pos_embed_learned # apply Transformer blocks if self.grad_checkpointing and not torch.jit.is_scripting(): for block in self.decoder_blocks: x = checkpoint(block, x) else: for block in self.decoder_blocks: x = block(x) x = self.decoder_norm(x) x = x[:, self.buffer_size:] x = x + self.diffusion_pos_embed_learned return x def gradient_checkpointing_enable(self): self.grad_checkpointing = True def gradient_checkpointing_disable(self): self.grad_checkpointing = False