|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
from tokenizer.tokenizer_image.lpips import LPIPS |
|
from tokenizer.tokenizer_image.discriminator_patchgan import ( |
|
NLayerDiscriminator as PatchGANDiscriminator, |
|
) |
|
from tokenizer.tokenizer_image.discriminator_stylegan import ( |
|
Discriminator as StyleGANDiscriminator, |
|
) |
|
from tokenizer.tokenizer_image.discriminator_dino import DinoDisc as DINODiscriminator |
|
from tokenizer.tokenizer_image.diffaug import DiffAug |
|
import wandb |
|
import torch.distributed as tdist |
|
|
|
|
|
def hinge_d_loss(logits_real, logits_fake): |
|
loss_real = torch.mean(F.relu(1.0 - logits_real)) |
|
loss_fake = torch.mean(F.relu(1.0 + logits_fake)) |
|
d_loss = 0.5 * (loss_real + loss_fake) |
|
return d_loss |
|
|
|
|
|
def vanilla_d_loss(logits_real, logits_fake): |
|
loss_real = torch.mean(F.softplus(-logits_real)) |
|
loss_fake = torch.mean(F.softplus(logits_fake)) |
|
d_loss = 0.5 * (loss_real + loss_fake) |
|
return d_loss |
|
|
|
|
|
def non_saturating_d_loss(logits_real, logits_fake): |
|
loss_real = torch.mean( |
|
F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real) |
|
) |
|
loss_fake = torch.mean( |
|
F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake) |
|
) |
|
d_loss = 0.5 * (loss_real + loss_fake) |
|
return d_loss |
|
|
|
|
|
def hinge_gen_loss(logit_fake): |
|
return -torch.mean(logit_fake) |
|
|
|
|
|
def non_saturating_gen_loss(logit_fake): |
|
return torch.mean( |
|
F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake) |
|
) |
|
|
|
|
|
def adopt_weight(weight, global_step, threshold=0, value=0.0): |
|
if global_step < threshold: |
|
weight = value |
|
return weight |
|
|
|
|
|
def anneal_weight( |
|
weight, |
|
global_step, |
|
threshold=0, |
|
initial_value=0.3, |
|
final_value=0.1, |
|
anneal_steps=2000, |
|
): |
|
if global_step < threshold: |
|
return initial_value |
|
elif global_step < threshold + anneal_steps: |
|
|
|
decay_ratio = (global_step - threshold) / anneal_steps |
|
weight = initial_value - decay_ratio * (initial_value - final_value) |
|
else: |
|
|
|
weight = final_value |
|
return weight |
|
|
|
|
|
class LeCAM_EMA(object): |
|
def __init__(self, init=0.0, decay=0.999): |
|
self.logits_real_ema = init |
|
self.logits_fake_ema = init |
|
self.decay = decay |
|
|
|
def update(self, logits_real, logits_fake): |
|
self.logits_real_ema = self.logits_real_ema * self.decay + torch.mean( |
|
logits_real |
|
).item() * (1 - self.decay) |
|
self.logits_fake_ema = self.logits_fake_ema * self.decay + torch.mean( |
|
logits_fake |
|
).item() * (1 - self.decay) |
|
|
|
|
|
def lecam_reg(real_pred, fake_pred, lecam_ema): |
|
reg = torch.mean(F.relu(real_pred - lecam_ema.logits_fake_ema).pow(2)) + torch.mean( |
|
F.relu(lecam_ema.logits_real_ema - fake_pred).pow(2) |
|
) |
|
return reg |
|
|
|
|
|
class VQLoss(nn.Module): |
|
def __init__( |
|
self, |
|
disc_start, |
|
disc_loss="hinge", |
|
disc_dim=64, |
|
disc_type="patchgan", |
|
image_size=256, |
|
disc_num_layers=3, |
|
disc_in_channels=3, |
|
disc_weight=1.0, |
|
disc_adaptive_weight=False, |
|
gen_adv_loss="hinge", |
|
reconstruction_loss="l2", |
|
reconstruction_weight=1.0, |
|
codebook_weight=1.0, |
|
perceptual_weight=1.0, |
|
lecam_loss_weight=None, |
|
norm_type="bn", |
|
aug_prob=1, |
|
): |
|
super().__init__() |
|
|
|
assert disc_type in ["patchgan", "stylegan", "dinodisc", "samdisc"] |
|
assert disc_loss in ["hinge", "vanilla", "non-saturating"] |
|
self.disc_type = disc_type |
|
if disc_type == "patchgan": |
|
self.discriminator = PatchGANDiscriminator( |
|
input_nc=disc_in_channels, |
|
n_layers=disc_num_layers, |
|
ndf=disc_dim, |
|
) |
|
elif disc_type == "stylegan": |
|
self.discriminator = StyleGANDiscriminator( |
|
input_nc=disc_in_channels, |
|
image_size=image_size, |
|
) |
|
elif disc_type == "dinodisc": |
|
self.discriminator = DINODiscriminator( |
|
norm_type=norm_type |
|
) |
|
self.daug = DiffAug(prob=aug_prob, cutout=0.2) |
|
elif disc_type == "samdisc": |
|
self.discriminator = SAMDiscriminator(norm_type=norm_type) |
|
else: |
|
raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.") |
|
if disc_loss == "hinge": |
|
self.disc_loss = hinge_d_loss |
|
elif disc_loss == "vanilla": |
|
self.disc_loss = vanilla_d_loss |
|
elif disc_loss == "non-saturating": |
|
self.disc_loss = non_saturating_d_loss |
|
else: |
|
raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.") |
|
self.discriminator_iter_start = disc_start |
|
self.disc_weight = disc_weight |
|
self.disc_adaptive_weight = disc_adaptive_weight |
|
|
|
assert gen_adv_loss in ["hinge", "non-saturating"] |
|
|
|
if gen_adv_loss == "hinge": |
|
self.gen_adv_loss = hinge_gen_loss |
|
elif gen_adv_loss == "non-saturating": |
|
self.gen_adv_loss = non_saturating_gen_loss |
|
else: |
|
raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.") |
|
|
|
|
|
self.perceptual_loss = LPIPS().eval() |
|
self.perceptual_weight = perceptual_weight |
|
|
|
|
|
if reconstruction_loss == "l1": |
|
self.rec_loss = F.l1_loss |
|
elif reconstruction_loss == "l2": |
|
self.rec_loss = F.mse_loss |
|
else: |
|
raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.") |
|
self.rec_weight = reconstruction_weight |
|
|
|
|
|
self.codebook_weight = codebook_weight |
|
|
|
self.lecam_loss_weight = lecam_loss_weight |
|
if self.lecam_loss_weight is not None: |
|
self.lecam_ema = LeCAM_EMA() |
|
|
|
if tdist.get_rank() == 0: |
|
self.wandb_tracker = wandb.init( |
|
project="MSVQ", |
|
) |
|
|
|
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): |
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] |
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] |
|
|
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) |
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() |
|
return d_weight.detach() |
|
|
|
def forward( |
|
self, |
|
codebook_loss, |
|
sem_loss, |
|
detail_loss, |
|
dependency_loss, |
|
inputs, |
|
reconstructions, |
|
optimizer_idx, |
|
global_step, |
|
last_layer=None, |
|
logger=None, |
|
log_every=100, |
|
fade_blur_schedule=0, |
|
): |
|
|
|
if optimizer_idx == 0: |
|
|
|
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous()) |
|
|
|
|
|
p_loss = self.perceptual_loss( |
|
inputs.contiguous(), reconstructions.contiguous() |
|
) |
|
p_loss = torch.mean(p_loss) |
|
|
|
|
|
if self.disc_type == "dinodisc": |
|
if fade_blur_schedule < 1e-6: |
|
fade_blur_schedule = 0 |
|
logits_fake = self.discriminator( |
|
self.daug.aug(reconstructions.contiguous(), fade_blur_schedule) |
|
) |
|
else: |
|
logits_fake = self.discriminator(reconstructions.contiguous()) |
|
generator_adv_loss = self.gen_adv_loss(logits_fake) |
|
|
|
if self.disc_adaptive_weight: |
|
null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss |
|
disc_adaptive_weight = self.calculate_adaptive_weight( |
|
null_loss, generator_adv_loss, last_layer=last_layer |
|
) |
|
else: |
|
disc_adaptive_weight = 1 |
|
disc_weight = adopt_weight( |
|
self.disc_weight, global_step, threshold=self.discriminator_iter_start |
|
) |
|
if sem_loss is None: |
|
sem_loss = 0 |
|
if detail_loss is None: |
|
detail_loss = 0 |
|
if dependency_loss is None: |
|
dependency_loss = 0 |
|
loss = ( |
|
self.rec_weight * rec_loss |
|
+ self.perceptual_weight * p_loss |
|
+ disc_adaptive_weight * disc_weight * generator_adv_loss |
|
+ codebook_loss[0] |
|
+ codebook_loss[1] |
|
+ codebook_loss[2] |
|
+ sem_loss |
|
+ detail_loss |
|
+ dependency_loss |
|
) |
|
|
|
if global_step % log_every == 0: |
|
rec_loss = self.rec_weight * rec_loss |
|
p_loss = self.perceptual_weight * p_loss |
|
generator_adv_loss = ( |
|
disc_adaptive_weight * disc_weight * generator_adv_loss |
|
) |
|
logger.info( |
|
f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, sem_loss: {sem_loss:.4f}, detail_loss: {detail_loss} " |
|
f"dependency_loss: {dependency_loss} vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, " |
|
f"codebook_usage: {codebook_loss[3]}, generator_adv_loss: {generator_adv_loss:.4f}, " |
|
f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}" |
|
) |
|
if tdist.get_rank() == 0: |
|
self.wandb_tracker.log( |
|
{ |
|
"rec_loss": rec_loss, |
|
"perceptual_loss": p_loss, |
|
"sem_loss": sem_loss, |
|
"detail_loss": detail_loss, |
|
"dependency_loss": dependency_loss, |
|
"vq_loss": codebook_loss[0], |
|
"commit_loss": codebook_loss[1], |
|
"entropy_loss": codebook_loss[2], |
|
"codebook_usage": np.mean(codebook_loss[3]), |
|
"generator_adv_loss": generator_adv_loss, |
|
"disc_adaptive_weight": disc_adaptive_weight, |
|
"disc_weight": disc_weight, |
|
}, |
|
step=global_step, |
|
) |
|
return loss |
|
|
|
|
|
if optimizer_idx == 1: |
|
|
|
if self.disc_type == "dinodisc": |
|
if fade_blur_schedule < 1e-6: |
|
fade_blur_schedule = 0 |
|
|
|
logits_fake = self.discriminator( |
|
self.daug.aug( |
|
reconstructions.contiguous().detach(), fade_blur_schedule |
|
) |
|
) |
|
logits_real = self.discriminator( |
|
self.daug.aug(inputs.contiguous().detach(), fade_blur_schedule) |
|
) |
|
else: |
|
logits_fake = self.discriminator(reconstructions.contiguous().detach()) |
|
logits_real = self.discriminator(inputs.contiguous().detach()) |
|
|
|
disc_weight = adopt_weight( |
|
self.disc_weight, global_step, threshold=self.discriminator_iter_start |
|
) |
|
|
|
if self.lecam_loss_weight is not None: |
|
self.lecam_ema.update(logits_real, logits_fake) |
|
lecam_loss = lecam_reg(logits_real, logits_fake, self.lecam_ema) |
|
non_saturate_d_loss = self.disc_loss(logits_real, logits_fake) |
|
d_adversarial_loss = disc_weight * ( |
|
lecam_loss * self.lecam_loss_weight + non_saturate_d_loss |
|
) |
|
else: |
|
d_adversarial_loss = disc_weight * self.disc_loss( |
|
logits_real, logits_fake |
|
) |
|
|
|
if global_step % log_every == 0: |
|
logits_real = logits_real.detach().mean() |
|
logits_fake = logits_fake.detach().mean() |
|
logger.info( |
|
f"(Discriminator) " |
|
f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, " |
|
f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}" |
|
) |
|
if tdist.get_rank() == 0: |
|
self.wandb_tracker.log( |
|
{ |
|
"discriminator_adv_loss": d_adversarial_loss, |
|
"disc_weight": disc_weight, |
|
"logits_real": logits_real, |
|
"logits_fake": logits_fake, |
|
}, |
|
step=global_step, |
|
) |
|
return d_adversarial_loss |
|
|