File size: 6,932 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import torch
import torch.nn.functional as F
import lightning as L
from contextlib import contextmanager
from collections import OrderedDict
from .improved_model import Encoder, Decoder
from .lookup_free_quantize import LFQ
from .ema import LitEma
class VQModel(L.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
## Quantize Related
n_embed,
embed_dim,
sample_minimization_weight,
batch_maximization_weight,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
learning_rate=None,
resume_lr=None,
### scheduler config
warmup_epochs=1.0, # warmup epochs
scheduler_type="linear-warmup_cosine-decay",
min_learning_rate=0,
use_ema=False,
token_factorization=False,
stage=None,
lr_drop_epoch=None,
lr_drop_rate=0.1,
factorized_bits=[9, 9],
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quantize = LFQ(
dim=embed_dim,
codebook_size=n_embed,
sample_minimization_weight=sample_minimization_weight,
batch_maximization_weight=batch_maximization_weight,
token_factorization=token_factorization,
factorized_bits=factorized_bits,
)
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = use_ema
if (
self.use_ema and stage is None
): # no need to construct EMA when training Transformer
self.model_ema = LitEma(self)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, stage=stage)
self.resume_lr = resume_lr
self.learning_rate = learning_rate
self.lr_drop_epoch = lr_drop_epoch
self.lr_drop_rate = lr_drop_rate
self.scheduler_type = scheduler_type
self.warmup_epochs = warmup_epochs
self.min_learning_rate = min_learning_rate
self.automatic_optimization = False
self.strict_loading = False
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def load_state_dict(self, *args, strict=False):
"""
Resume not strict loading
"""
return super().load_state_dict(*args, strict=strict)
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
"""
filter out the non-used keys
"""
return {
k: v
for k, v in super()
.state_dict(*args, destination, prefix, keep_vars)
.items()
if (
"inception_model" not in k
and "lpips_vgg" not in k
and "lpips_alex" not in k
)
}
def init_from_ckpt(self, path, ignore_keys=list(), stage="transformer"):
sd = torch.load(path, map_location="cpu")["state_dict"]
ema_mapping = {}
new_params = OrderedDict()
if stage == "transformer": ### directly use ema encoder and decoder parameter
if self.use_ema:
for k, v in sd.items():
if "encoder" in k:
if "model_ema" in k:
k = k.replace(
"model_ema.", ""
) # load EMA Encoder or Decoder
new_k = ema_mapping[k]
new_params[new_k] = v
s_name = k.replace(".", "")
ema_mapping.update({s_name: k})
continue
if "decoder" in k:
if "model_ema" in k:
k = k.replace(
"model_ema.", ""
) # load EMA Encoder or Decoder
new_k = ema_mapping[k]
new_params[new_k] = v
s_name = k.replace(".", "")
ema_mapping.update({s_name: k})
continue
else: # also only load the Generator
for k, v in sd.items():
if "encoder" in k:
new_params[k] = v
elif "decoder" in k:
new_params[k] = v
missing_keys, unexpected_keys = self.load_state_dict(
new_params, strict=False
) # first stage
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
(quant, emb_loss, info), loss_breakdown = self.quantize(
h, return_loss_breakdown=True
)
return quant, emb_loss, info, loss_breakdown
def decode(self, quant):
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input):
quant, diff, img_toks, loss_break = self.encode(input)
pixels = self.decode(quant)
return pixels, img_toks, quant
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).contiguous()
return x.float()
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
|