|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import os |
|
from timm.models.vision_transformer import Block |
|
import torch.nn.functional as F |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" 2D Image to Patch Embedding |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): |
|
super().__init__() |
|
img_size = (img_size, img_size) |
|
patch_size = (patch_size, patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
self.flatten = flatten |
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
|
|
x = self.proj(x) |
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000 ** omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum('m,d->md', pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_pos_embed(model, checkpoint_model): |
|
if 'pos_embed' in checkpoint_model: |
|
pos_embed_checkpoint = checkpoint_model['pos_embed'] |
|
embedding_size = pos_embed_checkpoint.shape[-1] |
|
num_patches = model.patch_embed.num_patches |
|
num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
|
|
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
|
|
|
new_size = int(num_patches ** 0.5) |
|
|
|
if orig_size != new_size: |
|
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
|
pos_tokens = torch.nn.functional.interpolate( |
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
checkpoint_model['pos_embed'] = new_pos_embed |
|
|
|
|
|
def sample(t: torch.Tensor, coords: torch.Tensor): |
|
return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True) |
|
|
|
|
|
class MaskedAutoencoderViT(nn.Module): |
|
""" Masked Autoencoder with VisionTransformer backbone |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, |
|
embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): |
|
super().__init__() |
|
|
|
|
|
|
|
self.embed_dim = embed_dim |
|
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), |
|
requires_grad=False) |
|
|
|
self.blocks = nn.ModuleList([ |
|
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(depth)]) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
|
|
|
|
|
|
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
|
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), |
|
requires_grad=False) |
|
|
|
self.decoder_blocks = nn.ModuleList([ |
|
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(decoder_depth)]) |
|
|
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) |
|
|
|
|
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), |
|
cls_token=True) |
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], |
|
int(self.patch_embed.num_patches ** .5), cls_token=True) |
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
torch.nn.init.normal_(self.cls_token, std=.02) |
|
torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def patchify(self, imgs): |
|
""" |
|
imgs: (N, 3, H, W) |
|
x: (N, L, patch_size**2 *3) |
|
""" |
|
p = self.patch_embed.patch_size[0] |
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
|
h = w = imgs.shape[2] // p |
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) |
|
return x |
|
|
|
def unpatchify(self, x): |
|
""" |
|
x: (N, L, patch_size**2 *3) |
|
imgs: (N, 3, H, W) |
|
""" |
|
p = self.patch_embed.patch_size[0] |
|
h = w = int(x.shape[1] ** .5) |
|
assert h * w == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
|
return imgs |
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore |
|
|
|
def sample_pe(self, img, pe): |
|
p = self.patch_embed.patch_size[0] |
|
|
|
H = img.shape[2] // p |
|
W = img.shape[3] // p |
|
|
|
original_num_patches = 224 // p |
|
embed_dim = pe.shape[-1] |
|
|
|
reshaped_pe = pe.squeeze(0)[1:] \ |
|
.reshape(1, original_num_patches, original_num_patches, embed_dim) \ |
|
.permute(0, 3, 1, 2) |
|
|
|
XX, YY = torch.meshgrid(torch.linspace(-1, 1, H, device=img.device, dtype=img.dtype), |
|
torch.linspace(-1, 1, W, device=img.device, dtype=img.dtype)) |
|
|
|
coords = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim=-1).unsqueeze(0) |
|
|
|
return sample(reshaped_pe, coords).reshape(embed_dim, H * W).permute(1, 0).unsqueeze(0) |
|
|
|
def featurize(self, img, n_decoder_blocks=None): |
|
p = self.patch_embed.patch_size[0] |
|
H = img.shape[2] // p |
|
W = img.shape[3] // p |
|
|
|
|
|
x = self.patch_embed(img) |
|
|
|
|
|
x = x + self.sample_pe(img, self.pos_embed) |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x[:, 1:, :].reshape(shape=(x.shape[0], H, W, -1)) \ |
|
.permute(0, 3, 1, 2), x[:, 0, :] |
|
|
|
def forward_encoder(self, img, mask_ratio): |
|
|
|
x = self.patch_embed(img) |
|
|
|
|
|
x = x + self.sample_pe(img, self.pos_embed) |
|
|
|
|
|
|
|
x, mask, ids_restore = self.random_masking(x, mask_ratio) |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
x = self.norm(x) |
|
|
|
return x, mask, ids_restore |
|
|
|
def forward_decoder(self, x, ids_restore, img): |
|
|
|
x = self.decoder_embed(x) |
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
cls_token = x[:, :1] + self.decoder_pos_embed[0, :1] |
|
x = x[:, 1:] + self.sample_pe(img, self.decoder_pos_embed) |
|
x = torch.cat((cls_token, x), dim=1) |
|
print("foo") |
|
|
|
|
|
for blk in self.decoder_blocks: |
|
x = blk(x) |
|
x = self.decoder_norm(x) |
|
|
|
|
|
x = self.decoder_pred(x) |
|
|
|
|
|
x = x[:, 1:, :] |
|
|
|
return x |
|
|
|
def forward_loss(self, imgs, pred, mask): |
|
""" |
|
imgs: [N, 3, H, W] |
|
pred: [N, L, p*p*3] |
|
mask: [N, L], 0 is keep, 1 is remove, |
|
""" |
|
target = self.patchify(imgs) |
|
if self.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
target = (target - mean) / (var + 1.e-6) ** .5 |
|
|
|
loss = (pred - target) ** 2 |
|
loss = loss.mean(dim=-1) |
|
|
|
loss = (loss * mask).sum() / mask.sum() |
|
return loss |
|
|
|
def forward(self, imgs, mask_ratio=0.75): |
|
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) |
|
pred = self.forward_decoder(latent, ids_restore, imgs) |
|
loss = self.forward_loss(imgs, pred, mask) |
|
return loss, pred, mask |
|
|
|
|
|
class MAEFeaturizer(nn.Module): |
|
|
|
def __init__(self, arch="mae_vit_large_patch16_gan"): |
|
super().__init__() |
|
|
|
shared_args = dict( |
|
decoder_embed_dim=512, |
|
decoder_depth=8, |
|
decoder_num_heads=16, |
|
mlp_ratio=4, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6) |
|
) |
|
if arch == "mae_vit_base_patch16": |
|
self.model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, **shared_args) |
|
chkpoint_dir = '../models/mae_visualize_vit_base.pth' |
|
elif arch == "mae_vit_large_patch16": |
|
self.model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args) |
|
chkpoint_dir = '../models/mae_visualize_vit_large.pth' |
|
elif arch == "mae_vit_large_patch16_gan": |
|
self.model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **shared_args) |
|
chkpoint_dir = '../models/mae_visualize_vit_large_ganloss.pth' |
|
elif arch == "mae_vit_huge_patch14": |
|
self.model = MaskedAutoencoderViT( |
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, **shared_args) |
|
chkpoint_dir = '../models/mae_visualize_vit_huge.pth' |
|
else: |
|
raise ValueError("Unknown model arch {}".format(arch)) |
|
|
|
|
|
chkpoint_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), chkpoint_dir) |
|
|
|
checkpoint = torch.load(chkpoint_dir) |
|
self.model.load_state_dict(checkpoint['model'], strict=False) |
|
|
|
def get_cls_token(self, img): |
|
feats, cls_token = self.model.featurize(img) |
|
return cls_token |
|
|
|
def forward(self, img): |
|
feats, cls_token = self.model.featurize(img) |
|
return feats |
|
|
|
|
|
if __name__ == "__main__": |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
from shared import norm, crop_to_divisor |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
image = Image.open("../samples/lex1.jpg") |
|
load_size = 224 |
|
transform = T.Compose([ |
|
T.Resize(load_size, Image.BILINEAR), |
|
|
|
T.ToTensor(), |
|
lambda x: crop_to_divisor(x, 16), |
|
norm]) |
|
|
|
model = MAEFeaturizer().cuda() |
|
|
|
results = model(transform(image).cuda().unsqueeze(0)) |
|
|
|
print(results.shape) |
|
|