|
|
|
|
|
|
|
import torch |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torchvision.datasets import ImageFolder |
|
from torchvision import transforms |
|
from torchvision.utils import make_grid |
|
from huggingface_hub import upload_folder |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
from PIL import Image |
|
from tqdm import tqdm |
|
import ruamel.yaml as yaml |
|
|
|
import os |
|
import time |
|
import argparse |
|
from glob import glob |
|
from copy import deepcopy |
|
import sys |
|
import math |
|
import numpy as np |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
project_root = os.path.abspath(os.path.join(current_dir, "../..")) |
|
sys.path.append(project_root) |
|
from utils.logger import create_logger |
|
from utils.distributed import init_distributed_mode |
|
from utils.ema import update_ema, requires_grad |
|
from dataset.augmentation import random_crop_arr, center_crop_arr |
|
from dataset.build import build_dataset |
|
from tokenizer.tokenizer_image.xqgan_model import VQ_models |
|
from tokenizer.tokenizer_image.vq_loss import VQLoss |
|
|
|
from timm.scheduler import create_scheduler_v2 as create_scheduler |
|
|
|
from evaluator import Evaluator |
|
import tensorflow.compat.v1 as tf |
|
|
|
try: |
|
import horovod.torch as hvd |
|
except ImportError: |
|
hvd = None |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_random_ratio( |
|
randomness_anneal_start, randomness_anneal_end, end_ratio, cur_step |
|
): |
|
if cur_step < randomness_anneal_start: |
|
return 1.0 |
|
elif cur_step > randomness_anneal_end: |
|
return end_ratio |
|
else: |
|
return ( |
|
1.0 |
|
- (cur_step - randomness_anneal_start) |
|
/ (randomness_anneal_end - randomness_anneal_start) |
|
* end_ratio |
|
) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--data-path", type=str, default="/mnt/localssd/ImageNet2012/train" |
|
) |
|
parser.add_argument( |
|
"--data-face-path", |
|
type=str, |
|
default=None, |
|
help="face datasets to improve vq model", |
|
) |
|
parser.add_argument( |
|
"--cloud-save-path", |
|
type=str, |
|
default="output/debug", |
|
help="please specify a cloud disk path, if not, local path", |
|
) |
|
parser.add_argument( |
|
"--no-local-save", |
|
action="store_true", |
|
help="no save checkpoints to local path for limited disk volume", |
|
) |
|
parser.add_argument( |
|
"--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16" |
|
) |
|
parser.add_argument( |
|
"--vq-ckpt", type=str, default=None, help="ckpt path for resume training" |
|
) |
|
parser.add_argument( |
|
"--finetune", action="store_true", help="finetune a pre-trained vq model" |
|
) |
|
parser.add_argument("--ema", action="store_true", help="whether using ema training") |
|
parser.add_argument( |
|
"--codebook-size", |
|
type=int, |
|
default=16384, |
|
help="codebook size for vector quantization", |
|
) |
|
parser.add_argument( |
|
"--codebook-embed-dim", |
|
type=int, |
|
default=8, |
|
help="codebook dimension for vector quantization", |
|
) |
|
parser.add_argument( |
|
"--codebook-l2-norm", action="store_true", default=True, help="l2 norm codebook" |
|
) |
|
parser.add_argument( |
|
"--codebook-weight", |
|
type=float, |
|
default=1.0, |
|
help="codebook loss weight for vector quantization", |
|
) |
|
parser.add_argument( |
|
"--entropy-loss-ratio", |
|
type=float, |
|
default=0.0, |
|
help="entropy loss ratio in codebook loss", |
|
) |
|
parser.add_argument( |
|
"--commit-loss-beta", |
|
type=float, |
|
default=0.25, |
|
help="commit loss beta in codebook loss", |
|
) |
|
parser.add_argument( |
|
"--reconstruction-weight", |
|
type=float, |
|
default=1.0, |
|
help="reconstruction loss weight of image pixel", |
|
) |
|
parser.add_argument( |
|
"--reconstruction-loss", |
|
type=str, |
|
default="l2", |
|
help="reconstruction loss type of image pixel", |
|
) |
|
parser.add_argument( |
|
"--perceptual-weight", |
|
type=float, |
|
default=1.0, |
|
help="perceptual loss weight of LPIPS", |
|
) |
|
parser.add_argument( |
|
"--disc-weight", |
|
type=float, |
|
default=0.5, |
|
help="discriminator loss weight for gan training", |
|
) |
|
parser.add_argument( |
|
"--disc-epoch-start", |
|
type=int, |
|
default=0, |
|
help="iteration to start discriminator training and loss", |
|
) |
|
parser.add_argument( |
|
"--disc-start", |
|
type=int, |
|
default=0, |
|
help="iteration to start discriminator training and loss", |
|
) |
|
parser.add_argument( |
|
"--disc-type", |
|
type=str, |
|
choices=["patchgan", "stylegan"], |
|
default="patchgan", |
|
help="discriminator type", |
|
) |
|
parser.add_argument( |
|
"--disc-loss", |
|
type=str, |
|
choices=["hinge", "vanilla", "non-saturating"], |
|
default="hinge", |
|
help="discriminator loss", |
|
) |
|
parser.add_argument( |
|
"--gen-loss", |
|
type=str, |
|
choices=["hinge", "non-saturating"], |
|
default="hinge", |
|
help="generator loss for gan training", |
|
) |
|
parser.add_argument("--compile", action="store_true", default=False) |
|
parser.add_argument("--dropout-p", type=float, default=0.0, help="dropout_p") |
|
parser.add_argument("--results-dir", type=str, default="results_tokenizer_image") |
|
parser.add_argument("--dataset", type=str, default="imagenet") |
|
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) |
|
parser.add_argument("--epochs", type=int, default=40) |
|
parser.add_argument("--lr", type=float, default=1e-4) |
|
parser.add_argument("--disc_lr", type=float, default=1e-4) |
|
parser.add_argument("--max_grad_norm", type=float, default=0.0) |
|
parser.add_argument("--lr_scheduler", type=str, default="none") |
|
parser.add_argument( |
|
"--weight-decay", type=float, default=0.0, help="Weight decay to use." |
|
) |
|
parser.add_argument( |
|
"--disc-weight-decay", type=float, default=0.0, help="Weight decay to use." |
|
) |
|
parser.add_argument( |
|
"--beta1", |
|
type=float, |
|
default=0.9, |
|
help="The beta1 parameter for the Adam optimizer.", |
|
) |
|
parser.add_argument( |
|
"--beta2", |
|
type=float, |
|
default=0.95, |
|
help="The beta2 parameter for the Adam optimizer.", |
|
) |
|
parser.add_argument( |
|
"--max-grad-norm", default=1.0, type=float, help="Max gradient norm." |
|
) |
|
parser.add_argument("--global-batch-size", type=int, default=128) |
|
parser.add_argument("--global-seed", type=int, default=0) |
|
parser.add_argument("--num-workers", type=int, default=16) |
|
parser.add_argument("--log-every", type=int, default=100) |
|
parser.add_argument("--vis-every", type=int, default=5000) |
|
parser.add_argument("--ckpt-every", type=int, default=10000) |
|
parser.add_argument("--gradient-accumulation-steps", type=int, default=1) |
|
parser.add_argument( |
|
"--mixed-precision", type=str, default="bf16", choices=["none", "fp16", "bf16"] |
|
) |
|
parser.add_argument("--save_best", action="store_true", default=False) |
|
parser.add_argument( |
|
"--val_data_path", type=str, default="/mnt/localssd/ImageNet2012/val" |
|
) |
|
parser.add_argument("--sample_folder_dir", type=str, default="samples") |
|
parser.add_argument( |
|
"--reconstruction_folder_dir", type=str, default="reconstruction" |
|
) |
|
parser.add_argument( |
|
"--v-patch-nums", |
|
type=int, |
|
default=[1, 2, 3, 4, 5, 6, 8, 10, 13, 16], |
|
nargs="+", |
|
help="number of patch numbers of each scale", |
|
) |
|
parser.add_argument("--enc_type", type=str, default="cnn") |
|
parser.add_argument("--dec_type", type=str, default="cnn") |
|
parser.add_argument("--semantic_guide", type=str, default="none") |
|
parser.add_argument("--detail_guide", type=str, default="none") |
|
parser.add_argument("--num_latent_tokens", type=int, default=256) |
|
parser.add_argument( |
|
"--encoder_model", |
|
type=str, |
|
default="vit_small_patch14_dinov2.lvd142m", |
|
help="encoder model name", |
|
) |
|
parser.add_argument( |
|
"--decoder_model", |
|
type=str, |
|
default="vit_small_patch14_dinov2.lvd142m", |
|
help="encoder model name", |
|
) |
|
parser.add_argument("--disc_adaptive_weight", type=bool, default=False) |
|
parser.add_argument("--abs_pos_embed", type=bool, default=False) |
|
parser.add_argument("--product_quant", type=int, default=1) |
|
parser.add_argument("--share_quant_resi", type=int, default=4) |
|
parser.add_argument("--codebook_drop", type=float, default=0.0) |
|
parser.add_argument("--half_sem", type=bool, default=False) |
|
parser.add_argument("--start_drop", type=int, default=1) |
|
parser.add_argument("--lecam_loss_weight", type=float, default=None) |
|
parser.add_argument("--sem_loss_weight", type=float, default=0.1) |
|
parser.add_argument("--detail_loss_weight", type=float, default=0.1) |
|
parser.add_argument("--enc_tuning_method", type=str, default="full") |
|
parser.add_argument("--dec_tuning_method", type=str, default="full") |
|
parser.add_argument("--clip_norm", type=bool, default=False) |
|
parser.add_argument("--sem_loss_scale", type=float, default=1.0) |
|
parser.add_argument("--detail_loss_scale", type=float, default=1.0) |
|
parser.add_argument("--config", type=str, default=None) |
|
parser.add_argument("--norm_type", type=str, default="bn") |
|
parser.add_argument("--aug_prob", type=float, default=1.0) |
|
parser.add_argument("--aug_fade_steps", type=int, default=0) |
|
parser.add_argument("--disc_reinit", type=int, default=0) |
|
parser.add_argument("--debug_disc", type=bool, default=False) |
|
parser.add_argument( |
|
"--guide_type_1", type=str, default="class", choices=["patch", "class"] |
|
) |
|
parser.add_argument( |
|
"--guide_type_2", type=str, default="class", choices=["patch", "class"] |
|
) |
|
parser.add_argument("--lfq", action="store_true", default=False, help="if use LFQ") |
|
|
|
parser.add_argument("--end-ratio", type=float, default=0.5) |
|
parser.add_argument("--anneal-start", type=int, default=200) |
|
parser.add_argument("--anneal-end", type=int, default=200) |
|
|
|
parser.add_argument("--alpha", type=float, default=0.0) |
|
parser.add_argument("--beta", type=float, default=0.0) |
|
parser.add_argument("--delta", type=int, default=100) |
|
|
|
args = parser.parse_args() |
|
if args.config is not None: |
|
with open(args.config, "r", encoding="utf-8") as f: |
|
file_yaml = yaml.YAML() |
|
config_args = file_yaml.load(f) |
|
parser.set_defaults(**config_args) |
|
|
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(args): |
|
""" |
|
Trains a new model. |
|
""" |
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
|
|
|
|
|
init_distributed_mode(args) |
|
assert ( |
|
args.global_batch_size % dist.get_world_size() == 0 |
|
), f"Batch size must be divisible by world size." |
|
rank = dist.get_rank() |
|
device = rank % torch.cuda.device_count() |
|
seed = args.global_seed * dist.get_world_size() + rank |
|
torch.manual_seed(seed) |
|
torch.cuda.set_device(device) |
|
|
|
|
|
if rank == 0: |
|
os.makedirs( |
|
args.results_dir, exist_ok=True |
|
) |
|
experiment_index = len(glob(f"{args.results_dir}/*")) |
|
model_string_name = args.vq_model.replace("/", "-") |
|
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" |
|
checkpoint_dir = ( |
|
f"{experiment_dir}/checkpoints" |
|
) |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
logger = create_logger(experiment_dir) |
|
logger.info(f"Experiment directory created at {experiment_dir}") |
|
|
|
cloud_results_dir = f"{args.cloud_save_path}" |
|
cloud_checkpoint_dir = f"{cloud_results_dir}" |
|
os.makedirs(cloud_checkpoint_dir, exist_ok=True) |
|
logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}") |
|
|
|
experiment_config = vars(args) |
|
with open( |
|
os.path.join(cloud_checkpoint_dir, "config.yaml"), "w", encoding="utf-8" |
|
) as f: |
|
|
|
file_yaml = yaml.YAML() |
|
file_yaml.dump(experiment_config, f) |
|
|
|
else: |
|
logger = create_logger(None) |
|
|
|
|
|
logger.info(f"{args}") |
|
|
|
|
|
logger.info( |
|
f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}." |
|
) |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Lambda( |
|
lambda pil_image: random_crop_arr(pil_image, args.image_size) |
|
), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True |
|
), |
|
] |
|
) |
|
dataset = build_dataset(args, transform=transform) |
|
sampler = DistributedSampler( |
|
dataset, |
|
num_replicas=dist.get_world_size(), |
|
rank=rank, |
|
shuffle=True, |
|
seed=args.global_seed, |
|
) |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=int(args.global_batch_size // dist.get_world_size()), |
|
shuffle=False, |
|
sampler=sampler, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") |
|
|
|
if args.save_best: |
|
transform = transforms.Compose( |
|
[ |
|
transforms.Lambda( |
|
lambda pil_image: center_crop_arr(pil_image, args.image_size) |
|
), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True |
|
), |
|
] |
|
) |
|
args.data_path = args.val_data_path |
|
val_dataset = build_dataset(args, transform=transform) |
|
val_sampler = DistributedSampler( |
|
val_dataset, |
|
num_replicas=dist.get_world_size(), |
|
rank=rank, |
|
shuffle=False, |
|
seed=args.global_seed, |
|
) |
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=int(args.global_batch_size // dist.get_world_size()), |
|
shuffle=False, |
|
sampler=val_sampler, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
if rank % torch.cuda.device_count() == 0: |
|
os.makedirs(args.sample_folder_dir, exist_ok=True) |
|
os.makedirs(args.reconstruction_folder_dir, exist_ok=True) |
|
logger.info(f"Saving .png samples at {args.sample_folder_dir}") |
|
logger.info( |
|
f"Saving .png reconstruction at {args.reconstruction_folder_dir}" |
|
) |
|
|
|
num_update_steps_per_epoch = len(loader) |
|
max_train_steps = args.epochs * num_update_steps_per_epoch |
|
args.disc_start = args.disc_epoch_start * num_update_steps_per_epoch |
|
|
|
|
|
vq_model = VQ_models[args.vq_model]( |
|
codebook_size=args.codebook_size, |
|
codebook_embed_dim=args.codebook_embed_dim, |
|
commit_loss_beta=args.commit_loss_beta, |
|
entropy_loss_ratio=args.entropy_loss_ratio, |
|
dropout_p=args.dropout_p, |
|
v_patch_nums=args.v_patch_nums, |
|
enc_type=args.enc_type, |
|
encoder_model=args.encoder_model, |
|
dec_type=args.dec_type, |
|
decoder_model=args.decoder_model, |
|
semantic_guide=args.semantic_guide, |
|
detail_guide=args.detail_guide, |
|
num_latent_tokens=args.num_latent_tokens, |
|
abs_pos_embed=args.abs_pos_embed, |
|
share_quant_resi=args.share_quant_resi, |
|
product_quant=args.product_quant, |
|
codebook_drop=args.codebook_drop, |
|
half_sem=args.half_sem, |
|
start_drop=args.start_drop, |
|
sem_loss_weight=args.sem_loss_weight, |
|
detail_loss_weight=args.detail_loss_weight, |
|
clip_norm=args.clip_norm, |
|
sem_loss_scale=args.sem_loss_scale, |
|
detail_loss_scale=args.detail_loss_scale, |
|
guide_type_1=args.guide_type_1, |
|
guide_type_2=args.guide_type_2, |
|
lfq=args.lfq, |
|
) |
|
logger.info( |
|
f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}" |
|
) |
|
if args.ema: |
|
ema = deepcopy(vq_model).to( |
|
device |
|
) |
|
requires_grad(ema, False) |
|
logger.info( |
|
f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}" |
|
) |
|
vq_model = vq_model.to(device) |
|
vq_loss = VQLoss( |
|
disc_start=args.disc_start, |
|
disc_weight=args.disc_weight, |
|
disc_type=args.disc_type, |
|
disc_loss=args.disc_loss, |
|
gen_adv_loss=args.gen_loss, |
|
image_size=args.image_size, |
|
perceptual_weight=args.perceptual_weight, |
|
reconstruction_weight=args.reconstruction_weight, |
|
reconstruction_loss=args.reconstruction_loss, |
|
codebook_weight=args.codebook_weight, |
|
lecam_loss_weight=args.lecam_loss_weight, |
|
disc_adaptive_weight=args.disc_adaptive_weight, |
|
norm_type=args.norm_type, |
|
aug_prob=args.aug_prob, |
|
).to(device) |
|
logger.info( |
|
f"Discriminator Parameters: {sum(p.numel() for p in vq_loss.discriminator.parameters()):,}" |
|
) |
|
|
|
args.lr = args.lr * args.global_batch_size / 128 |
|
args.disc_lr = args.disc_lr * args.global_batch_size / 128 |
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision == "fp16")) |
|
scaler_disc = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision == "fp16")) |
|
|
|
optimizer = torch.optim.AdamW( |
|
vq_model.parameters(), |
|
lr=args.lr, |
|
betas=(args.beta1, args.beta2), |
|
weight_decay=args.weight_decay, |
|
) |
|
optimizer_disc = torch.optim.AdamW( |
|
vq_loss.discriminator.parameters(), |
|
lr=args.disc_lr, |
|
betas=(args.beta1, args.beta2), |
|
weight_decay=args.disc_weight_decay, |
|
) |
|
|
|
|
|
if args.lr_scheduler == "none": |
|
vqvae_lr_scheduler = None |
|
disc_lr_scheduler = None |
|
else: |
|
vqvae_lr_scheduler, _ = create_scheduler( |
|
sched=args.lr_scheduler, |
|
optimizer=optimizer, |
|
patience_epochs=0, |
|
step_on_epochs=True, |
|
updates_per_epoch=num_update_steps_per_epoch, |
|
num_epochs=args.epochs, |
|
warmup_epochs=1, |
|
min_lr=5e-5, |
|
) |
|
disc_lr_scheduler, _ = create_scheduler( |
|
sched=args.lr_scheduler, |
|
optimizer=optimizer_disc, |
|
patience_epochs=0, |
|
step_on_epochs=True, |
|
updates_per_epoch=num_update_steps_per_epoch, |
|
num_epochs=args.epochs - args.disc_epoch_start, |
|
warmup_epochs=int(0.02 * args.epochs), |
|
min_lr=5e-5, |
|
) |
|
|
|
logger.info( |
|
f"num_update_steps_per_epoch {num_update_steps_per_epoch:,} max_train_steps ({max_train_steps})" |
|
) |
|
|
|
|
|
if args.vq_ckpt: |
|
checkpoint = torch.load(args.vq_ckpt, map_location="cpu") |
|
vq_model.load_state_dict(checkpoint["model"]) |
|
if args.ema: |
|
ema.load_state_dict(checkpoint["ema"]) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
if not args.debug_disc: |
|
vq_loss.discriminator.load_state_dict(checkpoint["discriminator"]) |
|
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) |
|
else: |
|
num_step = checkpoint["optimizer_disc"]["state"][ |
|
next(iter(checkpoint["optimizer_disc"]["state"])) |
|
]["step"] |
|
for param_state in optimizer_disc.state.values(): |
|
param_state["step"] = num_step |
|
if not args.finetune: |
|
train_steps = ( |
|
checkpoint["steps"] |
|
if "steps" in checkpoint |
|
else int(args.vq_ckpt.split("/")[-1].split(".")[0]) |
|
) |
|
start_epoch = ( |
|
int(train_steps / int(len(dataset) / args.global_batch_size)) + 1 |
|
) |
|
train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size)) |
|
else: |
|
train_steps = 0 |
|
start_epoch = 0 |
|
del checkpoint |
|
vq_model.finetune(args.enc_tuning_method, args.dec_tuning_method) |
|
logger.info(f"Resume training from checkpoint: {args.vq_ckpt}") |
|
logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}") |
|
else: |
|
train_steps = 0 |
|
start_epoch = 0 |
|
if args.ema: |
|
update_ema( |
|
ema, vq_model, decay=0 |
|
) |
|
|
|
if args.compile: |
|
logger.info("compiling the model... (may take several minutes)") |
|
vq_model = torch.compile(vq_model, mode="max-autotune") |
|
|
|
vq_model = DDP(vq_model.to(device), device_ids=[args.gpu]) |
|
vq_model.train() |
|
if args.ema: |
|
ema.eval() |
|
vq_loss = DDP(vq_loss.to(device), device_ids=[args.gpu]) |
|
vq_loss.train() |
|
|
|
ptdtype = {"none": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[ |
|
args.mixed_precision |
|
] |
|
|
|
|
|
log_steps = 0 |
|
running_loss = 0 |
|
start_time = time.time() |
|
curr_fid = None |
|
|
|
logger.info(f"Training for {args.epochs} epochs...") |
|
for epoch in range(start_epoch, args.epochs): |
|
ratio = get_random_ratio( |
|
args.anneal_start, args.anneal_end, args.end_ratio, epoch |
|
) |
|
delta = int(ratio * args.delta) |
|
alpha = ratio * args.alpha |
|
beta = args.beta |
|
|
|
sampler.set_epoch(epoch) |
|
logger.info(f"Beginning epoch {epoch}...") |
|
if args.disc_reinit != 0: |
|
if epoch % args.disc_reinit == 0: |
|
vq_loss.module.discriminator.reinit() |
|
for x, y in loader: |
|
imgs = x.to(device, non_blocking=True) |
|
|
|
if args.aug_fade_steps >= 0: |
|
fade_blur_schedule = ( |
|
0 |
|
if train_steps < args.disc_start |
|
else min( |
|
1.0, (train_steps - args.disc_start) / (args.aug_fade_steps + 1) |
|
) |
|
) |
|
fade_blur_schedule = 1 - fade_blur_schedule |
|
else: |
|
fade_blur_schedule = 0 |
|
|
|
optimizer.zero_grad() |
|
with torch.cuda.amp.autocast(dtype=ptdtype): |
|
recons_imgs, codebook_loss, sem_loss, detail_loss, dependency_loss = ( |
|
vq_model(imgs, epoch, alpha, beta, delta) |
|
) |
|
loss_gen = vq_loss( |
|
codebook_loss, |
|
sem_loss, |
|
detail_loss, |
|
dependency_loss, |
|
imgs, |
|
recons_imgs, |
|
optimizer_idx=0, |
|
global_step=train_steps + 1, |
|
last_layer=vq_model.module.decoder.last_layer, |
|
logger=logger, |
|
log_every=args.log_every, |
|
fade_blur_schedule=fade_blur_schedule, |
|
) |
|
|
|
scaler.scale(loss_gen).backward() |
|
if args.max_grad_norm != 0.0: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_( |
|
vq_model.parameters(), args.max_grad_norm |
|
) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
if args.ema: |
|
update_ema( |
|
ema, vq_model.module._orig_mod if args.compile else vq_model.module |
|
) |
|
|
|
|
|
optimizer_disc.zero_grad() |
|
|
|
with torch.cuda.amp.autocast(dtype=ptdtype): |
|
loss_disc = vq_loss( |
|
codebook_loss, |
|
sem_loss, |
|
detail_loss, |
|
dependency_loss, |
|
imgs, |
|
recons_imgs, |
|
optimizer_idx=1, |
|
global_step=train_steps + 1, |
|
logger=logger, |
|
log_every=args.log_every, |
|
fade_blur_schedule=fade_blur_schedule, |
|
) |
|
scaler_disc.scale(loss_disc).backward() |
|
if args.max_grad_norm != 0.0: |
|
scaler_disc.unscale_(optimizer_disc) |
|
torch.nn.utils.clip_grad_norm_( |
|
vq_loss.module.discriminator.parameters(), args.max_grad_norm |
|
) |
|
scaler_disc.step(optimizer_disc) |
|
scaler_disc.update() |
|
|
|
|
|
running_loss += loss_gen.item() + loss_disc.item() |
|
|
|
log_steps += 1 |
|
train_steps += 1 |
|
if train_steps % args.log_every == 0: |
|
|
|
torch.cuda.synchronize() |
|
end_time = time.time() |
|
steps_per_sec = log_steps / (end_time - start_time) |
|
|
|
avg_loss = torch.tensor(running_loss / log_steps, device=device) |
|
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
|
avg_loss = avg_loss.item() / dist.get_world_size() |
|
logger.info( |
|
f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}" |
|
) |
|
|
|
running_loss = 0 |
|
log_steps = 0 |
|
start_time = time.time() |
|
|
|
if dist.get_rank() == 0: |
|
vq_loss.module.wandb_tracker.log( |
|
{"lr": optimizer.param_groups[0]["lr"], "train_loss": avg_loss}, |
|
step=train_steps, |
|
) |
|
|
|
if train_steps % args.vis_every == 0: |
|
with torch.no_grad(): |
|
recons_with_scale = ( |
|
vq_model.module.img_to_reconstructed_img( |
|
imgs[:4], last_one=False |
|
) |
|
) |
|
image = torch.cat(recons_with_scale + [imgs[:4]], dim=0) |
|
image = torch.clamp(image, min=-1, max=1) |
|
image = make_grid( |
|
(image + 1) / 2, nrow=4, padding=0, pad_value=1.0 |
|
) |
|
image = image.permute(1, 2, 0).mul_(255).cpu().numpy() |
|
image = Image.fromarray(image.astype(np.uint8)) |
|
|
|
vq_loss.module.wandb_tracker.log( |
|
{"recon_images": [wandb.Image(image)]}, step=train_steps |
|
) |
|
|
|
|
|
if train_steps % args.ckpt_every == 0 and train_steps > 0: |
|
if args.save_best: |
|
vq_model.eval() |
|
total = 0 |
|
samples = [] |
|
gt = [] |
|
for x, _ in tqdm( |
|
val_loader, |
|
desc=f"evaluation for step {train_steps:07d}", |
|
disable=not rank == 0, |
|
): |
|
with torch.no_grad(): |
|
x = x.to(device, non_blocking=True) |
|
sample = vq_model.module.img_to_reconstructed_img(x) |
|
sample = ( |
|
torch.clamp(127.5 * sample + 128.0, 0, 255) |
|
.permute(0, 2, 3, 1) |
|
.to(torch.uint8) |
|
.contiguous() |
|
) |
|
x = ( |
|
torch.clamp(127.5 * x + 128.0, 0, 255) |
|
.permute(0, 2, 3, 1) |
|
.to(torch.uint8) |
|
.contiguous() |
|
) |
|
|
|
sample = torch.cat(dist.nn.all_gather(sample), dim=0) |
|
x = torch.cat(dist.nn.all_gather(x), dim=0) |
|
samples.append(sample.to("cpu", dtype=torch.uint8).numpy()) |
|
gt.append(x.to("cpu", dtype=torch.uint8).numpy()) |
|
|
|
total += sample.shape[0] |
|
vq_model.train() |
|
logger.info(f"Ealuate total {total} files.") |
|
dist.barrier() |
|
|
|
if rank == 0: |
|
samples = np.concatenate(samples, axis=0) |
|
gt = np.concatenate(gt, axis=0) |
|
config = tf.ConfigProto( |
|
allow_soft_placement=True |
|
) |
|
config.gpu_options.allow_growth = True |
|
|
|
evaluator = Evaluator(tf.Session(config=config), batch_size=32) |
|
evaluator.warmup() |
|
logger.info("computing reference batch activations...") |
|
ref_acts = evaluator.read_activations(gt) |
|
logger.info("computing/reading reference batch statistics...") |
|
ref_stats, _ = evaluator.read_statistics(gt, ref_acts) |
|
logger.info("computing sample batch activations...") |
|
sample_acts = evaluator.read_activations(samples) |
|
logger.info("computing/reading sample batch statistics...") |
|
sample_stats, _ = evaluator.read_statistics( |
|
samples, sample_acts |
|
) |
|
FID = sample_stats.frechet_distance(ref_stats) |
|
|
|
logger.info(f"traing step: {train_steps:07d}, FID {FID:07f}") |
|
|
|
if curr_fid == None: |
|
curr_fid = [FID, train_steps] |
|
elif FID <= curr_fid[0]: |
|
|
|
curr_fid = [FID, train_steps] |
|
|
|
vq_loss.module.wandb_tracker.log( |
|
{"eval FID": FID}, step=train_steps |
|
) |
|
|
|
dist.barrier() |
|
|
|
if rank == 0: |
|
if args.compile: |
|
model_weight = vq_model.module._orig_mod.state_dict() |
|
else: |
|
model_weight = vq_model.module.state_dict() |
|
checkpoint = { |
|
"model": model_weight, |
|
"optimizer": optimizer.state_dict(), |
|
"discriminator": vq_loss.module.discriminator.state_dict(), |
|
"optimizer_disc": optimizer_disc.state_dict(), |
|
"steps": train_steps, |
|
"args": args, |
|
} |
|
if args.ema: |
|
checkpoint["ema"] = ema.state_dict() |
|
if not args.no_local_save: |
|
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" |
|
torch.save(checkpoint, checkpoint_path) |
|
logger.info(f"Saved checkpoint to {checkpoint_path}") |
|
|
|
|
|
|
|
|
|
|
|
if args.save_best: |
|
last_checkpoint_path = f"{args.cloud_save_path}/last_ckpt.pt" |
|
if os.path.exists(last_checkpoint_path): |
|
os.remove(last_checkpoint_path) |
|
else: |
|
os.makedirs(f"{args.cloud_save_path}", exist_ok=True) |
|
torch.save(checkpoint, last_checkpoint_path) |
|
logger.info( |
|
f"Saved checkpoint in cloud to {last_checkpoint_path}" |
|
) |
|
if curr_fid[1] == train_steps: |
|
best_checkpoint_path = ( |
|
f"{args.cloud_save_path}/best_ckpt.pt" |
|
) |
|
torch.save(checkpoint, best_checkpoint_path) |
|
logger.info( |
|
f"Saved checkpoint in cloud to {best_checkpoint_path}" |
|
) |
|
|
|
dist.barrier() |
|
|
|
if vqvae_lr_scheduler is not None: |
|
vqvae_lr_scheduler.step(epoch + 1) |
|
if disc_lr_scheduler is not None and epoch >= args.disc_epoch_start: |
|
disc_lr_scheduler.step(epoch + 1 - args.disc_epoch_start) |
|
|
|
vq_model.eval() |
|
|
|
|
|
logger.info("Done!") |
|
dist.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|