import argparse from datetime import datetime import random import os import time import multiprocessing # Set multiprocessing start method to 'spawn' to avoid CUDA initialization issues in forked processes multiprocessing.set_start_method('spawn', force=True) from tqdm.auto import tqdm # Progress bar import numpy as np from omegaconf import OmegaConf import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR, ExponentialLR # Importing CosineAnnealingLR scheduler import torch.nn.functional as F from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.utils import set_seed # Removed get_scheduler import from peft import get_peft_model, LoraConfig from modeling import VMemModel from modeling.modules.autoencoder import AutoEncoder from modeling.sampling import DDPMDiscretization, DiscreteDenoiser, create_samplers from modeling.modules.conditioner import CLIPConditioner from utils.training_utils import DiffusionTrainer, load_pretrained_model from data.dataset import RealEstatePoseImageSevaDataset # set random seed for reproducibility torch.manual_seed(42) random.seed(42) np.random.seed(42) def parse_args(): parser = argparse.ArgumentParser(description='Train a model') parser.add_argument('--config', type=str, default="", required=True, help='Path to the config file') args = parser.parse_args() return args def generate_current_datetime(): return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") def prepare_model(unet, config): assert isinstance(unet, VMemModel), "unet should be an instance of VMemModel" if config.training.lora_flag: target_modules = [] for name, param in unet.named_parameters(): # # if ("temporal" in name or "transformer" in name) and "norm" not in name: print(name) if ("transformer" in name or "emb" in name or "layers" in name) \ and "norm" not in name and "in_layers.0" not in name and "out_layers.0" not in name: # print(name) name = name.replace(".weight", "") name = name.replace(".bias", "") if name not in target_modules: target_modules.append(str(name)) lora_config = LoraConfig( r=config.training.lora_r, lora_alpha=config.training.lora_alpha, target_modules=target_modules, lora_dropout=config.training.lora_dropout, # bias="none", ) lora_config.target_modules = target_modules unet = get_peft_model(unet, lora_config) # for name, param in unet.named_parameters(): # if "camera" in name or "control" in name or "context" in name or "epipolar" in name or "appearance" in name: # print(name) # param.requires_grad = True unet.print_trainable_parameters() else: for name, param in unet.named_parameters(): param.requires_grad = True print("trainable parameters percentage: ", np.sum([p.numel() for p in unet.parameters() if p.requires_grad])/np.sum([p.numel() for p in unet.parameters()])) return unet def main(): args = parse_args() config_path = args.config config = OmegaConf.load(config_path) # Load the configuration num_epochs = config.training.num_epochs batch_size = config.training.batch_size learning_rate = config.training.learning_rate gradient_accumulation_steps = config.training.gradient_accumulation_steps num_workers = config.training.num_workers warmup_epochs = config.training.warmup_epochs max_grad_norm = config.training.max_grad_norm validation_interval = config.training.validation_interval visualization_flag = config.training.visualization_flag visualize_every = config.training.visualize_every random_seed = config.training.random_seed save_flag = config.training.save_flag use_wandb = config.training.use_wandb samples_dir = config.training.samples_dir weights_save_dir = config.training.weights_save_dir resume = config.training.resume exp_id = generate_current_datetime() if visualization_flag: run_visualization_dir = f"{samples_dir}/{exp_id}" os.makedirs(run_visualization_dir, exist_ok=True) else: run_visualization_dir = None if save_flag: run_weights_save_dir = f"{weights_save_dir}/{exp_id}" os.makedirs(run_weights_save_dir, exist_ok=True) else: run_weights_save_dir = None accelerator = Accelerator( mixed_precision="fp16", gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=False)], ) num_gpus = accelerator.num_processes if random_seed is not None: set_seed(random_seed, device_specific=True) device = accelerator.device model = load_pretrained_model(cache_dir=config.model.cache_dir, device=device) model = prepare_model(model, config) if resume: model.load_state_dict(torch.load(resume, map_location='cpu'), strict=False) torch.cuda.empty_cache() # model = model.to(device) # time.sleep(100*3600) train_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir, meta_info_dir=config.dataset.realestate10k.meta_info_dir, num_sample_per_episode=config.dataset.realestate10k.num_sample_per_episode, mode='train') val_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir, meta_info_dir=config.dataset.realestate10k.meta_info_dir, num_sample_per_episode=config.dataset.realestate10k.val_num_sample_per_episode, mode='test') train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn') val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn') optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=config.training.weight_decay) train_steps_per_epoch = len(train_dataloader) total_train_steps = num_epochs * train_steps_per_epoch warmup_steps = warmup_epochs * train_steps_per_epoch lr_scheduler = CosineAnnealingLR( optimizer, T_max=total_train_steps - warmup_steps, eta_min=0 ) # lr_scheduler = ExponentialLR(optimizer, gamma=gamma) if warmup_epochs > 0: def warmup_lambda(current_step): return float(current_step) / float(max(1, warmup_steps)) warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda) # Combine the schedulers using SequentialLR lr_scheduler = SequentialLR( optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps] ) vae = AutoEncoder(chunk_size=1).to(device) vae.eval() conditioner = CLIPConditioner().to(device) discretization = DDPMDiscretization() denoiser = DiscreteDenoiser(discretization=discretization, num_idx=1000, device=device) sampler = create_samplers(guider_types=config.training.guider_types, discretization=discretization, num_frames=config.model.num_frames, num_steps=config.training.inference_num_steps, cfg_min=config.training.cfg_min, device=device) (model, vae, train_dataloader, val_dataloader, optimizer, lr_scheduler) = accelerator.prepare( model, vae, train_dataloader, val_dataloader, optimizer, lr_scheduler, ) trainer = DiffusionTrainer(network=model, ae=vae, conditioner=conditioner, denoiser=denoiser, sampler=sampler, discretization=discretization, cfg=config.training.cfg, optimizer=optimizer, lr_scheduler=lr_scheduler, ema_decay=config.training.ema_decay, device=device, accelerator=accelerator, max_grad_norm=max_grad_norm, save_flag=save_flag, visualize_flag=visualization_flag) trainer.train(train_dataloader, num_epochs, unconditional_prob=config.training.uncond_prob, log_every=10, validation_dataloader=val_dataloader, validation_interval=validation_interval, save_dir=run_weights_save_dir, save_interval=config.training.save_every, visualize_every=visualize_every, visualize_dir=run_visualization_dir, use_wandb=use_wandb) if __name__ == "__main__": main()