VTBench / src /vqvaes /xqgan /vq_loss.py
huaweilin's picture
update
14ce5a9
# Modified from:
# taming-transformers: https://github.com/CompVis/taming-transformers
# muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tokenizer.tokenizer_image.lpips import LPIPS
from tokenizer.tokenizer_image.discriminator_patchgan import (
NLayerDiscriminator as PatchGANDiscriminator,
)
from tokenizer.tokenizer_image.discriminator_stylegan import (
Discriminator as StyleGANDiscriminator,
)
from tokenizer.tokenizer_image.discriminator_dino import DinoDisc as DINODiscriminator
from tokenizer.tokenizer_image.diffaug import DiffAug
import wandb
import torch.distributed as tdist
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.softplus(-logits_real))
loss_fake = torch.mean(F.softplus(logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def non_saturating_d_loss(logits_real, logits_fake):
loss_real = torch.mean(
F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real)
)
loss_fake = torch.mean(
F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake)
)
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def hinge_gen_loss(logit_fake):
return -torch.mean(logit_fake)
def non_saturating_gen_loss(logit_fake):
return torch.mean(
F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake)
)
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def anneal_weight(
weight,
global_step,
threshold=0,
initial_value=0.3,
final_value=0.1,
anneal_steps=2000,
):
if global_step < threshold:
return initial_value
elif global_step < threshold + anneal_steps:
# Linearly interpolate between initial and final values within the anneal_steps
decay_ratio = (global_step - threshold) / anneal_steps
weight = initial_value - decay_ratio * (initial_value - final_value)
else:
# After annealing steps, set to final value
weight = final_value
return weight
class LeCAM_EMA(object):
def __init__(self, init=0.0, decay=0.999):
self.logits_real_ema = init
self.logits_fake_ema = init
self.decay = decay
def update(self, logits_real, logits_fake):
self.logits_real_ema = self.logits_real_ema * self.decay + torch.mean(
logits_real
).item() * (1 - self.decay)
self.logits_fake_ema = self.logits_fake_ema * self.decay + torch.mean(
logits_fake
).item() * (1 - self.decay)
def lecam_reg(real_pred, fake_pred, lecam_ema):
reg = torch.mean(F.relu(real_pred - lecam_ema.logits_fake_ema).pow(2)) + torch.mean(
F.relu(lecam_ema.logits_real_ema - fake_pred).pow(2)
)
return reg
class VQLoss(nn.Module):
def __init__(
self,
disc_start,
disc_loss="hinge",
disc_dim=64,
disc_type="patchgan",
image_size=256,
disc_num_layers=3,
disc_in_channels=3,
disc_weight=1.0,
disc_adaptive_weight=False,
gen_adv_loss="hinge",
reconstruction_loss="l2",
reconstruction_weight=1.0,
codebook_weight=1.0,
perceptual_weight=1.0,
lecam_loss_weight=None,
norm_type="bn",
aug_prob=1,
):
super().__init__()
# discriminator loss
assert disc_type in ["patchgan", "stylegan", "dinodisc", "samdisc"]
assert disc_loss in ["hinge", "vanilla", "non-saturating"]
self.disc_type = disc_type
if disc_type == "patchgan":
self.discriminator = PatchGANDiscriminator(
input_nc=disc_in_channels,
n_layers=disc_num_layers,
ndf=disc_dim,
)
elif disc_type == "stylegan":
self.discriminator = StyleGANDiscriminator(
input_nc=disc_in_channels,
image_size=image_size,
)
elif disc_type == "dinodisc":
self.discriminator = DINODiscriminator(
norm_type=norm_type
) # default 224 otherwise crop
self.daug = DiffAug(prob=aug_prob, cutout=0.2)
elif disc_type == "samdisc":
self.discriminator = SAMDiscriminator(norm_type=norm_type)
else:
raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
elif disc_loss == "non-saturating":
self.disc_loss = non_saturating_d_loss
else:
raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
self.discriminator_iter_start = disc_start
self.disc_weight = disc_weight
self.disc_adaptive_weight = disc_adaptive_weight
assert gen_adv_loss in ["hinge", "non-saturating"]
# gen_adv_loss
if gen_adv_loss == "hinge":
self.gen_adv_loss = hinge_gen_loss
elif gen_adv_loss == "non-saturating":
self.gen_adv_loss = non_saturating_gen_loss
else:
raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
# perceptual loss
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# reconstruction loss
if reconstruction_loss == "l1":
self.rec_loss = F.l1_loss
elif reconstruction_loss == "l2":
self.rec_loss = F.mse_loss
else:
raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
self.rec_weight = reconstruction_weight
# codebook loss
self.codebook_weight = codebook_weight
self.lecam_loss_weight = lecam_loss_weight
if self.lecam_loss_weight is not None:
self.lecam_ema = LeCAM_EMA()
if tdist.get_rank() == 0:
self.wandb_tracker = wandb.init(
project="MSVQ",
)
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
return d_weight.detach()
def forward(
self,
codebook_loss,
sem_loss,
detail_loss,
dependency_loss,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
logger=None,
log_every=100,
fade_blur_schedule=0,
):
# generator update
if optimizer_idx == 0:
# reconstruction loss
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
# perceptual loss
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
p_loss = torch.mean(p_loss)
# discriminator loss
if self.disc_type == "dinodisc":
if fade_blur_schedule < 1e-6:
fade_blur_schedule = 0
logits_fake = self.discriminator(
self.daug.aug(reconstructions.contiguous(), fade_blur_schedule)
)
else:
logits_fake = self.discriminator(reconstructions.contiguous())
generator_adv_loss = self.gen_adv_loss(logits_fake)
if self.disc_adaptive_weight:
null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
disc_adaptive_weight = self.calculate_adaptive_weight(
null_loss, generator_adv_loss, last_layer=last_layer
)
else:
disc_adaptive_weight = 1
disc_weight = adopt_weight(
self.disc_weight, global_step, threshold=self.discriminator_iter_start
)
if sem_loss is None:
sem_loss = 0
if detail_loss is None:
detail_loss = 0
if dependency_loss is None:
dependency_loss = 0
loss = (
self.rec_weight * rec_loss
+ self.perceptual_weight * p_loss
+ disc_adaptive_weight * disc_weight * generator_adv_loss
+ codebook_loss[0]
+ codebook_loss[1]
+ codebook_loss[2]
+ sem_loss
+ detail_loss
+ dependency_loss
)
if global_step % log_every == 0:
rec_loss = self.rec_weight * rec_loss
p_loss = self.perceptual_weight * p_loss
generator_adv_loss = (
disc_adaptive_weight * disc_weight * generator_adv_loss
)
logger.info(
f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, sem_loss: {sem_loss:.4f}, detail_loss: {detail_loss} "
f"dependency_loss: {dependency_loss} vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
f"codebook_usage: {codebook_loss[3]}, generator_adv_loss: {generator_adv_loss:.4f}, "
f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}"
)
if tdist.get_rank() == 0:
self.wandb_tracker.log(
{
"rec_loss": rec_loss,
"perceptual_loss": p_loss,
"sem_loss": sem_loss,
"detail_loss": detail_loss,
"dependency_loss": dependency_loss,
"vq_loss": codebook_loss[0],
"commit_loss": codebook_loss[1],
"entropy_loss": codebook_loss[2],
"codebook_usage": np.mean(codebook_loss[3]),
"generator_adv_loss": generator_adv_loss,
"disc_adaptive_weight": disc_adaptive_weight,
"disc_weight": disc_weight,
},
step=global_step,
)
return loss
# discriminator update
if optimizer_idx == 1:
if self.disc_type == "dinodisc":
if fade_blur_schedule < 1e-6:
fade_blur_schedule = 0
# add blur since disc is too strong
logits_fake = self.discriminator(
self.daug.aug(
reconstructions.contiguous().detach(), fade_blur_schedule
)
)
logits_real = self.discriminator(
self.daug.aug(inputs.contiguous().detach(), fade_blur_schedule)
)
else:
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(inputs.contiguous().detach())
disc_weight = adopt_weight(
self.disc_weight, global_step, threshold=self.discriminator_iter_start
)
if self.lecam_loss_weight is not None:
self.lecam_ema.update(logits_real, logits_fake)
lecam_loss = lecam_reg(logits_real, logits_fake, self.lecam_ema)
non_saturate_d_loss = self.disc_loss(logits_real, logits_fake)
d_adversarial_loss = disc_weight * (
lecam_loss * self.lecam_loss_weight + non_saturate_d_loss
)
else:
d_adversarial_loss = disc_weight * self.disc_loss(
logits_real, logits_fake
)
if global_step % log_every == 0:
logits_real = logits_real.detach().mean()
logits_fake = logits_fake.detach().mean()
logger.info(
f"(Discriminator) "
f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}"
)
if tdist.get_rank() == 0:
self.wandb_tracker.log(
{
"discriminator_adv_loss": d_adversarial_loss,
"disc_weight": disc_weight,
"logits_real": logits_real,
"logits_fake": logits_fake,
},
step=global_step,
)
return d_adversarial_loss