|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .quantizer.bsq import BinarySphericalQuantizer |
|
from .quantizer.vq import VectorQuantizer |
|
from .transformer import TransformerDecoder, TransformerEncoder |
|
|
|
|
|
class VITVQModel(nn.Module): |
|
def __init__(self, vitconfig, n_embed, embed_dim, |
|
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[], |
|
grad_checkpointing=False, selective_checkpointing=False, |
|
clamp_range=(0, 1), |
|
dvitconfig=None, |
|
): |
|
super().__init__() |
|
self.encoder = TransformerEncoder(**vitconfig) |
|
dvitconfig = vitconfig if dvitconfig is None else dvitconfig |
|
self.decoder = TransformerDecoder(**dvitconfig, logit_laplace=logit_laplace) |
|
if self.training and grad_checkpointing: |
|
self.encoder.set_grad_checkpointing(True, selective=selective_checkpointing) |
|
self.decoder.set_grad_checkpointing(True, selective=selective_checkpointing) |
|
|
|
self.n_embed = n_embed |
|
self.embed_dim = embed_dim |
|
self.l2_norm = l2_norm |
|
self.setup_quantizer() |
|
|
|
self.quant_embed = nn.Linear(in_features=vitconfig['width'], out_features=embed_dim) |
|
self.post_quant_embed = nn.Linear(in_features=embed_dim, out_features=dvitconfig['width']) |
|
self.l2_norm = l2_norm |
|
self.logit_laplace = logit_laplace |
|
self.clamp_range = clamp_range |
|
|
|
if ckpt_path is not None: |
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
def setup_quantizer(self): |
|
self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, l2_norm=self.l2_norm, beta=0.25, input_format='blc') |
|
|
|
|
|
def init_from_ckpt(self, state_dict, ignore_keys=[]): |
|
state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith('module.')} |
|
filtered_state_dict = {k: v for k, v in state_dict.items() if all([not k.startswith(ig) for ig in ignore_keys])} |
|
missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False) |
|
print(f"missing_keys: {missing_keys}") |
|
print(f"unexpected_keys: {unexpected_keys}") |
|
|
|
def encode(self, x, skip_quantize=False): |
|
h = self.encoder(x) |
|
h = self.quant_embed(h) |
|
if skip_quantize: |
|
assert not self.training, 'skip_quantize should be used in eval mode only.' |
|
if self.l2_norm: |
|
h = F.normalize(h, dim=-1) |
|
return h, {}, {} |
|
quant, loss, info = self.quantize(h) |
|
return quant, loss, info |
|
|
|
def decode(self, quant): |
|
h = self.post_quant_embed(quant) |
|
x = self.decoder(h) |
|
return x |
|
|
|
def clamp(self, x): |
|
if self.logit_laplace: |
|
dec, _ = x.chunk(2, dim=1) |
|
x = self.logit_laplace_loss.unmap(F.sigmoid(dec)) |
|
else: |
|
x = x.clamp_(self.clamp_range[0], self.clamp_range[1]) |
|
return x |
|
|
|
def forward(self, input, skip_quantize=False): |
|
if self.logit_laplace: |
|
input = self.logit_laplace_loss.inmap(input) |
|
quant, loss, info = self.encode(input, skip_quantize=skip_quantize) |
|
dec = self.decode(quant) |
|
if self.logit_laplace: |
|
dec, lnb = dec.chunk(2, dim=1) |
|
logit_laplace_loss = self.logit_laplace_loss(dec, lnb, input) |
|
info.update({'logit_laplace_loss': logit_laplace_loss}) |
|
dec = self.logit_laplace_loss.unmap(F.sigmoid(dec)) |
|
else: |
|
dec = dec.clamp_(self.clamp_range[0], self.clamp_range[1]) |
|
return dec, loss, info |
|
|
|
def get_last_layer(self): |
|
return self.decoder.conv_out.weight |
|
|
|
|
|
class VITBSQModel(VITVQModel): |
|
def __init__(self, vitconfig, embed_dim, embed_group_size=9, |
|
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[], |
|
grad_checkpointing=False, selective_checkpointing=False, |
|
clamp_range=(0, 1), |
|
dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0, |
|
persample_entropy_compute='group', |
|
cb_entropy_compute='group', |
|
post_q_l2_norm=False, |
|
inv_temperature=1., |
|
): |
|
|
|
self.beta = beta |
|
self.gamma0 = gamma0 |
|
self.gamma = gamma |
|
self.zeta = zeta |
|
self.embed_group_size = embed_group_size |
|
self.persample_entropy_compute = persample_entropy_compute |
|
self.cb_entropy_compute = cb_entropy_compute |
|
self.post_q_l2_norm = post_q_l2_norm |
|
self.inv_temperature = inv_temperature |
|
|
|
|
|
super().__init__( |
|
vitconfig, |
|
2 ** embed_dim, |
|
embed_dim, |
|
l2_norm=l2_norm, |
|
logit_laplace=logit_laplace, |
|
ckpt_path=ckpt_path, |
|
ignore_keys=ignore_keys, |
|
grad_checkpointing=grad_checkpointing, |
|
selective_checkpointing=selective_checkpointing, |
|
clamp_range=clamp_range, |
|
dvitconfig=dvitconfig, |
|
) |
|
|
|
|
|
def setup_quantizer(self): |
|
self.quantize = BinarySphericalQuantizer( |
|
self.embed_dim, self.beta, self.gamma0, self.gamma, self.zeta, |
|
group_size=self.embed_group_size, |
|
persample_entropy_compute=self.persample_entropy_compute, |
|
cb_entropy_compute=self.cb_entropy_compute, |
|
input_format='blc', |
|
l2_norm=self.post_q_l2_norm, |
|
inv_temperature=self.inv_temperature, |
|
) |
|
|
|
def encode(self, x, skip_quantize=False): |
|
h = self.encoder(x) |
|
h = self.quant_embed(h) |
|
if self.l2_norm: |
|
h = F.normalize(h, dim=-1) |
|
if skip_quantize: |
|
assert not self.training, 'skip_quantize should be used in eval mode only.' |
|
return h, {}, {} |
|
quant, loss, info = self.quantize(h) |
|
return quant, loss, info |
|
|