File size: 13,383 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
# Modified from:
# taming-transformers: https://github.com/CompVis/taming-transformers
# muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tokenizer.tokenizer_image.lpips import LPIPS
from tokenizer.tokenizer_image.discriminator_patchgan import (
NLayerDiscriminator as PatchGANDiscriminator,
)
from tokenizer.tokenizer_image.discriminator_stylegan import (
Discriminator as StyleGANDiscriminator,
)
from tokenizer.tokenizer_image.discriminator_dino import DinoDisc as DINODiscriminator
from tokenizer.tokenizer_image.diffaug import DiffAug
import wandb
import torch.distributed as tdist
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.softplus(-logits_real))
loss_fake = torch.mean(F.softplus(logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def non_saturating_d_loss(logits_real, logits_fake):
loss_real = torch.mean(
F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real)
)
loss_fake = torch.mean(
F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake)
)
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def hinge_gen_loss(logit_fake):
return -torch.mean(logit_fake)
def non_saturating_gen_loss(logit_fake):
return torch.mean(
F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake)
)
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def anneal_weight(
weight,
global_step,
threshold=0,
initial_value=0.3,
final_value=0.1,
anneal_steps=2000,
):
if global_step < threshold:
return initial_value
elif global_step < threshold + anneal_steps:
# Linearly interpolate between initial and final values within the anneal_steps
decay_ratio = (global_step - threshold) / anneal_steps
weight = initial_value - decay_ratio * (initial_value - final_value)
else:
# After annealing steps, set to final value
weight = final_value
return weight
class LeCAM_EMA(object):
def __init__(self, init=0.0, decay=0.999):
self.logits_real_ema = init
self.logits_fake_ema = init
self.decay = decay
def update(self, logits_real, logits_fake):
self.logits_real_ema = self.logits_real_ema * self.decay + torch.mean(
logits_real
).item() * (1 - self.decay)
self.logits_fake_ema = self.logits_fake_ema * self.decay + torch.mean(
logits_fake
).item() * (1 - self.decay)
def lecam_reg(real_pred, fake_pred, lecam_ema):
reg = torch.mean(F.relu(real_pred - lecam_ema.logits_fake_ema).pow(2)) + torch.mean(
F.relu(lecam_ema.logits_real_ema - fake_pred).pow(2)
)
return reg
class VQLoss(nn.Module):
def __init__(
self,
disc_start,
disc_loss="hinge",
disc_dim=64,
disc_type="patchgan",
image_size=256,
disc_num_layers=3,
disc_in_channels=3,
disc_weight=1.0,
disc_adaptive_weight=False,
gen_adv_loss="hinge",
reconstruction_loss="l2",
reconstruction_weight=1.0,
codebook_weight=1.0,
perceptual_weight=1.0,
lecam_loss_weight=None,
norm_type="bn",
aug_prob=1,
):
super().__init__()
# discriminator loss
assert disc_type in ["patchgan", "stylegan", "dinodisc", "samdisc"]
assert disc_loss in ["hinge", "vanilla", "non-saturating"]
self.disc_type = disc_type
if disc_type == "patchgan":
self.discriminator = PatchGANDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
ndf=disc_dim,
)
elif disc_type == "stylegan":
self.discriminator = StyleGANDiscriminator(
input_nc=disc_in_channels,
image_size=image_size,
)
elif disc_type == "dinodisc":
self.discriminator = DINODiscriminator(
norm_type=norm_type
) # default 224 otherwise crop
self.daug = DiffAug(prob=aug_prob, cutout=0.2)
elif disc_type == "samdisc":
self.discriminator = SAMDiscriminator(norm_type=norm_type)
else:
raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
elif disc_loss == "non-saturating":
self.disc_loss = non_saturating_d_loss
else:
raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
self.discriminator_iter_start = disc_start
self.disc_weight = disc_weight
self.disc_adaptive_weight = disc_adaptive_weight
assert gen_adv_loss in ["hinge", "non-saturating"]
# gen_adv_loss
if gen_adv_loss == "hinge":
self.gen_adv_loss = hinge_gen_loss
elif gen_adv_loss == "non-saturating":
self.gen_adv_loss = non_saturating_gen_loss
else:
raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
# perceptual loss
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# reconstruction loss
if reconstruction_loss == "l1":
self.rec_loss = F.l1_loss
elif reconstruction_loss == "l2":
self.rec_loss = F.mse_loss
else:
raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
self.rec_weight = reconstruction_weight
# codebook loss
self.codebook_weight = codebook_weight
self.lecam_loss_weight = lecam_loss_weight
if self.lecam_loss_weight is not None:
self.lecam_ema = LeCAM_EMA()
if tdist.get_rank() == 0:
self.wandb_tracker = wandb.init(
project="MSVQ",
)
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
return d_weight.detach()
def forward(
self,
codebook_loss,
sem_loss,
detail_loss,
dependency_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
logger=None,
log_every=100,
fade_blur_schedule=0,
):
# generator update
if optimizer_idx == 0:
# reconstruction loss
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
# perceptual loss
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
p_loss = torch.mean(p_loss)
# discriminator loss
if self.disc_type == "dinodisc":
if fade_blur_schedule < 1e-6:
fade_blur_schedule = 0
logits_fake = self.discriminator(
self.daug.aug(reconstructions.contiguous(), fade_blur_schedule)
)
else:
logits_fake = self.discriminator(reconstructions.contiguous())
generator_adv_loss = self.gen_adv_loss(logits_fake)
if self.disc_adaptive_weight:
null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
disc_adaptive_weight = self.calculate_adaptive_weight(
null_loss, generator_adv_loss, last_layer=last_layer
)
else:
disc_adaptive_weight = 1
disc_weight = adopt_weight(
self.disc_weight, global_step, threshold=self.discriminator_iter_start
)
if sem_loss is None:
sem_loss = 0
if detail_loss is None:
detail_loss = 0
if dependency_loss is None:
dependency_loss = 0
loss = (
self.rec_weight * rec_loss
+ self.perceptual_weight * p_loss
+ disc_adaptive_weight * disc_weight * generator_adv_loss
+ codebook_loss[0]
+ codebook_loss[1]
+ codebook_loss[2]
+ sem_loss
+ detail_loss
+ dependency_loss
)
if global_step % log_every == 0:
rec_loss = self.rec_weight * rec_loss
p_loss = self.perceptual_weight * p_loss
generator_adv_loss = (
disc_adaptive_weight * disc_weight * generator_adv_loss
)
logger.info(
f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, sem_loss: {sem_loss:.4f}, detail_loss: {detail_loss} "
f"dependency_loss: {dependency_loss} vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
f"codebook_usage: {codebook_loss[3]}, generator_adv_loss: {generator_adv_loss:.4f}, "
f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}"
)
if tdist.get_rank() == 0:
self.wandb_tracker.log(
{
"rec_loss": rec_loss,
"perceptual_loss": p_loss,
"sem_loss": sem_loss,
"detail_loss": detail_loss,
"dependency_loss": dependency_loss,
"vq_loss": codebook_loss[0],
"commit_loss": codebook_loss[1],
"entropy_loss": codebook_loss[2],
"codebook_usage": np.mean(codebook_loss[3]),
"generator_adv_loss": generator_adv_loss,
"disc_adaptive_weight": disc_adaptive_weight,
"disc_weight": disc_weight,
},
step=global_step,
)
return loss
# discriminator update
if optimizer_idx == 1:
if self.disc_type == "dinodisc":
if fade_blur_schedule < 1e-6:
fade_blur_schedule = 0
# add blur since disc is too strong
logits_fake = self.discriminator(
self.daug.aug(
reconstructions.contiguous().detach(), fade_blur_schedule
)
)
logits_real = self.discriminator(
self.daug.aug(inputs.contiguous().detach(), fade_blur_schedule)
)
else:
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(inputs.contiguous().detach())
disc_weight = adopt_weight(
self.disc_weight, global_step, threshold=self.discriminator_iter_start
)
if self.lecam_loss_weight is not None:
self.lecam_ema.update(logits_real, logits_fake)
lecam_loss = lecam_reg(logits_real, logits_fake, self.lecam_ema)
non_saturate_d_loss = self.disc_loss(logits_real, logits_fake)
d_adversarial_loss = disc_weight * (
lecam_loss * self.lecam_loss_weight + non_saturate_d_loss
)
else:
d_adversarial_loss = disc_weight * self.disc_loss(
logits_real, logits_fake
)
if global_step % log_every == 0:
logits_real = logits_real.detach().mean()
logits_fake = logits_fake.detach().mean()
logger.info(
f"(Discriminator) "
f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}"
)
if tdist.get_rank() == 0:
self.wandb_tracker.log(
{
"discriminator_adv_loss": d_adversarial_loss,
"disc_weight": disc_weight,
"logits_real": logits_real,
"logits_fake": logits_fake,
},
step=global_step,
)
return d_adversarial_loss
|