|
import os |
|
import sys |
|
|
|
sys.path.append("/home/xiangl/LlamaGen") |
|
import logging |
|
import json |
|
import numpy as np |
|
import torch.distributed |
|
from tqdm.auto import tqdm |
|
from PIL import Image |
|
from logging import getLogger as get_logger |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP |
|
from torch.utils.data import DataLoader, default_collate |
|
from utils2 import ( |
|
init_distributed_device, |
|
is_global_primary, |
|
is_primary, |
|
seed_everything, |
|
str2bool, |
|
) |
|
from tokenizer.tokenizer_image.msvq_model import VQ_models |
|
from datasets import ( |
|
create_dataset, |
|
fast_collate, |
|
PrefetchLoader, |
|
Normalize, |
|
Denormalize, |
|
) |
|
|
|
from timm.optim import create_optimizer_v2 as create_optimizer |
|
from timm.scheduler import create_scheduler_v2 as create_scheduler |
|
|
|
import argparse |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
"--config", |
|
type=str, |
|
default="configs/vqgan/imagenet/vqvae_vq_dinov2base_v4096z16n64_pretrained_ae.yaml", |
|
help="config file used to specify parameters", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--data_dir", type=str, default="imagenet/train", help="data folder" |
|
) |
|
parser.add_argument( |
|
"--dataset_name", type=str, default="imagenet", help="dataset name" |
|
) |
|
parser.add_argument( |
|
"--val_data_dir", type=str, default="imagenet/val", help="data folder" |
|
) |
|
parser.add_argument("--image_size", type=int, default=256, help="image size") |
|
parser.add_argument("--batch_size", type=int, default=4, help="per gpu batch size") |
|
parser.add_argument("--num_workers", type=int, default=8, help="batch size") |
|
parser.add_argument( |
|
"--num_classes", type=int, default=1000, help="number of classes in dataset" |
|
) |
|
parser.add_argument( |
|
"--use_prefetcher", type=str2bool, default=True, help="use prefetch" |
|
) |
|
|
|
|
|
parser.add_argument("--run_name", type=str, default=None, help="run_name") |
|
parser.add_argument( |
|
"--output_dir", type=str, default="experiments", help="output folder" |
|
) |
|
parser.add_argument("--num_epochs", type=int, default=10) |
|
parser.add_argument("--optimizer", type=str, default="adamw", help="optimizer") |
|
parser.add_argument( |
|
"--learning_rate", type=float, default=1e-4, help="learning rate" |
|
) |
|
parser.add_argument("--min_lr", type=float, default=5e-5, help="end learning rate") |
|
parser.add_argument("--weight_decay", type=float, default=1e-4, help="weight decay") |
|
parser.add_argument( |
|
"--lr_scheduler", type=str, default="cosine", help="lr scheduler" |
|
) |
|
parser.add_argument( |
|
"--lr_warmup_epochs", type=float, default=1, help="warmup epochs" |
|
) |
|
parser.add_argument( |
|
"--log_interval", type=int, default=50, help="log interval for steps" |
|
) |
|
parser.add_argument( |
|
"--val_interval", type=int, default=1000, help="validation interval for epochs" |
|
) |
|
parser.add_argument("--save_interval", type=int, default=1, help="save interval") |
|
parser.add_argument( |
|
"--gradient_accumulation_steps", |
|
type=int, |
|
default=1, |
|
help="gradient accumulation steps", |
|
) |
|
parser.add_argument( |
|
"--gradient_clip", type=float, default=1.0, help="gradient clip" |
|
) |
|
parser.add_argument( |
|
"--torchcompile", type=str2bool, default=False, help="use torch compile" |
|
) |
|
parser.add_argument( |
|
"--report_to", |
|
type=str, |
|
default="wandb", |
|
help="report to", |
|
choices=["wandb", "tensorboard", "none"], |
|
) |
|
parser.add_argument( |
|
"--resume", type=str, default=None, help="resume from pre-trained checkpoint" |
|
) |
|
parser.add_argument( |
|
"--auto_resume", |
|
type=str2bool, |
|
default=False, |
|
help="auto resume from latest checkpoint", |
|
) |
|
parser.add_argument("--seed", type=int, default=42, help="random seed") |
|
parser.add_argument( |
|
"--mixed_precision", |
|
type=str, |
|
default="bf16", |
|
choices=["fp16", "bf16", "fp32"], |
|
help="mixed precision", |
|
) |
|
parser.add_argument( |
|
"--ema", type=float, default=0, help="ema updates of the models" |
|
) |
|
parser.add_argument("--beta1", type=int, default=0.9, help="beta1 for adam") |
|
parser.add_argument("--beta2", type=int, default=0.99, help="beta2 for adam") |
|
parser.add_argument( |
|
"--quantizer_lr_multiplier", |
|
type=float, |
|
default=1.0, |
|
help="lr multiplier for quantization", |
|
) |
|
parser.add_argument( |
|
"--compile", type=str2bool, default=False, help="use torch compile" |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--disc_adaptive", |
|
type=str2bool, |
|
default=True, |
|
help="flag of whether to use adaptive discriminator weight", |
|
) |
|
parser.add_argument( |
|
"--disc_loss_start", |
|
type=float, |
|
default=0, |
|
help="starting threshold of adaptive discriminator weight for discriminator training", |
|
) |
|
parser.add_argument( |
|
"--disc_loss_weight", type=float, default=0.8, help="discriminator loss weight" |
|
) |
|
parser.add_argument( |
|
"--gen_disc_loss_weight", |
|
type=float, |
|
default=0.1, |
|
help="discriminator loss weight of generator", |
|
) |
|
parser.add_argument( |
|
"--gen_disc_loss_type", |
|
type=str, |
|
default="non-saturating", |
|
choices=["hinge", "vanilla", "non-saturating"], |
|
help="generator loss type", |
|
) |
|
parser.add_argument( |
|
"--disc_loss_type", |
|
type=str, |
|
default="hinge", |
|
choices=["hinge", "vanilla", "non-saturating"], |
|
help="discriminator loss type", |
|
) |
|
parser.add_argument( |
|
"--disc_model", |
|
type=str, |
|
default="patchgan", |
|
choices=["patchgan", "stylegan"], |
|
help="discriminator loss type", |
|
) |
|
parser.add_argument( |
|
"--lecam_loss_weight", |
|
type=float, |
|
default=0.0, |
|
help="lecam regularization loss weight of discriminator", |
|
) |
|
parser.add_argument( |
|
"--codebook_loss_weight", type=float, default=1.0, help="codebook loss weight" |
|
) |
|
parser.add_argument( |
|
"--perceptual_loss_weight", |
|
type=float, |
|
default=0.1, |
|
help="perceptual loss weight", |
|
) |
|
parser.add_argument( |
|
"--logit_scale_loss_weight", |
|
type=float, |
|
default=0.1, |
|
help="logit_scale loss weight", |
|
) |
|
parser.add_argument( |
|
"--rec_loss_weight", type=float, default=1.0, help="rec loss weight" |
|
) |
|
|
|
parser.add_argument( |
|
"--ent_loss_weight", type=float, default=0.1, help="entropy loss weight" |
|
) |
|
parser.add_argument( |
|
"--ent_loss_weight_end", type=float, default=0.0, help="entropy loss weight" |
|
) |
|
parser.add_argument( |
|
"--ent_loss_start", type=float, default=1.0, help="start to add entropy loss" |
|
) |
|
parser.add_argument( |
|
"--ent_loss_annealing_steps", |
|
type=float, |
|
default=2000, |
|
help="steps to anneal entropy loss weight", |
|
) |
|
parser.add_argument( |
|
"--sem_loss_weight", type=float, default=0.01, help="semantic loss weight" |
|
) |
|
parser.add_argument( |
|
"--ent_sample_min_loss_weight", |
|
type=float, |
|
default=1.0, |
|
help="sample entropy minimization loss weight", |
|
) |
|
parser.add_argument( |
|
"--ent_batch_max_loss_weight", |
|
type=float, |
|
default=1.0, |
|
help="batch entropy maximization loss weight", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--recon_loss", |
|
type=str, |
|
default="l1", |
|
choices=["l1", "l2"], |
|
help="reconstruction loss", |
|
) |
|
parser.add_argument( |
|
"--quantizer", |
|
type=str, |
|
default="vq", |
|
choices=[ |
|
"vq", |
|
"gumbel_vq", |
|
"st_gumbel_vq", |
|
"ema_vq", |
|
"oc_vq", |
|
"diff_vq", |
|
"diff_vq2", |
|
"diff_vq_fix", |
|
], |
|
help="quantizer type", |
|
) |
|
parser.add_argument( |
|
"--encoder", type=str, default="dinov2", help="encoder model type" |
|
) |
|
parser.add_argument( |
|
"--decoder", type=str, default="dinov2", help="deocder model type" |
|
) |
|
parser.add_argument( |
|
"--encoder_model", |
|
type=str, |
|
default="vit_small_patch14_dinov2.lvd142m", |
|
help="encoder model name", |
|
) |
|
parser.add_argument( |
|
"--encoder_model_pretrained", |
|
type=str2bool, |
|
default=True, |
|
help="encoder model load pretrained checkpoint", |
|
) |
|
parser.add_argument( |
|
"--encoder_patch_size", type=int, default=16, help="encoder patch size" |
|
) |
|
parser.add_argument( |
|
"--encoder_tuning", type=str, default="lora", help="encoder tuning method" |
|
) |
|
parser.add_argument( |
|
"--encoder_tuning_lora_r", default=8, type=int, help="encoder tuning lora r" |
|
) |
|
parser.add_argument( |
|
"--encoder_drop_path", type=float, default=0.0, help="encoder droppath rate" |
|
) |
|
parser.add_argument( |
|
"--decoder_model", |
|
type=str, |
|
default="vit_small_patch14_dinov2.lvd142m", |
|
help="deocder model name", |
|
) |
|
parser.add_argument( |
|
"--decoder_model_pretrained", |
|
type=str2bool, |
|
default=True, |
|
help="decoder model load pretrained checkpoint", |
|
) |
|
parser.add_argument( |
|
"--decoder_patch_size", type=int, default=16, help="decoder patch size" |
|
) |
|
parser.add_argument( |
|
"--decoder_drop_path", type=float, default=0.0, help="decoder droppath rate" |
|
) |
|
parser.add_argument( |
|
"--decoder_to_pixel", |
|
type=str, |
|
default="linear", |
|
help="decoder to pixel", |
|
choices=["linear", "conv", "ada_conv", "siren"], |
|
) |
|
parser.add_argument( |
|
"--decoder_use_rope", type=str2bool, default=False, help="decoder use RoPE" |
|
) |
|
parser.add_argument( |
|
"--decoder_cond_latent", |
|
type=str2bool, |
|
default=False, |
|
help="use dino latent to initialize latent tokens (mask token)", |
|
) |
|
parser.add_argument( |
|
"--decoder_tuning", type=str, default="lora", help="deocder tuning method" |
|
) |
|
parser.add_argument( |
|
"--decoder_tuning_lora_r", default=8, type=int, help="decoder tuning lora r" |
|
) |
|
parser.add_argument( |
|
"--pretrained_path", type=str, default=None, help="pretrained model path" |
|
) |
|
parser.add_argument( |
|
"--semantic_guide", |
|
type=str, |
|
default="none", |
|
help="semantic guidance on latent tokens", |
|
) |
|
parser.add_argument( |
|
"--sem_loss_scale", type=float, default=15.0, help="scale for clip loss" |
|
) |
|
parser.add_argument( |
|
"--renorm_input", type=str2bool, default=False, help="normalize input images" |
|
) |
|
|
|
parser.add_argument( |
|
"--vocab_size", type=int, default=4096, nargs="+", help="codebook size" |
|
) |
|
parser.add_argument( |
|
"--z_channels", type=int, default=32, help="latent size of vqvae" |
|
) |
|
parser.add_argument( |
|
"--num_latent_tokens", type=int, default=32, help="number of latent tokens" |
|
) |
|
parser.add_argument( |
|
"--codebook_norm", type=str2bool, default=True, help="normalize codebook" |
|
) |
|
parser.add_argument( |
|
"--use_gumbel", |
|
type=str2bool, |
|
default=False, |
|
help="use gumbel softmax for probs", |
|
) |
|
parser.add_argument( |
|
"--commit_loss_weight", type=float, default=0.0, help="commit loss weight" |
|
) |
|
parser.add_argument( |
|
"--kl_loss_weight", type=float, default=5e-4, help="kl loss weight" |
|
) |
|
parser.add_argument( |
|
"--ema_decay", |
|
type=float, |
|
default=0.999, |
|
help="ema decay for embeddings of ema quantizer", |
|
) |
|
parser.add_argument( |
|
"--oc_anchor", |
|
type=str, |
|
default="cloest", |
|
help="online cluster anchor", |
|
choices=["closest", "random", "projrandom"], |
|
) |
|
parser.add_argument( |
|
"--contrastive_loss_weight", |
|
type=float, |
|
default=1.0, |
|
help="contrastive loss weight", |
|
) |
|
parser.add_argument( |
|
"--freq_loss_weight", type=float, default=0.0, help="freq loss weight" |
|
) |
|
parser.add_argument( |
|
"--disc_r1_gamma", type=float, default=0.0, help="disc do r1 reg" |
|
) |
|
parser.add_argument( |
|
"--use_diffaug", type=str2bool, default=False, help="use diff aug" |
|
) |
|
parser.add_argument( |
|
"--init_logit_scale", |
|
type=float, |
|
default=10, |
|
help="initial logit scale before log", |
|
) |
|
parser.add_argument( |
|
"--max_logit_scale", |
|
type=float, |
|
default=200, |
|
help="maximum logit scale before log", |
|
) |
|
|
|
parser.add_argument( |
|
"--v_patch_nums", |
|
type=int, |
|
default=[1, 2, 3, 4, 5, 6, 8, 10, 13, 16], |
|
nargs="+", |
|
help="number of patch numbers of each scale", |
|
) |
|
parser.add_argument( |
|
"--codebook-size", |
|
type=int, |
|
default=16384, |
|
help="codebook size for vector quantization", |
|
) |
|
parser.add_argument( |
|
"--codebook-embed-dim", |
|
type=int, |
|
default=8, |
|
help="codebook dimension for vector quantization", |
|
) |
|
parser.add_argument( |
|
"--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16" |
|
) |
|
parser.add_argument( |
|
"--vq-ckpt", type=str, default=None, help="ckpt path for resume training" |
|
) |
|
parser.add_argument( |
|
"--output_path", |
|
type=str, |
|
default="output/linear_probing", |
|
help="output model path", |
|
) |
|
parser.add_argument("--enc_type", type=str, default="cnn") |
|
parser.add_argument("--dec_type", type=str, default="cnn") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
class LinearClassifier(nn.Module): |
|
def __init__(self, in_dim, out_dim): |
|
super(LinearClassifier, self).__init__() |
|
self.bn = nn.BatchNorm1d(in_dim, affine=False, eps=1e-6) |
|
self.linear = nn.Linear(in_dim, out_dim) |
|
|
|
def forward(self, x): |
|
x = x.mean(dim=1) |
|
x = self.bn(x) |
|
out = self.linear(x) |
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
def extract_feature(vqvae, images, args): |
|
|
|
if args.distributed: |
|
|
|
z_e = vqvae.module.encoder(images) |
|
if args.enc_type == "dinov2": |
|
b, l, c = z_e.shape |
|
z_e = z_e.view(b, 16, 16, c) |
|
z_e = z_e.permute(0, 3, 1, 2) |
|
z_e = vqvae.module.quant_conv(z_e) |
|
|
|
|
|
else: |
|
z_e = vqvae.encoder(images) |
|
if args.enc_type == "dinov2": |
|
b, l, c = z_e.shape |
|
z_e = z_e.view(b, 16, 16, c) |
|
z_e = z_e.permute(0, 3, 1, 2) |
|
z_e = vqvae.quant_conv(z_e) |
|
|
|
|
|
return z_e |
|
|
|
|
|
def train_epoch( |
|
vqvae, linear_classifier, train_dataloader, optimizer, device, scaler, args |
|
): |
|
criterion = torch.nn.CrossEntropyLoss() |
|
linear_classifier.train() |
|
train_dtype = { |
|
"none": torch.float32, |
|
"bf16": torch.bfloat16, |
|
"fp16": torch.float16, |
|
}[args.mixed_precision] |
|
total_loss = 0 |
|
total_correct = 0 |
|
total_samples = 0 |
|
if args.renorm_input: |
|
denormalize = Denormalize( |
|
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], device=device |
|
) |
|
normalize = Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], device=device |
|
) |
|
|
|
for idx, batch in tqdm( |
|
enumerate(train_dataloader), |
|
total=len(train_dataloader), |
|
disable=not is_primary(args), |
|
): |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
images, labels = batch |
|
if not args.use_prefetcher: |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
if args.renorm_input: |
|
input_images = denormalize(images) |
|
input_images = normalize(input_images) |
|
else: |
|
input_images = images |
|
|
|
with torch.cuda.amp.autocast(dtype=train_dtype): |
|
|
|
features = extract_feature(vqvae, input_images, args).detach() |
|
|
|
features = features.flatten(2).permute(0, 2, 1) |
|
logits = linear_classifier(features) |
|
loss = criterion(logits, labels) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
total_loss += loss.item() |
|
total_correct += (logits.argmax(1) == labels).sum().item() |
|
total_samples += labels.size(0) |
|
|
|
if is_primary(args) and idx % 25 == 0: |
|
logger.info(f"Training Loss: {loss.item():.4f}") |
|
logger.info(f"Training Acc: {total_correct / total_samples * 100.0:.4f}") |
|
return total_loss / len(train_dataloader), total_correct / total_samples * 100.0 |
|
|
|
|
|
def evaluate(vqvae, linear_classifier, val_dataloader, device, args): |
|
dtype = {"none": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[ |
|
args.mixed_precision |
|
] |
|
criterion = torch.nn.CrossEntropyLoss() |
|
linear_classifier.eval() |
|
total_loss = 0 |
|
total_correct = 0 |
|
total_samples = 0 |
|
if args.renorm_input: |
|
denormalize = Denormalize( |
|
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], device=device |
|
) |
|
normalize = Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], device=device |
|
) |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm( |
|
val_dataloader, total=len(val_dataloader), disable=not is_primary(args) |
|
): |
|
|
|
|
|
images, labels = batch |
|
images = images.to(device) |
|
labels = labels.to(device) |
|
|
|
if args.renorm_input: |
|
input_images = denormalize(images) |
|
input_images = normalize(input_images) |
|
else: |
|
input_images = images |
|
|
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
features = extract_feature(vqvae, input_images, args) |
|
features = features.flatten(2).permute(0, 2, 1) |
|
|
|
logits = linear_classifier(features) |
|
loss = criterion(logits, labels) |
|
|
|
total_loss += loss.item() |
|
total_correct += (logits.argmax(1) == labels).sum().item() |
|
total_samples += labels.size(0) |
|
return total_loss / len(val_dataloader), total_correct / total_samples * 100.0 |
|
|
|
|
|
def main(): |
|
|
|
args = parse_args() |
|
|
|
|
|
seed_everything(args.seed) |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
|
|
device = init_distributed_device(args) |
|
if args.distributed: |
|
logger.info( |
|
"Training in distributed mode with multiple processes, 1 device per process." |
|
f"Process {args.rank}, total {args.world_size}, device {args.device}." |
|
) |
|
os.environ["HF_HOME"] = f"./hf_cache_{args.rank}/" |
|
os.environ["TRANSFORMERS_CACHE"] = f"./hf_cache_{args.rank}/" |
|
else: |
|
logger.info(f"Training with a single process on 1 device ({args.device}).") |
|
assert args.rank >= 0 |
|
|
|
|
|
logger.info("Creating model") |
|
|
|
vqvae = VQ_models[args.vq_model]( |
|
codebook_size=args.codebook_size, |
|
codebook_embed_dim=args.codebook_embed_dim, |
|
v_patch_nums=args.v_patch_nums, |
|
enc_type=args.enc_type, |
|
dec_type=args.dec_type, |
|
semantic_guide=args.semantic_guide, |
|
) |
|
vqvae.to(device) |
|
vqvae.eval() |
|
checkpoint = torch.load(args.vq_ckpt, map_location="cpu") |
|
if "ema" in checkpoint: |
|
model_weight = checkpoint["ema"] |
|
elif "model" in checkpoint: |
|
model_weight = checkpoint["model"] |
|
elif "state_dict" in checkpoint: |
|
model_weight = checkpoint["state_dict"] |
|
else: |
|
raise Exception("please check model weight") |
|
vqvae.load_state_dict(model_weight) |
|
del checkpoint |
|
|
|
|
|
linear_classifier = LinearClassifier(vqvae.codebook_embed_dim, args.num_classes) |
|
linear_classifier = linear_classifier.to(device) |
|
|
|
if args.distributed: |
|
if is_primary(args): |
|
logger.info("Using native Torch DistributedDataParallel.") |
|
vqvae = NativeDDP(vqvae, device_ids=[device], find_unused_parameters=True) |
|
linear_classifier = NativeDDP( |
|
linear_classifier, device_ids=[device], find_unused_parameters=True |
|
) |
|
|
|
logger.info("Creating dataset") |
|
train_dataset = create_dataset( |
|
args.dataset_name, |
|
args.data_dir, |
|
args.image_size, |
|
is_train=True, |
|
use_prefetcher=args.use_prefetcher, |
|
) |
|
valid_dataset = create_dataset( |
|
args.dataset_name, |
|
args.val_data_dir, |
|
args.image_size, |
|
is_train=False, |
|
use_prefetcher=False, |
|
) |
|
sampler = None |
|
if args.distributed: |
|
sampler = torch.utils.data.DistributedSampler( |
|
train_dataset, shuffle=True, drop_last=False |
|
) |
|
shuffle = sampler is None |
|
collate_fn = fast_collate if args.use_prefetcher else default_collate |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=shuffle, |
|
collate_fn=collate_fn, |
|
) |
|
if args.use_prefetcher: |
|
train_dataloader = PrefetchLoader( |
|
train_dataloader, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], device=device |
|
) |
|
val_dataloader = DataLoader( |
|
valid_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
shuffle=False, |
|
) |
|
total_batch_size = args.batch_size * args.world_size |
|
|
|
|
|
|
|
output_dir = args.output_path |
|
output_dir = os.path.join(output_dir, "evaluations") |
|
os.makedirs(output_dir, exist_ok=True) |
|
lp_model_dir = os.path.join(output_dir, args.dataset_name) |
|
os.makedirs(lp_model_dir, exist_ok=True) |
|
|
|
optimizer = create_optimizer( |
|
linear_classifier, |
|
opt=args.optimizer, |
|
lr=args.learning_rate, |
|
weight_decay=args.weight_decay, |
|
) |
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision == "fp16")) |
|
scheduler, _ = create_scheduler( |
|
sched="step", |
|
decay_milestones=[ |
|
int(args.num_epochs * 0.3), |
|
int(args.num_epochs * 0.6), |
|
int(args.num_epochs * 0.9), |
|
], |
|
optimizer=optimizer, |
|
patience_epochs=0, |
|
step_on_epochs=True, |
|
num_epochs=args.num_epochs, |
|
warmup_epochs=args.lr_warmup_epochs, |
|
min_lr=1e-6, |
|
) |
|
|
|
|
|
logger.info("Start training linear classifier") |
|
max_accuracy = 0 |
|
for epoch in range(args.num_epochs): |
|
if args.distributed: |
|
sampler.set_epoch(epoch) |
|
train_loss, train_acc = train_epoch( |
|
vqvae, linear_classifier, train_dataloader, optimizer, device, scaler, args |
|
) |
|
val_loss, val_acc = evaluate( |
|
vqvae, linear_classifier, val_dataloader, device, args |
|
) |
|
|
|
if is_global_primary(args): |
|
if args.distributed: |
|
torch.save( |
|
linear_classifier.module.state_dict(), |
|
os.path.join(lp_model_dir, f"epoch_{epoch}.pth"), |
|
) |
|
else: |
|
torch.save( |
|
linear_classifier.state_dict(), |
|
os.path.join(lp_model_dir, f"epoch_{epoch}.pth"), |
|
) |
|
|
|
if val_acc > max_accuracy: |
|
max_accuracy = val_acc |
|
logger.info(f"Saving best model with accuracy {max_accuracy}") |
|
if args.distributed: |
|
torch.save( |
|
linear_classifier.module.state_dict(), |
|
os.path.join(lp_model_dir, "best.pth"), |
|
) |
|
else: |
|
torch.save( |
|
linear_classifier.state_dict(), |
|
os.path.join(lp_model_dir, "best.pth"), |
|
) |
|
|
|
if is_primary(args): |
|
logger.info( |
|
f"Epoch {epoch}: train_loss={train_loss}, train_acc={train_acc}, val_loss={val_loss}, val_acc={val_acc}" |
|
) |
|
logger.info(f"Best accuracy so far: {max_accuracy}") |
|
|
|
scheduler.step(epoch + 1) |
|
results = {"best_lp_accuracy": max_accuracy} |
|
|
|
|
|
|
|
if is_primary(args): |
|
logger.info("Finished training") |
|
logger.info(f"Best accuracy: {max_accuracy}") |
|
|
|
with open(os.path.join(output_dir, "linear_results.json"), "w") as f: |
|
json.dump(results, f) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|