huaweilin's picture
update
14ce5a9
import torch
import torch.nn as nn
import torch.nn.functional as F
from .quantizer.bsq import BinarySphericalQuantizer
from .quantizer.vq import VectorQuantizer
from .transformer import TransformerDecoder, TransformerEncoder
class VITVQModel(nn.Module):
def __init__(self, vitconfig, n_embed, embed_dim,
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
grad_checkpointing=False, selective_checkpointing=False,
clamp_range=(0, 1),
dvitconfig=None,
):
super().__init__()
self.encoder = TransformerEncoder(**vitconfig)
dvitconfig = vitconfig if dvitconfig is None else dvitconfig
self.decoder = TransformerDecoder(**dvitconfig, logit_laplace=logit_laplace)
if self.training and grad_checkpointing:
self.encoder.set_grad_checkpointing(True, selective=selective_checkpointing)
self.decoder.set_grad_checkpointing(True, selective=selective_checkpointing)
self.n_embed = n_embed
self.embed_dim = embed_dim
self.l2_norm = l2_norm
self.setup_quantizer()
self.quant_embed = nn.Linear(in_features=vitconfig['width'], out_features=embed_dim)
self.post_quant_embed = nn.Linear(in_features=embed_dim, out_features=dvitconfig['width'])
self.l2_norm = l2_norm
self.logit_laplace = logit_laplace
self.clamp_range = clamp_range
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def setup_quantizer(self):
self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, l2_norm=self.l2_norm, beta=0.25, input_format='blc')
# def init_from_ckpt(self, ckpt_path, ignore_keys=[]):
def init_from_ckpt(self, state_dict, ignore_keys=[]):
state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith('module.')}
filtered_state_dict = {k: v for k, v in state_dict.items() if all([not k.startswith(ig) for ig in ignore_keys])}
missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False)
print(f"missing_keys: {missing_keys}")
print(f"unexpected_keys: {unexpected_keys}")
def encode(self, x, skip_quantize=False):
h = self.encoder(x)
h = self.quant_embed(h)
if skip_quantize:
assert not self.training, 'skip_quantize should be used in eval mode only.'
if self.l2_norm:
h = F.normalize(h, dim=-1)
return h, {}, {}
quant, loss, info = self.quantize(h)
return quant, loss, info
def decode(self, quant):
h = self.post_quant_embed(quant)
x = self.decoder(h)
return x
def clamp(self, x):
if self.logit_laplace:
dec, _ = x.chunk(2, dim=1)
x = self.logit_laplace_loss.unmap(F.sigmoid(dec))
else:
x = x.clamp_(self.clamp_range[0], self.clamp_range[1])
return x
def forward(self, input, skip_quantize=False):
if self.logit_laplace:
input = self.logit_laplace_loss.inmap(input)
quant, loss, info = self.encode(input, skip_quantize=skip_quantize)
dec = self.decode(quant)
if self.logit_laplace:
dec, lnb = dec.chunk(2, dim=1)
logit_laplace_loss = self.logit_laplace_loss(dec, lnb, input)
info.update({'logit_laplace_loss': logit_laplace_loss})
dec = self.logit_laplace_loss.unmap(F.sigmoid(dec))
else:
dec = dec.clamp_(self.clamp_range[0], self.clamp_range[1])
return dec, loss, info
def get_last_layer(self):
return self.decoder.conv_out.weight
class VITBSQModel(VITVQModel):
def __init__(self, vitconfig, embed_dim, embed_group_size=9,
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
grad_checkpointing=False, selective_checkpointing=False,
clamp_range=(0, 1),
dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0,
persample_entropy_compute='group',
cb_entropy_compute='group',
post_q_l2_norm=False,
inv_temperature=1.,
):
# set quantizer params
self.beta = beta # commit loss
self.gamma0 = gamma0 # entropy
self.gamma = gamma # entropy penalty
self.zeta = zeta # lpips
self.embed_group_size = embed_group_size
self.persample_entropy_compute = persample_entropy_compute
self.cb_entropy_compute = cb_entropy_compute
self.post_q_l2_norm = post_q_l2_norm
self.inv_temperature = inv_temperature
# call init
super().__init__(
vitconfig,
2 ** embed_dim,
embed_dim,
l2_norm=l2_norm,
logit_laplace=logit_laplace,
ckpt_path=ckpt_path,
ignore_keys=ignore_keys,
grad_checkpointing=grad_checkpointing,
selective_checkpointing=selective_checkpointing,
clamp_range=clamp_range,
dvitconfig=dvitconfig,
)
def setup_quantizer(self):
self.quantize = BinarySphericalQuantizer(
self.embed_dim, self.beta, self.gamma0, self.gamma, self.zeta,
group_size=self.embed_group_size,
persample_entropy_compute=self.persample_entropy_compute,
cb_entropy_compute=self.cb_entropy_compute,
input_format='blc',
l2_norm=self.post_q_l2_norm,
inv_temperature=self.inv_temperature,
)
def encode(self, x, skip_quantize=False):
h = self.encoder(x)
h = self.quant_embed(h)
if self.l2_norm:
h = F.normalize(h, dim=-1)
if skip_quantize:
assert not self.training, 'skip_quantize should be used in eval mode only.'
return h, {}, {}
quant, loss, info = self.quantize(h)
return quant, loss, info