import argparse import datetime import numpy as np import os import time from pathlib import Path import yaml import glob import torch import torch.backends.cudnn as cudnn from torch.utils.tensorboard import SummaryWriter import torchvision.transforms as transforms import torchvision.datasets as datasets from data import cityscapes from util.crop import center_crop_arr import util.misc as misc from util.misc import NativeScalerWithGradNormCount as NativeScaler from util.loader import CachedFolder from models.vae import AutoencoderKL from models import mar import copy from tqdm import tqdm import util.lr_sched as lr_sched import logging def update_ema(target_params, source_params, rate=0.99): """ Update target parameters to be closer to those of source parameters using an exponential moving average. :param target_params: the target parameter sequence. :param source_params: the source parameter sequence. :param rate: the EMA rate (closer to 1 means slower). """ for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src, alpha=1 - rate) def logger_file(path): logger = logging.getLogger() logger.setLevel(logging.DEBUG) handler = logging.FileHandler(path,"w", encoding=None, delay="true") handler.setLevel(logging.INFO) formatter = logging.Formatter("%(message)s") handler.setFormatter(formatter) logger.addHandler(handler) return logger def get_args_parser(): parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False) parser.add_argument('--batch_size', default=2, type=int, help='Batch size per GPU (effective batch size is batch_size * # gpus') parser.add_argument('--epochs', default=2000, type=int) # Model parameters parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str, help='model checkpoint path') # VAE parameters parser.add_argument('--img_size', default=768, type=int, help='images input size') parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str, help='images input size') parser.add_argument('--vae_embed_dim', default=16, type=int, help='vae output embedding dimension') parser.add_argument('--vae_stride', default=16, type=int, help='tokenizer stride, default use KL16') parser.add_argument('--patch_size', default=1, type=int, help='number of tokens to group as a patch.') parser.add_argument('--config', default="ldm/config.yaml", type=str, help='vae model configuration file') # Generation parameters parser.add_argument('--num_iter', default=64, type=int, help='number of autoregressive iterations to generate an image') parser.add_argument('--num_images', default=3000, type=int, help='number of images to generate') parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance") parser.add_argument('--cfg_schedule', default="linear", type=str) parser.add_argument('--label_drop_prob', default=0.1, type=float) parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency') parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency') parser.add_argument('--online_eval', action='store_true') parser.add_argument('--evaluate', action='store_true') parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size') # Optimizer parameters parser.add_argument('--weight_decay', type=float, default=0.02, help='weight decay (default: 0.02)') parser.add_argument('--grad_checkpointing', action='store_true') parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--lr_schedule', type=str, default='constant', help='learning rate schedule') parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N', help='epochs to warmup LR') parser.add_argument('--ema_rate', default=0.9999, type=float) # MAR params parser.add_argument('--mask_ratio_min', type=float, default=0.7, help='Minimum mask ratio') parser.add_argument('--grad_clip', type=float, default=3.0, help='Gradient clip') parser.add_argument('--attn_dropout', type=float, default=0.1, help='attention dropout') parser.add_argument('--proj_dropout', type=float, default=0.1, help='projection dropout') parser.add_argument('--buffer_size', type=int, default=64) # Diffusion Loss params parser.add_argument('--diffloss_d', type=int, default=6) parser.add_argument('--diffloss_w', type=int, default=1024) parser.add_argument('--num_sampling_steps', type=str, default="100") parser.add_argument('--diffusion_batch_mul', type=int, default=4) parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature') # Dataset parameters parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default='./output_dir', help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=1, type=int) parser.add_argument('--resume', default=None, help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') # caching latents parser.add_argument('--use_cached', action='store_true', dest='use_cached', help='Use cached latents') parser.set_defaults(use_cached=False) parser.add_argument('--cached_path', default='', help='path to cached latents') return parser def main(args): misc.init_distributed_mode(args) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True num_tasks = misc.get_world_size() global_rank = misc.get_rank() log_writer = None # augmentation following DiT and ADM transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) dataset_train = cityscapes.CityScapes('dataset/CityScapes/trainlist.txt', transform=transform_train, img_size=args.img_size) sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) print("Sampler_train = %s" % str(sampler_train)) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) # define the vae and mar model with open(args.config, "r") as f: config = yaml.safe_load(f) args.ddconfig = config["ddconfig"] print('cofig: ', config) vae = AutoencoderKL( ddconfig=args.ddconfig, embed_dim=args.vae_embed_dim, ckpt_path=args.vae_path ).cuda().eval() for param in vae.parameters(): param.requires_grad = False model = mar.__dict__[args.model]( img_size=args.img_size, vae_stride=args.vae_stride, patch_size=args.patch_size, vae_embed_dim=args.vae_embed_dim, mask_ratio_min=args.mask_ratio_min, label_drop_prob=args.label_drop_prob, attn_dropout=args.attn_dropout, proj_dropout=args.proj_dropout, buffer_size=args.buffer_size, diffloss_d=args.diffloss_d, diffloss_w=args.diffloss_w, num_sampling_steps=args.num_sampling_steps, diffusion_batch_mul=args.diffusion_batch_mul, grad_checkpointing=args.grad_checkpointing, ) if args.ckpt_path: checkpoint = torch.load(args.ckpt_path, map_location='cpu') model.load_state_dict(checkpoint['model']) print("Model = %s" % str(model)) # following timm: set wd as 0 for bias and norm layers n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Number of trainable parameters: {}M".format(n_params / 1e6)) model.to(device) model_without_ddp = model eff_batch_size = args.batch_size * misc.get_world_size() if args.lr is None: # only base_lr is specified args.lr = args.blr print("base lr: %.2e" % args.blr) print("actual lr: %.2e" % args.lr) print("effective batch size: %d" % eff_batch_size) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module # no weight decay on bias, norm layers, and diffloss MLP param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay) optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) print(optimizer) loss_scaler = NativeScaler() # resume training if args.resume and glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')): try: checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-1], map_location='cpu') model.load_state_dict(checkpoint['model']) except: checkpoint = torch.load(sorted(glob.glob(os.path.join(args.output_dir, args.resume, 'checkpoint*.pth')))[-2], map_location='cpu') model.load_state_dict(checkpoint['model']) state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model'].items()} model_without_ddp.load_state_dict(state_dict) model_params = list(model_without_ddp.parameters()) ema_params = copy.deepcopy(model_params) ema_state_dict = {key.replace("module.", ""): value for key, value in checkpoint['model_ema'].items()} ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()] print("Resume checkpoint %s" % args.resume) if 'optimizer' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") del checkpoint args.output_dir = os.path.join(args.output_dir, args.resume) logger = logger_file(args.log_dir+'/'+args.resume+'.log') if os.path.exists(args.log_dir+'/'+args.resume+'.log'): with open(args.log_dir+'/'+args.resume+'.log', 'r') as infile: for line in infile: logger.info(line.rstrip()) else: logger.info("All the arguments") for k, v in vars(args).items(): logger.info(f"{k}: {v}") logger.info("\n\n Loss information") else: model_params = list(model_without_ddp.parameters()) ema_params = copy.deepcopy(model_params) print("Training from scratch") args.resume = datetime.datetime.now().strftime("%Y.%m.%d.%H.%M") args.output_dir = os.path.join(args.output_dir, args.resume) Path(args.output_dir).mkdir(parents=True, exist_ok=True) logger = logger_file(args.log_dir+'/'+args.resume+'.log') logger.info("All the arguments") for k, v in vars(args).items(): logger.info(f"{k}: {v}") logger.info("\n\n Loss information") print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) for epoch in tqdm(range(args.start_epoch, args.epochs), desc="Training Progress"): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 20 optimizer.zero_grad() for data_iter_step, (samples, labels, _) in enumerate(data_loader_train): lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args) samples = samples.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) with torch.no_grad(): posterior_x = vae.encode(samples) posterior_y = vae.encode(labels) x = posterior_x.sample().mul_(0.2325) y = posterior_y.sample().mul_(0.2325) with torch.cuda.amp.autocast(): loss = model(x,y) loss_value = loss.item() loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True) optimizer.zero_grad() torch.cuda.synchronize() update_ema(ema_params, model_params, rate=args.ema_rate) metric_logger.update(loss=loss_value) lr = optimizer.param_groups[0]["lr"] metric_logger.update(lr=lr) loss_value_reduce = misc.all_reduce_mean(loss_value) metric_logger.synchronize_between_processes() logger.info(f"epoch: {epoch:4d}, Averaged stats: {metric_logger}") if (epoch+1)% args.save_last_freq == 0: misc.save_model(args=args, model=model, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params, epoch_name=str(epoch).zfill(5)) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': args = get_args_parser() args = args.parse_args() Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.log_dir).mkdir(parents=True, exist_ok=True) main(args)