VTBench / src /vqvaes /xqgan /discriminator_dino.py
huaweilin's picture
update
14ce5a9
import math
import os.path
import random
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import SpectralNorm
from torchvision.transforms import RandomCrop
import dist
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.fused_dense import fused_mlp_func
except:
dropout_add_layer_norm = fused_mlp_func = None
try:
from flash_attn import flash_attn_qkvpacked_func # qkv: BL3Hc, ret: BLHcq
except:
flash_attn_qkvpacked_func = None
try:
assert torch.cuda.is_available()
from torch.nn.functional import (
scaled_dot_product_attention as slow_attn,
) # q, k, v: BHLc
except:
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
if attn_mask is not None:
attn.add_(attn_mask)
return (
F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True)
if dropout_p > 0
else attn.softmax(dim=-1)
) @ value
class MLPNoDrop(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
fused_if_available=True,
):
super().__init__()
self.fused_mlp_func = (
fused_mlp_func
if (torch.cuda.is_available() and fused_if_available)
else None
)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU(approximate="tanh")
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
if self.fused_mlp_func is not None:
return self.fused_mlp_func(
x=x,
weight1=self.fc1.weight,
weight2=self.fc2.weight,
bias1=self.fc1.bias,
bias2=self.fc2.bias,
activation="gelu_approx",
save_pre_act=self.training,
return_residual=False,
checkpoint_lvl=0,
heuristic=0,
process_group=None,
)
else:
return self.fc2(self.act(self.fc1(x)))
def extra_repr(self) -> str:
return f"fused_mlp_func={self.fused_mlp_func is not None}"
class SelfAttentionNoDrop(nn.Module):
def __init__(
self,
block_idx,
embed_dim=768,
num_heads=12,
flash_if_available=True,
):
super().__init__()
assert embed_dim % num_heads == 0
self.block_idx, self.num_heads, self.head_dim = (
block_idx,
num_heads,
embed_dim // num_heads,
) # =64
self.scale = 1 / math.sqrt(self.head_dim)
self.qkv, self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=True), nn.Linear(
embed_dim, embed_dim, bias=True
)
self.using_flash_attn = (
torch.cuda.is_available()
and flash_if_available
and flash_attn_qkvpacked_func is not None
)
def forward(self, x):
B, L, C = x.shape
qkv = self.qkv(x).view(B, L, 3, self.num_heads, self.head_dim)
if self.using_flash_attn and qkv.dtype != torch.float32:
oup = flash_attn_qkvpacked_func(qkv, softmax_scale=self.scale).view(B, L, C)
else:
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) # BHLc
oup = (
slow_attn(query=q, key=k, value=v, scale=self.scale)
.transpose(1, 2)
.reshape(B, L, C)
)
return self.proj(oup)
def extra_repr(self) -> str:
return f"using_flash_attn={self.using_flash_attn}"
class SABlockNoDrop(nn.Module):
def __init__(self, block_idx, embed_dim, num_heads, mlp_ratio, norm_eps):
super(SABlockNoDrop, self).__init__()
self.norm1 = nn.LayerNorm(embed_dim, eps=norm_eps)
self.attn = SelfAttentionNoDrop(
block_idx=block_idx,
embed_dim=embed_dim,
num_heads=num_heads,
flash_if_available=True,
)
self.norm2 = nn.LayerNorm(embed_dim, eps=norm_eps)
self.mlp = MLPNoDrop(
in_features=embed_dim,
hidden_features=round(embed_dim * mlp_ratio),
fused_if_available=True,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class ResidualBlock(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.ratio = 1 / np.sqrt(2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x = x.float()
return (self.fn(x).add(x)).mul_(self.ratio)
class SpectralConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12)
class BatchNormLocal(nn.Module):
def __init__(
self,
num_features: int,
affine: bool = True,
virtual_bs: int = 8,
eps: float = 1e-6,
):
super().__init__()
self.virtual_bs = virtual_bs
self.eps = eps
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = x.float()
# Reshape batch into groups.
G = np.ceil(x.size(0) / self.virtual_bs).astype(int)
x = x.view(G, -1, x.size(-2), x.size(-1))
# Calculate stats.
mean = x.mean([1, 3], keepdim=True)
var = x.var([1, 3], keepdim=True, unbiased=False)
x = (x - mean) / (torch.sqrt(var + self.eps))
if self.affine:
x = x * self.weight[None, :, None] + self.bias[None, :, None]
return x.view(shape)
def make_block(
channels: int,
kernel_size: int,
norm_type: str,
norm_eps: float,
using_spec_norm: bool,
) -> nn.Module:
if norm_type == "bn":
norm = BatchNormLocal(channels, eps=norm_eps)
elif norm_type == "sbn":
norm = nn.SyncBatchNorm(channels, eps=norm_eps, process_group=None)
elif norm_type in {"lbn", "hbn"}:
norm = nn.SyncBatchNorm(
channels, eps=norm_eps, process_group=dist.new_local_machine_group()
)
elif norm_type == "gn":
norm = nn.GroupNorm(
num_groups=32, num_channels=channels, eps=norm_eps, affine=True
)
else:
raise NotImplementedError
return nn.Sequential(
(SpectralConv1d if using_spec_norm else nn.Conv1d)(
channels,
channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
padding_mode="circular",
),
norm,
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
class DinoDisc(nn.Module):
def __init__(
self,
dino_ckpt_path="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
device="cuda",
ks=9,
depth=12,
key_depths=(2, 5, 8, 11),
norm_type="bn",
using_spec_norm=True,
norm_eps=1e-6,
):
super().__init__()
# load state
state = torch.hub.load_state_dict_from_url(dino_ckpt_path, map_location="cpu")
# state = torch.load(dino_ckpt_path, 'cpu')
for k in sorted(state.keys()):
if ".attn.qkv.bias" in k:
bias = state[k]
C = bias.numel() // 3
bias[C : 2 * C].zero_() # zero out k_bias
# build DINO
key_depths = tuple(d for d in key_depths if d < depth)
d = FrozenDINOSmallNoDrop(depth=depth, key_depths=key_depths, norm_eps=norm_eps)
missing, unexpected = d.load_state_dict(state, strict=False)
missing = [
m
for m in missing
if all(
x not in m
for x in {
"x_scale",
"x_shift",
}
)
]
if torch.cuda.is_available():
assert len(missing) == 0, f"missing keys: {missing}"
assert len(unexpected) == 0, f"unexpected keys: {unexpected}"
# todo: don't compile! reduce-overhead would raise CudaERR
self.dino_proxy: Tuple[FrozenDINOSmallNoDrop] = (d.to(device=device),)
dino_C = self.dino_proxy[0].embed_dim
# if 'KEVIN_LOCAL' in os.environ:
# torch.manual_seed(0)
# np.random.seed(0)
# random.seed(0)
self.heads = nn.ModuleList(
[
nn.Sequential(
make_block(
dino_C,
kernel_size=1,
norm_type=norm_type,
norm_eps=norm_eps,
using_spec_norm=using_spec_norm,
),
ResidualBlock(
make_block(
dino_C,
kernel_size=ks,
norm_type=norm_type,
norm_eps=norm_eps,
using_spec_norm=using_spec_norm,
)
),
(SpectralConv1d if using_spec_norm else nn.Conv1d)(
dino_C, 1, kernel_size=1, padding=0
),
)
for _ in range(len(key_depths) + 1) # +1: before all attention blocks
]
)
def reinit(
self,
dino_ckpt_path="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
device="cuda",
ks=9,
depth=12,
key_depths=(2, 5, 8, 11),
norm_type="bn",
using_spec_norm=True,
norm_eps=1e-6,
):
dino_C = self.dino_proxy[0].embed_dim
heads = nn.ModuleList(
[
nn.Sequential(
make_block(
dino_C,
kernel_size=1,
norm_type=norm_type,
norm_eps=norm_eps,
using_spec_norm=using_spec_norm,
),
ResidualBlock(
make_block(
dino_C,
kernel_size=ks,
norm_type=norm_type,
norm_eps=norm_eps,
using_spec_norm=using_spec_norm,
)
),
(SpectralConv1d if using_spec_norm else nn.Conv1d)(
dino_C, 1, kernel_size=1, padding=0
),
)
for _ in range(len(key_depths) + 1)
]
)
self.heads.load_state_dict(heads.state_dict())
def forward(
self, x_in_pm1, grad_ckpt=False
): # x_in_pm1: image tensor normalized to [-1, 1]
dino_grad_ckpt = grad_ckpt and x_in_pm1.requires_grad
FrozenDINOSmallNoDrop.forward
activations: List[torch.Tensor] = self.dino_proxy[0](
x_in_pm1.float(), grad_ckpt=dino_grad_ckpt
)
B = x_in_pm1.shape[0]
return torch.cat(
[
(
h(act)
if not grad_ckpt
else torch.utils.checkpoint.checkpoint(h, act, use_reentrant=False)
).view(B, -1)
for h, act in zip(self.heads, activations)
],
dim=1,
) # cat 5 BL => B, 5L
class PatchEmbed(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2) # BCHW => BCL => BLC
return self.norm(x)
class FrozenDINOSmallNoDrop(nn.Module):
"""
Frozen DINO ViT without any dropout or droppath layers (eval node only), based on timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=0)
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def __init__(
self,
depth=12,
key_depths=(2, 5, 8, 11),
norm_eps=1e-6, # 4 stages: 012, 345, 678, 9 10 11
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=384,
num_heads=6,
mlp_ratio=4.0,
# drop_rate=0., attn_drop_rate=0., drop_path_rate=0. # no drop for frozen model
):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.img_size = 224
self.patch_embed = PatchEmbed(
img_size=self.img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.patch_size = patch_size
self.patch_nums = self.img_size // patch_size
# x \in [-1, 1]
# x = ((x+1)/2 - m) / s = 0.5x/s + 0.5/s - m/s = (0.5/s) x + (0.5-m)/s
m, s = torch.tensor((0.485, 0.456, 0.406)), torch.tensor((0.229, 0.224, 0.225))
self.register_buffer("x_scale", (0.5 / s).reshape(1, 3, 1, 1))
self.register_buffer("x_shift", ((0.5 - m) / s).reshape(1, 3, 1, 1))
self.crop = RandomCrop(self.img_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = None
self.pos_embed = nn.Parameter(
torch.zeros(1, self.patch_nums * self.patch_nums + 1, embed_dim)
) # +1: for cls
# self.pos_drop = nn.Dropout(p=drop_rate)
# self.pos_pool = dict()
self.key_depths = set(d for d in key_depths if d < depth)
# dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # no drop for frozen model
self.blocks = nn.Sequential(
*[
SABlockNoDrop(
block_idx=i,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_eps=norm_eps,
)
for i in range(max(depth, 1 + max(self.key_depths)))
]
)
self.norm = nn.LayerNorm(embed_dim, eps=norm_eps)
# eval mode only
self.eval()
[p.requires_grad_(False) for p in self.parameters()]
def inter_pos_embed(self, patch_nums=(14, 14)):
if patch_nums[0] == self.patch_nums and patch_nums[1] == self.patch_nums:
return self.pos_embed
pe_cls, pe_grid = self.pos_embed[:, :1], self.pos_embed[0, 1:]
pe_grid = pe_grid.reshape(1, self.patch_nums, self.patch_nums, -1).permute(
0, 3, 1, 2
)
pe_grid = F.interpolate(
pe_grid,
size=(patch_nums[0], patch_nums[1]),
mode="bilinear",
align_corners=False,
)
pe_grid = pe_grid.permute(0, 2, 3, 1).reshape(
1, patch_nums[0] * patch_nums[1], -1
)
return torch.cat([pe_cls, pe_grid], dim=1)
def forward(self, x, grad_ckpt=False):
with torch.cuda.amp.autocast(enabled=False):
x = (self.x_scale * x.float()).add_(self.x_shift)
H, W = x.shape[-2], x.shape[-1]
if H > self.img_size and W > self.img_size and random.random() <= 0.5:
x = self.crop(x)
else:
x = F.interpolate(
x,
size=(self.img_size, self.img_size),
mode="area" if H > self.img_size else "bicubic",
)
# x now must be self.img_size x self.img_size
# patch_nums = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
# x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), self.patch_embed(x)), dim=1)
# if patch_nums in self.pos_pool:
# x += self.pos_pool[patch_nums]
# else:
# self.pos_pool[patch_nums] = pe = self.inter_pos_embed(patch_nums)
# x += pe
# x = self.pos_drop(x)
x = self.patch_embed(x)
with torch.cuda.amp.autocast(enabled=False):
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x.float()), dim=1)
x = x + self.pos_embed
activations = [(x[:, 1:] + x[:, :1]).transpose_(1, 2)] # readout
for i, b in enumerate(self.blocks):
if not grad_ckpt:
x = b(x)
else:
x = torch.utils.checkpoint.checkpoint(b, x, use_reentrant=False)
if i in self.key_depths:
activations.append(
(x[:, 1:].float() + x[:, :1].float()).transpose_(1, 2)
) # readout
# x = self.norm(x)
return activations
if __name__ == "__main__":
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
ks = 9
norm_type = "sbn"
norm_eps = 1e-6
dino_C = 384
key_layers = (2, 5, 8, 11)
using_spec_norm = True
heads = nn.ModuleList(
[
nn.Sequential(
make_block(
dino_C,
kernel_size=1,
norm_type=norm_type,
norm_eps=norm_eps,
using_spec_norm=using_spec_norm,
),
ResidualBlock(
make_block(
dino_C,
kernel_size=ks,
norm_type=norm_type,
norm_eps=norm_eps,
using_spec_norm=using_spec_norm,
)
),
(SpectralConv1d if using_spec_norm else nn.Conv1d)(
dino_C, 1, kernel_size=1, padding=0
),
)
for _ in range(len(key_layers) + 1)
]
)
# ckpt = os.path.join(os.path.dirname(__file__), '/mnt/bn/foundation-lq/tiankeyu/ckpt_vae/vit_small_patch16_224.pth')
ckpt = "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
DinoDisc.forward
dd = DinoDisc(
dino_ckpt_path=ckpt,
device="cpu",
ks=ks,
norm_type=norm_type,
norm_eps=norm_eps,
key_depths=key_layers,
)
dd.eval()
dd.heads.load_state_dict(heads.state_dict())
print(f"{sum(p.numel() for p in dd.parameters() if p.requires_grad) / 1e6:.2f}M")
inp = torch.linspace(-2, 2, 2 * 3 * 224 * 224).reshape(2, 3, 224, 224)
inp.requires_grad = True
cond = torch.rand(2, 64)
mid_ls = dd.dino_proxy[0](inp)
means = [round(m.mean().item(), 3) for m in mid_ls]
stds = [round(m.std().item(), 3) for m in mid_ls]
print(f"mean: {means}")
print(f"std: {stds}")
o = dd(inp, grad_ckpt=True)
print(f"o: {o.abs().mean().item():.9f}, {o.abs().std().item():.9f}")
o.abs().mean().backward()
# for n, p in dd.named_parameters():
# tag = n.split('heads.')[-1][0]
# if p.ndim == 3: tag += '.conv1d'
# print(f'[{tag}] {n}: {p.shape}')
"""
对于使用qkv的版本,输出是
7.39M
mean: [0.019, -0.028, 0.054, 0.058, 0.074]
std: [0.427, 0.142, 0.169, 0.194, 0.153]
o: 50.266475677, 91.698143005
对于使用zero_k_bias的版本,输出是
7.39M
mean: [0.019, -0.028, 0.054, 0.058, 0.074]
std: [0.427, 0.142, 0.169, 0.194, 0.153]
o: 50.266475677, 91.698143005
"""