File size: 6,104 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .quantizer.bsq import BinarySphericalQuantizer
from .quantizer.vq import VectorQuantizer
from .transformer import TransformerDecoder, TransformerEncoder
class VITVQModel(nn.Module):
def __init__(self, vitconfig, n_embed, embed_dim,
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
grad_checkpointing=False, selective_checkpointing=False,
clamp_range=(0, 1),
dvitconfig=None,
):
super().__init__()
self.encoder = TransformerEncoder(**vitconfig)
dvitconfig = vitconfig if dvitconfig is None else dvitconfig
self.decoder = TransformerDecoder(**dvitconfig, logit_laplace=logit_laplace)
if self.training and grad_checkpointing:
self.encoder.set_grad_checkpointing(True, selective=selective_checkpointing)
self.decoder.set_grad_checkpointing(True, selective=selective_checkpointing)
self.n_embed = n_embed
self.embed_dim = embed_dim
self.l2_norm = l2_norm
self.setup_quantizer()
self.quant_embed = nn.Linear(in_features=vitconfig['width'], out_features=embed_dim)
self.post_quant_embed = nn.Linear(in_features=embed_dim, out_features=dvitconfig['width'])
self.l2_norm = l2_norm
self.logit_laplace = logit_laplace
self.clamp_range = clamp_range
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def setup_quantizer(self):
self.quantize = VectorQuantizer(self.n_embed, self.embed_dim, l2_norm=self.l2_norm, beta=0.25, input_format='blc')
# def init_from_ckpt(self, ckpt_path, ignore_keys=[]):
def init_from_ckpt(self, state_dict, ignore_keys=[]):
state_dict = {k[7:]: v for k, v in state_dict.items() if k.startswith('module.')}
filtered_state_dict = {k: v for k, v in state_dict.items() if all([not k.startswith(ig) for ig in ignore_keys])}
missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False)
print(f"missing_keys: {missing_keys}")
print(f"unexpected_keys: {unexpected_keys}")
def encode(self, x, skip_quantize=False):
h = self.encoder(x)
h = self.quant_embed(h)
if skip_quantize:
assert not self.training, 'skip_quantize should be used in eval mode only.'
if self.l2_norm:
h = F.normalize(h, dim=-1)
return h, {}, {}
quant, loss, info = self.quantize(h)
return quant, loss, info
def decode(self, quant):
h = self.post_quant_embed(quant)
x = self.decoder(h)
return x
def clamp(self, x):
if self.logit_laplace:
dec, _ = x.chunk(2, dim=1)
x = self.logit_laplace_loss.unmap(F.sigmoid(dec))
else:
x = x.clamp_(self.clamp_range[0], self.clamp_range[1])
return x
def forward(self, input, skip_quantize=False):
if self.logit_laplace:
input = self.logit_laplace_loss.inmap(input)
quant, loss, info = self.encode(input, skip_quantize=skip_quantize)
dec = self.decode(quant)
if self.logit_laplace:
dec, lnb = dec.chunk(2, dim=1)
logit_laplace_loss = self.logit_laplace_loss(dec, lnb, input)
info.update({'logit_laplace_loss': logit_laplace_loss})
dec = self.logit_laplace_loss.unmap(F.sigmoid(dec))
else:
dec = dec.clamp_(self.clamp_range[0], self.clamp_range[1])
return dec, loss, info
def get_last_layer(self):
return self.decoder.conv_out.weight
class VITBSQModel(VITVQModel):
def __init__(self, vitconfig, embed_dim, embed_group_size=9,
l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
grad_checkpointing=False, selective_checkpointing=False,
clamp_range=(0, 1),
dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0,
persample_entropy_compute='group',
cb_entropy_compute='group',
post_q_l2_norm=False,
inv_temperature=1.,
):
# set quantizer params
self.beta = beta # commit loss
self.gamma0 = gamma0 # entropy
self.gamma = gamma # entropy penalty
self.zeta = zeta # lpips
self.embed_group_size = embed_group_size
self.persample_entropy_compute = persample_entropy_compute
self.cb_entropy_compute = cb_entropy_compute
self.post_q_l2_norm = post_q_l2_norm
self.inv_temperature = inv_temperature
# call init
super().__init__(
vitconfig,
2 ** embed_dim,
embed_dim,
l2_norm=l2_norm,
logit_laplace=logit_laplace,
ckpt_path=ckpt_path,
ignore_keys=ignore_keys,
grad_checkpointing=grad_checkpointing,
selective_checkpointing=selective_checkpointing,
clamp_range=clamp_range,
dvitconfig=dvitconfig,
)
def setup_quantizer(self):
self.quantize = BinarySphericalQuantizer(
self.embed_dim, self.beta, self.gamma0, self.gamma, self.zeta,
group_size=self.embed_group_size,
persample_entropy_compute=self.persample_entropy_compute,
cb_entropy_compute=self.cb_entropy_compute,
input_format='blc',
l2_norm=self.post_q_l2_norm,
inv_temperature=self.inv_temperature,
)
def encode(self, x, skip_quantize=False):
h = self.encoder(x)
h = self.quant_embed(h)
if self.l2_norm:
h = F.normalize(h, dim=-1)
if skip_quantize:
assert not self.training, 'skip_quantize should be used in eval mode only.'
return h, {}, {}
quant, loss, info = self.quantize(h)
return quant, loss, info
|