import argparse from pathlib import Path import yaml def get_args_parser(): parser = argparse.ArgumentParser('MAR training with Diffusion Loss', add_help=False) parser.add_argument('--batch_size', default=16, type=int, help='Batch size per GPU (effective batch size is batch_size * # gpus') parser.add_argument('--epochs', default=2000, type=int) # Model parameters parser.add_argument('--model', default='mar_base', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--ckpt_path', default="pretrained_models/mar/city768.16.pth", type=str, help='model checkpoint path') # VAE parameters parser.add_argument('--img_size', default=768, type=int, help='images input size') parser.add_argument('--vae_path', default="pretrained_models/vae/modelf16.ckpt", type=str, help='images input size') parser.add_argument('--vae_embed_dim', default=16, type=int, help='vae output embedding dimension') parser.add_argument('--vae_stride', default=16, type=int, help='tokenizer stride, default use KL16') parser.add_argument('--patch_size', default=1, type=int, help='number of tokens to group as a patch.') parser.add_argument('--config', default="ldm/config.yaml", type=str, help='vae model configuration file') # Generation parameters parser.add_argument('--num_iter', default=64, type=int, help='number of autoregressive iterations to generate an image') parser.add_argument('--num_images', default=3000, type=int, help='number of images to generate') parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance") parser.add_argument('--cfg_schedule', default="linear", type=str) parser.add_argument('--label_drop_prob', default=0.1, type=float) parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency') parser.add_argument('--save_last_freq', type=int, default=5, help='save last frequency') parser.add_argument('--online_eval', action='store_true') parser.add_argument('--evaluate', action='store_true') parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size') # Optimizer parameters parser.add_argument('--weight_decay', type=float, default=0.02, help='weight decay (default: 0.02)') parser.add_argument('--grad_checkpointing', action='store_true') parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--lr_schedule', type=str, default='constant', help='learning rate schedule') parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N', help='epochs to warmup LR') parser.add_argument('--ema_rate', default=0.9999, type=float) # MAR params parser.add_argument('--mask_ratio_min', type=float, default=0.7, help='Minimum mask ratio') parser.add_argument('--grad_clip', type=float, default=3.0, help='Gradient clip') parser.add_argument('--attn_dropout', type=float, default=0.1, help='attention dropout') parser.add_argument('--proj_dropout', type=float, default=0.1, help='projection dropout') parser.add_argument('--buffer_size', type=int, default=64) # Diffusion Loss params parser.add_argument('--diffloss_d', type=int, default=6) parser.add_argument('--diffloss_w', type=int, default=1024) parser.add_argument('--num_sampling_steps', type=str, default="100") parser.add_argument('--diffusion_batch_mul', type=int, default=4) parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature') # Dataset parameters parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default='./output_dir', help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=1, type=int) parser.add_argument('--resume', default=None,#'pretrained_models/mar/mar_base', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') # caching latents parser.add_argument('--use_cached', action='store_true', dest='use_cached', help='Use cached latents') parser.set_defaults(use_cached=False) parser.add_argument('--cached_path', default='', help='path to cached latents') return parser args = get_args_parser() args = args.parse_args() Path(args.output_dir).mkdir(parents=True, exist_ok=True) args.log_dir = args.output_dir with open(args.config, "r") as f: config = yaml.safe_load(f) args.ddconfig = config["ddconfig"]