|
import os |
|
from pathlib import Path |
|
|
|
import hydra |
|
import torch |
|
import wandb |
|
import random |
|
from colorama import Fore |
|
from jaxtyping import install_import_hook |
|
from lightning.pytorch import Trainer |
|
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint |
|
from lightning.pytorch.loggers.wandb import WandbLogger |
|
from lightning.pytorch.plugins.environments import SLURMEnvironment |
|
from lightning.pytorch.strategies import DeepSpeedStrategy |
|
from omegaconf import DictConfig, OmegaConf |
|
from hydra.core.hydra_config import HydraConfig |
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from src.model.model import get_model |
|
from src.misc.weight_modify import checkpoint_filter_fn |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
with install_import_hook( |
|
("src",), |
|
("beartype", "beartype"), |
|
): |
|
from src.config import load_typed_root_config |
|
from src.dataset.data_module import DataModule |
|
from src.global_cfg import set_cfg |
|
from src.loss import get_losses |
|
from src.misc.LocalLogger import LocalLogger |
|
from src.misc.step_tracker import StepTracker |
|
from src.misc.wandb_tools import update_checkpoint_path |
|
from src.model.decoder import get_decoder |
|
from src.model.encoder import get_encoder |
|
from src.model.model_wrapper import ModelWrapper |
|
|
|
|
|
def cyan(text: str) -> str: |
|
return f"{Fore.CYAN}{text}{Fore.RESET}" |
|
|
|
|
|
@hydra.main( |
|
version_base=None, |
|
config_path="../config", |
|
config_name="main", |
|
) |
|
def train(cfg_dict: DictConfig): |
|
cfg = load_typed_root_config(cfg_dict) |
|
set_cfg(cfg_dict) |
|
|
|
|
|
output_dir = Path( |
|
hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"] |
|
) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
print(cyan(f"Saving outputs to {output_dir}.")) |
|
|
|
cfg.train.output_path = output_dir |
|
|
|
|
|
callbacks = [] |
|
if cfg_dict.wandb.mode != "disabled": |
|
logger = WandbLogger( |
|
project=cfg_dict.wandb.project, |
|
mode=cfg_dict.wandb.mode, |
|
name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})", |
|
tags=cfg_dict.wandb.get("tags", None), |
|
log_model=False, |
|
save_dir=output_dir, |
|
config=OmegaConf.to_container(cfg_dict), |
|
) |
|
callbacks.append(LearningRateMonitor("step", True)) |
|
|
|
|
|
if wandb.run is not None: |
|
wandb.run.log_code("src") |
|
else: |
|
logger = LocalLogger() |
|
|
|
|
|
callbacks.append( |
|
ModelCheckpoint( |
|
output_dir / "checkpoints", |
|
every_n_train_steps=cfg.checkpointing.every_n_train_steps, |
|
save_top_k=cfg.checkpointing.save_top_k, |
|
save_weights_only=cfg.checkpointing.save_weights_only, |
|
monitor="info/global_step", |
|
mode="max", |
|
) |
|
) |
|
callbacks[-1].CHECKPOINT_EQUALS_CHAR = '_' |
|
|
|
|
|
checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) |
|
|
|
|
|
step_tracker = StepTracker() |
|
|
|
trainer = Trainer( |
|
max_epochs=-1, |
|
num_nodes=cfg.trainer.num_nodes, |
|
|
|
accelerator="gpu", |
|
logger=logger, |
|
devices="auto", |
|
strategy=( |
|
"ddp_find_unused_parameters_true" |
|
if torch.cuda.device_count() > 1 |
|
else "auto" |
|
), |
|
|
|
callbacks=callbacks, |
|
val_check_interval=cfg.trainer.val_check_interval, |
|
check_val_every_n_epoch=None, |
|
enable_progress_bar=False, |
|
gradient_clip_val=cfg.trainer.gradient_clip_val, |
|
max_steps=cfg.trainer.max_steps, |
|
precision=cfg.trainer.precision, |
|
accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, |
|
|
|
inference_mode=False if (cfg.mode == "test" and cfg.test.align_pose) else True, |
|
) |
|
torch.manual_seed(cfg_dict.seed + trainer.global_rank) |
|
|
|
model = get_model(cfg.model.encoder, cfg.model.decoder) |
|
|
|
model_wrapper = ModelWrapper( |
|
cfg.optimizer, |
|
cfg.test, |
|
cfg.train, |
|
model, |
|
get_losses(cfg.loss), |
|
step_tracker |
|
) |
|
data_module = DataModule( |
|
cfg.dataset, |
|
cfg.data_loader, |
|
step_tracker, |
|
global_rank=trainer.global_rank, |
|
) |
|
|
|
if cfg.mode == "train": |
|
trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path) |
|
else: |
|
trainer.test( |
|
model_wrapper, |
|
datamodule=data_module, |
|
ckpt_path=checkpoint_path, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|