|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from timm.models import create_model |
|
|
|
import sys, os |
|
from math import sqrt |
|
|
|
|
|
|
|
|
|
|
|
|
|
from .cliploss import ClipLoss |
|
from .quant import VectorQuantizer2 |
|
from .lookup_free_quantize import LFQ |
|
from .dino_enc.dinov2 import DINOv2Encoder, DINOv2Decoder |
|
from .latent_perturbation import add_perturbation |
|
from datasets import Denormalize |
|
from datasets import Normalize as ImgNormalize |
|
|
|
import torch.distributed as tdist |
|
|
|
|
|
@dataclass |
|
class ModelArgs: |
|
codebook_size: int = 16384 |
|
codebook_embed_dim: int = 8 |
|
codebook_l2_norm: bool = True |
|
codebook_show_usage: bool = True |
|
commit_loss_beta: float = 0.25 |
|
entropy_loss_ratio: float = 0.0 |
|
|
|
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) |
|
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) |
|
z_channels: int = 256 |
|
dropout_p: float = 0.0 |
|
|
|
v_patch_nums: List[int] = field( |
|
default_factory=lambda: [1, 2, 3, 4, 5, 6, 8, 10, 13, 16] |
|
) |
|
enc_type: str = "cnn" |
|
dec_type: str = "cnn" |
|
semantic_guide: str = "dinov2" |
|
detail_guide: str = "clip" |
|
num_latent_tokens: int = 256 |
|
encoder_model: str = "vit_small_patch14_dinov2.lvd142m" |
|
decoder_model: str = "vit_small_patch14_dinov2.lvd142m" |
|
abs_pos_embed: bool = False |
|
share_quant_resi: int = 4 |
|
product_quant: int = 1 |
|
codebook_drop: float = 0.0 |
|
half_sem: bool = False |
|
start_drop: int = 1 |
|
sem_loss_weight: float = 0.1 |
|
detail_loss_weight: float = 0.1 |
|
clip_norm: bool = False |
|
sem_loss_scale: float = 1.0 |
|
detail_loss_scale: float = 1.0 |
|
guide_type_1: str = "class" |
|
guide_type_2: str = "class" |
|
|
|
lfq: bool = False |
|
scale: float = 1.0 |
|
soft_entropy: bool = True |
|
|
|
dependency_loss_weight: float = 0.0 |
|
|
|
test_model: bool = False |
|
|
|
|
|
class VQModel(nn.Module): |
|
def __init__( |
|
self, |
|
config: ModelArgs, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.enc_type = config.enc_type |
|
self.dec_type = config.dec_type |
|
self.product_quant = config.product_quant |
|
self.half_sem = config.half_sem |
|
self.start_drop = config.start_drop |
|
self.clip_norm = config.clip_norm |
|
config.num_latent_tokens = ( |
|
config.num_latent_tokens * config.product_quant |
|
) |
|
|
|
if config.enc_type == "cnn": |
|
self.encoder = Encoder( |
|
ch_mult=config.encoder_ch_mult, |
|
z_channels=config.z_channels, |
|
dropout=config.dropout_p, |
|
) |
|
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) |
|
elif config.enc_type == "dinov2": |
|
self.encoder = DINOv2Encoder( |
|
in_channels=3, |
|
num_latent_tokens=config.num_latent_tokens, |
|
model_name=config.encoder_model, |
|
model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.1}, |
|
pretrained=True, |
|
tuning_method="full", |
|
tuning_kwargs={"r": 8}, |
|
abs_pos_embed=config.abs_pos_embed, |
|
product_quant=config.product_quant, |
|
) |
|
self.quant_conv = nn.Conv2d( |
|
self.encoder.embed_dim, config.codebook_embed_dim, 1 |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
if config.dec_type == "cnn": |
|
self.decoder = Decoder( |
|
ch_mult=config.decoder_ch_mult, |
|
z_channels=config.z_channels, |
|
dropout=config.dropout_p, |
|
) |
|
self.post_quant_conv = nn.Conv2d( |
|
config.codebook_embed_dim, config.z_channels, 1 |
|
) |
|
elif config.dec_type == "dinov2": |
|
self.decoder = DINOv2Decoder( |
|
in_channels=3, |
|
num_latent_tokens=config.num_latent_tokens // self.product_quant, |
|
model_name=config.decoder_model, |
|
model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.1}, |
|
pretrained=True, |
|
tuning_method="full", |
|
tuning_kwargs={"r": 8}, |
|
to_pixel="linear", |
|
use_rope=False, |
|
cond_latent=False, |
|
abs_pos_embed=config.abs_pos_embed, |
|
) |
|
self.post_quant_conv = nn.Conv2d( |
|
config.codebook_embed_dim, self.decoder.embed_dim, 1 |
|
) |
|
|
|
self.V = self.vocab_size = config.codebook_size * self.product_quant |
|
self.Cvae = config.codebook_embed_dim * self.product_quant |
|
if self.product_quant > 1: |
|
if len(config.v_patch_nums) == 1: |
|
self.quantizes = nn.ModuleList( |
|
[ |
|
VectorQuantizer( |
|
config.codebook_size, |
|
config.codebook_embed_dim, |
|
config.commit_loss_beta, |
|
config.codebook_l2_norm, |
|
) |
|
for _ in range(self.product_quant) |
|
] |
|
) |
|
elif not config.lfq: |
|
self.quantizes = nn.ModuleList( |
|
[ |
|
VectorQuantizer2( |
|
config.codebook_size, |
|
config.codebook_embed_dim, |
|
v_patch_nums=config.v_patch_nums, |
|
num_latent_tokens=config.num_latent_tokens |
|
// self.product_quant, |
|
share_quant_resi=config.share_quant_resi, |
|
codebook_drop=config.codebook_drop, |
|
) |
|
for _ in range(self.product_quant) |
|
] |
|
) |
|
else: |
|
self.quantizes = nn.ModuleList( |
|
[ |
|
LFQ( |
|
config.codebook_size, |
|
config.codebook_embed_dim, |
|
v_patch_nums=config.v_patch_nums, |
|
num_latent_tokens=config.num_latent_tokens |
|
// self.product_quant, |
|
share_quant_resi=config.share_quant_resi, |
|
codebook_drop=config.codebook_drop, |
|
using_znorm=config.codebook_l2_norm, |
|
scale=config.scale, |
|
entropy_weight=config.entropy_loss_ratio, |
|
soft_entropy=config.soft_entropy, |
|
) |
|
for _ in range(self.product_quant) |
|
] |
|
) |
|
self.post_quant_conv = nn.Conv2d( |
|
config.codebook_embed_dim * self.product_quant, |
|
self.decoder.embed_dim, |
|
1, |
|
) |
|
else: |
|
if len(config.v_patch_nums) == 1: |
|
self.quantize = VectorQuantizer( |
|
config.codebook_size, |
|
config.codebook_embed_dim, |
|
config.commit_loss_beta, |
|
config.codebook_l2_norm, |
|
) |
|
elif not config.lfq: |
|
self.quantize = VectorQuantizer2( |
|
config.codebook_size, |
|
config.codebook_embed_dim, |
|
v_patch_nums=config.v_patch_nums, |
|
num_latent_tokens=config.num_latent_tokens, |
|
share_quant_resi=config.share_quant_resi, |
|
) |
|
else: |
|
self.quantize = LFQ( |
|
config.codebook_size, |
|
config.codebook_embed_dim, |
|
v_patch_nums=config.v_patch_nums, |
|
num_latent_tokens=config.num_latent_tokens, |
|
share_quant_resi=config.share_quant_resi, |
|
codebook_drop=config.codebook_drop, |
|
using_znorm=config.codebook_l2_norm, |
|
scale=config.scale, |
|
entropy_weight=config.entropy_loss_ratio, |
|
soft_entropy=config.soft_entropy, |
|
) |
|
|
|
self.codebook_embed_dim = config.codebook_embed_dim |
|
self.v_patch_nums = config.v_patch_nums |
|
self.codebook_drop = config.codebook_drop |
|
|
|
self.semantic_guide = config.semantic_guide |
|
self.denormalize = Denormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
self.normalize = ImgNormalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
if self.semantic_guide == "dinov2": |
|
semantic_model = create_model( |
|
config.encoder_model, |
|
pretrained=True, |
|
img_size=256, |
|
patch_size=16, |
|
drop_path_rate=0.0, |
|
) |
|
semantic_model.eval() |
|
for param in semantic_model.parameters(): |
|
param.requires_grad = False |
|
self.semantic_model = ( |
|
semantic_model |
|
) |
|
|
|
local_loss = False |
|
gather_with_grad = True |
|
rank = tdist.get_rank() |
|
world_size = tdist.get_world_size() |
|
use_horovod = False |
|
sem_loss_scale = config.sem_loss_scale |
|
|
|
self.sem_loss_scale = sem_loss_scale |
|
self.semantic_loss = ClipLoss( |
|
local_loss=local_loss, |
|
gather_with_grad=gather_with_grad, |
|
cache_labels=True, |
|
rank=rank, |
|
world_size=world_size, |
|
use_horovod=use_horovod, |
|
) |
|
if not self.half_sem and self.product_quant > 1: |
|
self.sem_linear = nn.Conv2d( |
|
self.product_quant * config.codebook_embed_dim, |
|
config.codebook_embed_dim, |
|
1, |
|
) |
|
elif self.half_sem and self.product_quant == 1: |
|
self.sem_linear = nn.Conv2d(768, config.codebook_embed_dim // 2, 1) |
|
if self.enc_type == "cnn": |
|
self.sem_linear = torch.nn.Linear(384, config.codebook_embed_dim) |
|
|
|
self.sem_loss_weight = config.sem_loss_weight |
|
|
|
self.detail_guide = config.detail_guide |
|
if self.detail_guide != "none": |
|
detail_model = create_model( |
|
"vit_base_patch16_clip_224.openai", |
|
pretrained=True, |
|
img_size=256, |
|
patch_size=16, |
|
drop_path_rate=0.0, |
|
) |
|
detail_model.eval() |
|
for param in detail_model.parameters(): |
|
param.requires_grad = False |
|
self.detail_model = detail_model |
|
|
|
self.detail_loss_scale = config.detail_loss_scale |
|
self.detail_loss = ClipLoss( |
|
local_loss=False, |
|
gather_with_grad=True, |
|
cache_labels=True, |
|
rank=tdist.get_rank(), |
|
world_size=tdist.get_world_size(), |
|
use_horovod=False, |
|
) |
|
self.detail_loss_weight = config.detail_loss_weight |
|
|
|
self.guide_type_1 = config.guide_type_1 |
|
self.guide_type_2 = config.guide_type_2 |
|
self.dependency_loss_weight = config.dependency_loss_weight |
|
|
|
self.test_mode = config.test_model |
|
|
|
if self.test_mode: |
|
self.eval() |
|
[p.requires_grad_(False) for p in self.parameters()] |
|
|
|
def finetune(self, enc_tuning_method, dec_tuning_method): |
|
self.encoder.finetine(enc_tuning_method) |
|
self.decoder.finetine(dec_tuning_method) |
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
if self.enc_type == "dinov2": |
|
b, l, c = h.shape |
|
if self.product_quant > 1: |
|
assert int(sqrt(l // self.product_quant)) ** 2 * self.product_quant == l |
|
h = h.view(b, l, 1, c) |
|
h = h.permute(0, 3, 1, 2) |
|
else: |
|
assert int(sqrt(l)) ** 2 == l |
|
h = h.view(b, int(sqrt(l)), int(sqrt(l)), c) |
|
h = h.permute(0, 3, 1, 2) |
|
h = self.quant_conv(h) |
|
return h |
|
|
|
def decode(self, quant, return_quant=False): |
|
quant = self.post_quant_conv(quant) |
|
if self.dec_type == "dinov2": |
|
quant = quant.flatten(2).permute(0, 2, 1) |
|
dec = self.decoder(quant) |
|
return dec |
|
|
|
def decode_code( |
|
self, |
|
code_b, |
|
): |
|
quant_b, usages, mean_vq_loss = self.quantize(code_b, ret_usages=True) |
|
dec = self.decode(quant_b) |
|
return dec |
|
|
|
def forward(self, input, epoch, alpha, beta, delta): |
|
h = self.encode(input) |
|
b, c, l, _ = h.shape |
|
if len(self.v_patch_nums) == 1: |
|
dropout_rand = None |
|
else: |
|
dropout_rand = torch.randint( |
|
self.start_drop, len(self.v_patch_nums) + 1, (b,) |
|
) |
|
|
|
if self.product_quant > 1: |
|
h_list = h.chunk(chunks=self.product_quant, dim=2) |
|
( |
|
quant_list, |
|
usages_list, |
|
mean_vq_loss_list, |
|
commit_loss_list, |
|
entropy_list, |
|
) = ([], [], [], [], []) |
|
for i, h in enumerate(h_list): |
|
h = h.view( |
|
b, |
|
-1, |
|
int(sqrt(l // self.product_quant)), |
|
int(sqrt(l // self.product_quant)), |
|
) |
|
quant, usages, vq_loss, commit_loss, entropy_loss = self.quantizes[ |
|
i |
|
].forward(h, ret_usages=True, dropout=dropout_rand) |
|
quant_list.append(quant) |
|
usages_list.append(usages) |
|
mean_vq_loss_list.append(vq_loss) |
|
commit_loss_list.append(commit_loss) |
|
entropy_list.append(entropy_loss) |
|
dependency_loss = self.dependency_loss_weight * orthogonal_cosine_loss( |
|
torch.mean(quant_list[0], dim=(2, 3)).contiguous(), |
|
torch.mean(quant_list[-1], dim=(2, 3)).contiguous(), |
|
) |
|
usages = [sum(us) / self.product_quant for us in zip(*usages_list)] |
|
mean_vq_loss = sum(mean_vq_loss_list) / self.product_quant |
|
mean_commit_loss = sum(commit_loss_list) / self.product_quant |
|
mean_entropy = sum(entropy_list) / self.product_quant |
|
quant = torch.cat(quant_list, dim=1) |
|
else: |
|
dependency_loss = 0.0 |
|
quant, usages, mean_vq_loss, mean_commit_loss, mean_entropy = ( |
|
self.quantize.forward(h, ret_usages=True, dropout=dropout_rand) |
|
) |
|
print(alpha, beta, delta) |
|
quant = add_perturbation( |
|
h, |
|
quant, |
|
self.quantize.z_channels, |
|
self.quantize.codebook_norm, |
|
self.quantize.embedding, |
|
alpha, |
|
beta, |
|
delta, |
|
) |
|
quant_list = [quant] |
|
|
|
dec = self.decode(quant) |
|
|
|
|
|
input = self.normalize(self.denormalize(input)) |
|
if self.semantic_guide != "none": |
|
if self.guide_type_1 == "class": |
|
z_s = self.semantic_model(input) |
|
z_s = z_s[..., None, None] |
|
else: |
|
z_s = self.semantic_model.forward_features(input)[:, 1:, :] |
|
z_s = z_s.reshape(b, 768, 16, 16) |
|
if self.enc_type == "dinov2": |
|
z_s = self.quant_conv(z_s).contiguous() |
|
semantic_quant = quant_list[-1] |
|
z_s = torch.mean(z_s, dim=(2, 3)).contiguous() |
|
z_q_ = torch.mean(semantic_quant, dim=(2, 3)).contiguous() |
|
elif self.enc_type == "cnn": |
|
z_q_ = torch.mean(h, dim=(2, 3)).contiguous() |
|
z_s = self.sem_linear(z_s).contiguous() |
|
|
|
n_drop = int(b * self.codebook_drop) |
|
with torch.cuda.amp.autocast(enabled=False): |
|
sem_loss_scale = self.sem_loss_scale |
|
feat1 = z_s[n_drop:].float() |
|
feat2 = z_q_[n_drop:].float() |
|
if self.clip_norm: |
|
feat1 = feat1 / feat1.norm(dim=1, keepdim=True) |
|
feat2 = feat2 / feat2.norm(dim=1, keepdim=True) |
|
sem_loss_scale = ( |
|
(epoch % 200) / 200 * (100 - sem_loss_scale) + sem_loss_scale |
|
if epoch < 200 |
|
else 100 |
|
) |
|
sem_loss = self.semantic_loss.forward( |
|
feat1, feat2, logit_scale=sem_loss_scale |
|
) |
|
sem_loss = sem_loss * self.sem_loss_weight |
|
else: |
|
sem_loss = None |
|
|
|
if self.detail_guide != "none": |
|
assert ( |
|
self.guide_type_2 == "patch" |
|
), "current only accept patch for detail guide" |
|
if self.guide_type_2 == "class": |
|
z_d = self.detail_model(input) |
|
z_d = z_d[..., None, None] |
|
else: |
|
z_d = self.detail_model.forward_features(input)[:, 1:, :] |
|
z_d = z_d.reshape(b, 768, 16, 16) |
|
if self.enc_type == "dinov2": |
|
z_d = self.quant_conv(z_d).contiguous() |
|
detail_quant = quant_list[0] |
|
z_d = torch.mean(z_d, dim=(2, 3)).contiguous() |
|
z_q_ = torch.mean(detail_quant, dim=(2, 3)).contiguous() |
|
elif self.enc_type == "cnn": |
|
pass |
|
|
|
n_drop = int(b * self.codebook_drop) |
|
with torch.cuda.amp.autocast(enabled=False): |
|
detail_loss_scale = self.detail_loss_scale |
|
feat1 = z_d[n_drop:].float() |
|
feat2 = z_q_[n_drop:].float() |
|
if self.clip_norm: |
|
feat1 = feat1 / feat1.norm(dim=1, keepdim=True) |
|
feat2 = feat2 / feat2.norm(dim=1, keepdim=True) |
|
detail_loss_scale = ( |
|
(epoch % 200) / 200 * (100 - detail_loss_scale) |
|
+ detail_loss_scale |
|
if epoch < 200 |
|
else 100 |
|
) |
|
detail_loss = self.detail_loss.forward( |
|
feat1, feat2, logit_scale=detail_loss_scale |
|
) |
|
detail_loss = detail_loss * self.detail_loss_weight |
|
else: |
|
detail_loss = None |
|
|
|
return ( |
|
dec, |
|
(mean_vq_loss, mean_commit_loss, mean_entropy, usages), |
|
sem_loss, |
|
detail_loss, |
|
dependency_loss, |
|
) |
|
|
|
def img_to_reconstructed_img( |
|
self, |
|
x, |
|
last_one=True, |
|
) -> List[torch.Tensor]: |
|
h = self.encoder(x) |
|
if self.enc_type == "dinov2": |
|
b, l, c = h.shape |
|
if self.product_quant > 1: |
|
assert int(sqrt(l // self.product_quant)) ** 2 * self.product_quant == l |
|
h = h.view(b, l, 1, c) |
|
h = h.permute(0, 3, 1, 2) |
|
else: |
|
assert int(sqrt(l)) ** 2 == l |
|
h = h.view(b, int(sqrt(l)), int(sqrt(l)), c) |
|
h = h.permute(0, 3, 1, 2) |
|
f = self.quant_conv(h) |
|
|
|
if self.product_quant > 1: |
|
b, c, l, _ = f.shape |
|
f_list = f.chunk(chunks=self.product_quant, dim=2) |
|
f_list = [ |
|
f.view( |
|
b, |
|
-1, |
|
int(sqrt(l // self.product_quant)), |
|
int(sqrt(l // self.product_quant)), |
|
) |
|
for f in f_list |
|
] |
|
if len(self.v_patch_nums) == 1: |
|
f_hats_list = [ |
|
self.quantizes[i].f_to_idxBl_or_fhat( |
|
f, to_fhat=True, v_patch_nums=None |
|
) |
|
for i, f in enumerate(f_list) |
|
] |
|
else: |
|
f_hats_list = [ |
|
self.quantizes[i].f_to_idxBl_or_fhat( |
|
f, to_fhat=True, v_patch_nums=self.v_patch_nums |
|
) |
|
for i, f in enumerate(f_list) |
|
] |
|
f_hats = [ |
|
self.post_quant_conv(torch.cat(f_hats, dim=1)) |
|
for f_hats in zip(*f_hats_list) |
|
] |
|
else: |
|
if len(self.v_patch_nums) == 1: |
|
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat( |
|
f, to_fhat=True, v_patch_nums=None |
|
) |
|
else: |
|
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat( |
|
f, to_fhat=True, v_patch_nums=self.v_patch_nums |
|
) |
|
f_hats = [self.post_quant_conv(f_hat) for f_hat in ls_f_hat_BChw] |
|
|
|
if self.dec_type == "dinov2": |
|
f_hats = [f_hat.flatten(2).permute(0, 2, 1) for f_hat in f_hats] |
|
|
|
if last_one: |
|
return self.decoder(f_hats[-1]).clamp_(-1, 1) |
|
else: |
|
return [self.decoder(f_hat).clamp_(-1, 1) for f_hat in f_hats] |
|
|
|
def img_to_sem_feat( |
|
self, |
|
x, |
|
) -> List[torch.Tensor]: |
|
h = self.encoder(x) |
|
if self.enc_type == "dinov2": |
|
b, l, c = h.shape |
|
if self.product_quant > 1: |
|
assert int(sqrt(l // self.product_quant)) ** 2 * self.product_quant == l |
|
h = h.view(b, l, 1, c) |
|
h = h.permute(0, 3, 1, 2) |
|
else: |
|
assert int(sqrt(l)) ** 2 == l |
|
h = h.view(b, int(sqrt(l)), int(sqrt(l)), c) |
|
h = h.permute(0, 3, 1, 2) |
|
f = self.quant_conv(h) |
|
|
|
b, c, l, _ = f.shape |
|
f_list = f.chunk(chunks=self.product_quant, dim=2) |
|
f_list = [ |
|
f.view( |
|
b, |
|
-1, |
|
int(sqrt(l // self.product_quant)), |
|
int(sqrt(l // self.product_quant)), |
|
) |
|
for f in f_list |
|
] |
|
f_hats_list = [ |
|
self.quantizes[i].f_to_idxBl_or_fhat( |
|
f, to_fhat=True, v_patch_nums=self.v_patch_nums |
|
) |
|
for i, f in enumerate(f_list) |
|
] |
|
|
|
z_q = f_hats_list[-1][ |
|
-1 |
|
] |
|
return z_q |
|
|
|
def fhat_to_img(self, f_hat: torch.Tensor): |
|
f_hat = self.post_quant_conv(f_hat) |
|
if self.dec_type == "dinov2": |
|
f_hat = f_hat.flatten(2).permute(0, 2, 1) |
|
return self.decoder(f_hat).clamp_(-1, 1) |
|
|
|
def idxBl_to_var_input(self, gt_idx_Bl): |
|
if self.product_quant > 1: |
|
x_BLCv_wo_first_l_list = [ |
|
self.quantizes[i].idxBl_to_var_input(gt_idx_Bl[i]) |
|
for i in range(self.product_quant) |
|
] |
|
return torch.cat(x_BLCv_wo_first_l_list, dim=-1) |
|
else: |
|
return self.quantize.idxBl_to_var_input(gt_idx_Bl) |
|
|
|
def get_next_autoregressive_input(self, si, SN, f_hat, h_BChw): |
|
f_hat_list = f_hat.chunk(self.product_quant, dim=1) |
|
h_BChw_list = h_BChw.chunk(self.product_quant, dim=1) |
|
out_fhat_list, out_next_token_map_list = [], [] |
|
for i, (f_hat, h_BChw) in enumerate(zip(f_hat_list, h_BChw_list)): |
|
out_fhat, out_next_token_map = self.quantizes[ |
|
i |
|
].get_next_autoregressive_input(si, SN, f_hat, h_BChw) |
|
out_fhat_list.append(out_fhat) |
|
out_next_token_map_list.append(out_next_token_map) |
|
f_hat = torch.cat(out_fhat_list, dim=1) |
|
next_token_map = torch.cat(out_next_token_map_list, dim=1) |
|
return f_hat, next_token_map |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=3, |
|
ch=128, |
|
ch_mult=(1, 1, 2, 2, 4), |
|
num_res_blocks=2, |
|
norm_type="group", |
|
dropout=0.0, |
|
resamp_with_conv=True, |
|
z_channels=256, |
|
): |
|
super().__init__() |
|
self.num_resolutions = len(ch_mult) |
|
self.num_res_blocks = num_res_blocks |
|
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
in_ch_mult = (1,) + tuple(ch_mult) |
|
self.conv_blocks = nn.ModuleList() |
|
for i_level in range(self.num_resolutions): |
|
conv_block = nn.Module() |
|
|
|
res_block = nn.ModuleList() |
|
attn_block = nn.ModuleList() |
|
block_in = ch * in_ch_mult[i_level] |
|
block_out = ch * ch_mult[i_level] |
|
for _ in range(self.num_res_blocks): |
|
res_block.append( |
|
ResnetBlock( |
|
block_in, block_out, dropout=dropout, norm_type=norm_type |
|
) |
|
) |
|
block_in = block_out |
|
if i_level == self.num_resolutions - 1: |
|
attn_block.append(AttnBlock(block_in, norm_type)) |
|
conv_block.res = res_block |
|
conv_block.attn = attn_block |
|
|
|
if i_level != self.num_resolutions - 1: |
|
conv_block.downsample = Downsample(block_in, resamp_with_conv) |
|
self.conv_blocks.append(conv_block) |
|
|
|
|
|
self.mid = nn.ModuleList() |
|
self.mid.append( |
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) |
|
) |
|
self.mid.append(AttnBlock(block_in, norm_type=norm_type)) |
|
self.mid.append( |
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) |
|
) |
|
|
|
|
|
self.norm_out = Normalize(block_in, norm_type) |
|
self.conv_out = nn.Conv2d( |
|
block_in, z_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
def forward(self, x): |
|
h = self.conv_in(x) |
|
|
|
for i_level, block in enumerate(self.conv_blocks): |
|
for i_block in range(self.num_res_blocks): |
|
h = block.res[i_block](h) |
|
if len(block.attn) > 0: |
|
h = block.attn[i_block](h) |
|
if i_level != self.num_resolutions - 1: |
|
h = block.downsample(h) |
|
|
|
|
|
for mid_block in self.mid: |
|
h = mid_block(h) |
|
|
|
|
|
h = self.norm_out(h) |
|
h = nonlinearity(h) |
|
h = self.conv_out(h) |
|
return h |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
z_channels=256, |
|
ch=128, |
|
ch_mult=(1, 1, 2, 2, 4), |
|
num_res_blocks=2, |
|
norm_type="group", |
|
dropout=0.0, |
|
resamp_with_conv=True, |
|
out_channels=3, |
|
): |
|
super().__init__() |
|
self.num_resolutions = len(ch_mult) |
|
self.num_res_blocks = num_res_blocks |
|
|
|
block_in = ch * ch_mult[self.num_resolutions - 1] |
|
|
|
self.conv_in = nn.Conv2d( |
|
z_channels, block_in, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
|
|
self.mid = nn.ModuleList() |
|
self.mid.append( |
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) |
|
) |
|
self.mid.append(AttnBlock(block_in, norm_type=norm_type)) |
|
self.mid.append( |
|
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) |
|
) |
|
|
|
|
|
self.conv_blocks = nn.ModuleList() |
|
for i_level in reversed(range(self.num_resolutions)): |
|
conv_block = nn.Module() |
|
|
|
res_block = nn.ModuleList() |
|
attn_block = nn.ModuleList() |
|
block_out = ch * ch_mult[i_level] |
|
for _ in range(self.num_res_blocks + 1): |
|
res_block.append( |
|
ResnetBlock( |
|
block_in, block_out, dropout=dropout, norm_type=norm_type |
|
) |
|
) |
|
block_in = block_out |
|
if i_level == self.num_resolutions - 1: |
|
attn_block.append(AttnBlock(block_in, norm_type)) |
|
conv_block.res = res_block |
|
conv_block.attn = attn_block |
|
|
|
if i_level != 0: |
|
conv_block.upsample = Upsample(block_in, resamp_with_conv) |
|
self.conv_blocks.append(conv_block) |
|
|
|
|
|
self.norm_out = Normalize(block_in, norm_type) |
|
self.conv_out = nn.Conv2d( |
|
block_in, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
@property |
|
def last_layer(self): |
|
return self.conv_out.weight |
|
|
|
def forward(self, z): |
|
|
|
h = self.conv_in(z) |
|
|
|
|
|
for mid_block in self.mid: |
|
h = mid_block(h) |
|
|
|
|
|
for i_level, block in enumerate(self.conv_blocks): |
|
for i_block in range(self.num_res_blocks + 1): |
|
h = block.res[i_block](h) |
|
if len(block.attn) > 0: |
|
h = block.attn[i_block](h) |
|
if i_level != self.num_resolutions - 1: |
|
h = block.upsample(h) |
|
|
|
|
|
h = self.norm_out(h) |
|
h = nonlinearity(h) |
|
h = self.conv_out(h) |
|
return h |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels=None, |
|
conv_shortcut=False, |
|
dropout=0.0, |
|
norm_type="group", |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
self.norm1 = Normalize(in_channels, norm_type) |
|
self.conv1 = nn.Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
self.norm2 = Normalize(out_channels, norm_type) |
|
self.dropout = nn.Dropout(dropout) |
|
self.conv2 = nn.Conv2d( |
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = nn.Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
else: |
|
self.nin_shortcut = nn.Conv2d( |
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
return x + h |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__(self, in_channels, norm_type="group"): |
|
super().__init__() |
|
self.norm = Normalize(in_channels, norm_type) |
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
self.proj_out = nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
h_ = x |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
b, c, h, w = q.shape |
|
q = q.reshape(b, c, h * w) |
|
q = q.permute(0, 2, 1) |
|
k = k.reshape(b, c, h * w) |
|
w_ = torch.bmm(q, k) |
|
w_ = w_ * (int(c) ** (-0.5)) |
|
w_ = F.softmax(w_, dim=2) |
|
|
|
|
|
v = v.reshape(b, c, h * w) |
|
w_ = w_.permute(0, 2, 1) |
|
h_ = torch.bmm(v, w_) |
|
h_ = h_.reshape(b, c, h, w) |
|
|
|
h_ = self.proj_out(h_) |
|
|
|
return x + h_ |
|
|
|
|
|
def nonlinearity(x): |
|
|
|
return x * torch.sigmoid(x) |
|
|
|
|
|
def Normalize(in_channels, norm_type="group"): |
|
assert norm_type in ["group", "batch"] |
|
if norm_type == "group": |
|
return nn.GroupNorm( |
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
elif norm_type == "batch": |
|
return nn.SyncBatchNorm(in_channels) |
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
self.conv = nn.Conv2d( |
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
def forward(self, x): |
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest") |
|
if self.with_conv: |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class Downsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
|
|
self.conv = nn.Conv2d( |
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
if self.with_conv: |
|
pad = (0, 1, 0, 1) |
|
x = F.pad(x, pad, mode="constant", value=0) |
|
x = self.conv(x) |
|
else: |
|
x = F.avg_pool2d(x, kernel_size=2, stride=2) |
|
return x |
|
|
|
|
|
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): |
|
flat_affinity = affinity.reshape(-1, affinity.shape[-1]) |
|
flat_affinity /= temperature |
|
probs = F.softmax(flat_affinity, dim=-1) |
|
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) |
|
if loss_type == "softmax": |
|
target_probs = probs |
|
else: |
|
raise ValueError("Entropy loss {} not supported".format(loss_type)) |
|
avg_probs = torch.mean(target_probs, dim=0) |
|
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) |
|
sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1)) |
|
loss = sample_entropy - avg_entropy |
|
return loss |
|
|
|
|
|
class VectorQuantizer(nn.Module): |
|
|
|
def __init__(self, vocab_size=8192, z_channels=32, beta=0.25, codebook_norm=True): |
|
super().__init__() |
|
|
|
self.vocab_size = vocab_size |
|
self.z_channels = z_channels |
|
self.beta = beta |
|
self.codebook_norm = codebook_norm |
|
|
|
|
|
|
|
self.embedding = nn.Embedding(self.vocab_size, self.z_channels) |
|
self.embedding.weight.data.uniform_( |
|
-1.0 / self.vocab_size, 1.0 / self.vocab_size |
|
) |
|
if self.codebook_norm: |
|
self.embedding.weight.data = F.normalize( |
|
self.embedding.weight.data, p=2, dim=-1 |
|
) |
|
|
|
self.register_buffer( |
|
"ema_vocab_hit_SV", torch.full((self.vocab_size,), fill_value=0.0) |
|
) |
|
self.record_hit = 0 |
|
|
|
def no_weight_decay(self): |
|
return [ |
|
"embedding.weight", |
|
] |
|
|
|
def forward(self, z, ret_usages=True, dropout=None): |
|
|
|
vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=z.device) |
|
|
|
|
|
z = torch.einsum("b c h w -> b h w c", z).contiguous() |
|
z_flattened = z.view(-1, self.z_channels) |
|
|
|
if self.codebook_norm: |
|
z = F.normalize(z, p=2, dim=-1) |
|
z_flattened = F.normalize(z_flattened, p=2, dim=-1) |
|
embedding = F.normalize(self.embedding.weight, p=2, dim=-1) |
|
else: |
|
embedding = self.embedding.weight |
|
|
|
|
|
d = ( |
|
torch.sum(z_flattened**2, dim=1, keepdim=True) |
|
+ torch.sum(embedding**2, dim=1) |
|
- 2 |
|
* torch.einsum( |
|
"bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) |
|
) |
|
) |
|
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1) |
|
|
|
z_q = self.embedding(min_encoding_indices).view(z.shape) |
|
if self.codebook_norm: |
|
z_q = F.normalize(z_q, p=2, dim=-1) |
|
|
|
if ret_usages and self.training: |
|
hit_V = min_encoding_indices.bincount(minlength=self.vocab_size).float() |
|
handler = tdist.all_reduce(hit_V, async_op=True) |
|
handler.wait() |
|
if self.record_hit == 0: |
|
self.ema_vocab_hit_SV.copy_(hit_V) |
|
elif self.record_hit < 100: |
|
self.ema_vocab_hit_SV.mul_(0.9).add_(hit_V.mul(0.1)) |
|
else: |
|
self.ema_vocab_hit_SV.mul_(0.99).add_(hit_V.mul(0.01)) |
|
self.record_hit += 1 |
|
vocab_hit_V.add_(hit_V) |
|
|
|
margin = ( |
|
tdist.get_world_size() |
|
* (z.numel() / self.z_channels) |
|
/ self.vocab_size |
|
* 0.08 |
|
) |
|
|
|
codebook_usage = ( |
|
self.ema_vocab_hit_SV >= margin |
|
).float().mean().item() * 100 |
|
|
|
|
|
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) |
|
vq_loss = torch.mean((z_q - z.detach()) ** 2) |
|
|
|
|
|
z_q = z + (z_q - z).detach() |
|
|
|
|
|
z_q = torch.einsum("b h w c -> b c h w", z_q) |
|
|
|
return z_q, [codebook_usage], vq_loss, commit_loss, 0.0 |
|
|
|
def f_to_idxBl_or_fhat( |
|
self, z: torch.Tensor, to_fhat: bool, v_patch_nums |
|
): |
|
|
|
z = torch.einsum("b c h w -> b h w c", z).contiguous() |
|
z_flattened = z.view(-1, self.z_channels) |
|
|
|
|
|
if self.codebook_norm: |
|
z = F.normalize(z, p=2, dim=-1) |
|
z_flattened = F.normalize(z_flattened, p=2, dim=-1) |
|
embedding = F.normalize(self.embedding.weight, p=2, dim=-1) |
|
else: |
|
embedding = self.embedding.weight |
|
|
|
|
|
d = ( |
|
torch.sum(z_flattened**2, dim=1, keepdim=True) |
|
+ torch.sum(embedding**2, dim=1) |
|
- 2 |
|
* torch.einsum( |
|
"bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding) |
|
) |
|
) |
|
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1) |
|
|
|
z_q = self.embedding(min_encoding_indices).view(z.shape) |
|
if self.codebook_norm: |
|
z_q = F.normalize(z_q, p=2, dim=-1) |
|
|
|
|
|
z_q = torch.einsum("b h w c -> b c h w", z_q) |
|
|
|
f_hat_or_idx_Bl: List[torch.Tensor] = [z_q if to_fhat else min_encoding_indices] |
|
|
|
return f_hat_or_idx_Bl |
|
|
|
|
|
def orthogonal_cosine_loss(A, B): |
|
A_norm = A / A.norm(dim=1, keepdim=True) |
|
B_norm = B / B.norm(dim=1, keepdim=True) |
|
loss = (A_norm * B_norm).sum(dim=1).mean() |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def VQ_8(**kwargs): |
|
return VQModel( |
|
ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs) |
|
) |
|
|
|
|
|
def VQ_16(**kwargs): |
|
return VQModel( |
|
ModelArgs( |
|
encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs |
|
) |
|
) |
|
|
|
|
|
VQ_models = {"VQ-16": VQ_16, "VQ-8": VQ_8} |
|
|
|
if __name__ == "__main__": |
|
semantic_model = create_model( |
|
"vit_small_patch14_dinov2.lvd142m", |
|
pretrained=True, |
|
img_size=256, |
|
patch_size=16, |
|
drop_path_rate=0.0, |
|
) |
|
semantic_model.eval() |
|
|