import os from pathlib import Path import hydra import torch import wandb import random from colorama import Fore from jaxtyping import install_import_hook from lightning.pytorch import Trainer from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.loggers.wandb import WandbLogger from lightning.pytorch.plugins.environments import SLURMEnvironment from lightning.pytorch.strategies import DeepSpeedStrategy from omegaconf import DictConfig, OmegaConf from hydra.core.hydra_config import HydraConfig import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.model.model import get_model from src.misc.weight_modify import checkpoint_filter_fn import warnings warnings.filterwarnings("ignore") # Configure beartype and jaxtyping. with install_import_hook( ("src",), ("beartype", "beartype"), ): from src.config import load_typed_root_config from src.dataset.data_module import DataModule from src.global_cfg import set_cfg from src.loss import get_losses from src.misc.LocalLogger import LocalLogger from src.misc.step_tracker import StepTracker from src.misc.wandb_tools import update_checkpoint_path from src.model.decoder import get_decoder from src.model.encoder import get_encoder from src.model.model_wrapper import ModelWrapper def cyan(text: str) -> str: return f"{Fore.CYAN}{text}{Fore.RESET}" @hydra.main( version_base=None, config_path="../config", config_name="main", ) def train(cfg_dict: DictConfig): cfg = load_typed_root_config(cfg_dict) set_cfg(cfg_dict) # Set up the output directory. output_dir = Path( hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] ) output_dir.mkdir(parents=True, exist_ok=True) print(cyan(f"Saving outputs to {output_dir}.")) cfg.train.output_path = output_dir # Set up logging with wandb. callbacks = [] if cfg_dict.wandb.mode != "disabled": logger = WandbLogger( project=cfg_dict.wandb.project, mode=cfg_dict.wandb.mode, name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})", tags=cfg_dict.wandb.get("tags", None), log_model=False, save_dir=output_dir, config=OmegaConf.to_container(cfg_dict), ) callbacks.append(LearningRateMonitor("step", True)) # On rank != 0, wandb.run is None. if wandb.run is not None: wandb.run.log_code("src") else: logger = LocalLogger() # Set up checkpointing. callbacks.append( ModelCheckpoint( output_dir / "checkpoints", every_n_train_steps=cfg.checkpointing.every_n_train_steps, save_top_k=cfg.checkpointing.save_top_k, save_weights_only=cfg.checkpointing.save_weights_only, monitor="info/global_step", mode="max", ) ) callbacks[-1].CHECKPOINT_EQUALS_CHAR = '_' # Prepare the checkpoint for loading. checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) # This allows the current step to be shared with the data loader processes. step_tracker = StepTracker() trainer = Trainer( max_epochs=-1, num_nodes=cfg.trainer.num_nodes, # num_sanity_val_steps=0, accelerator="gpu", logger=logger, devices="auto", strategy=( "ddp_find_unused_parameters_true" if torch.cuda.device_count() > 1 else "auto" ), # strategy="deepspeed_stage_1", callbacks=callbacks, val_check_interval=cfg.trainer.val_check_interval, check_val_every_n_epoch=None, enable_progress_bar=False, gradient_clip_val=cfg.trainer.gradient_clip_val, max_steps=cfg.trainer.max_steps, precision=cfg.trainer.precision, accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, # plugins=[SLURMEnvironment(requeue_signal=signal.SIGUSR1)], # Uncomment for SLURM auto resubmission. inference_mode=False if (cfg.mode == "test" and cfg.test.align_pose) else True, ) torch.manual_seed(cfg_dict.seed + trainer.global_rank) model = get_model(cfg.model.encoder, cfg.model.decoder) model_wrapper = ModelWrapper( cfg.optimizer, cfg.test, cfg.train, model, get_losses(cfg.loss), step_tracker ) data_module = DataModule( cfg.dataset, cfg.data_loader, step_tracker, global_rank=trainer.global_rank, ) if cfg.mode == "train": trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) else: trainer.test( model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path, ) if __name__ == "__main__": train()