File size: 5,012 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
from pathlib import Path
import hydra
import torch
import wandb
import random
from colorama import Fore
from jaxtyping import install_import_hook
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.plugins.environments import SLURMEnvironment
from lightning.pytorch.strategies import DeepSpeedStrategy
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model.model import get_model
from src.misc.weight_modify import checkpoint_filter_fn
import warnings
warnings.filterwarnings("ignore")
# Configure beartype and jaxtyping.
with install_import_hook(
("src",),
("beartype", "beartype"),
):
from src.config import load_typed_root_config
from src.dataset.data_module import DataModule
from src.global_cfg import set_cfg
from src.loss import get_losses
from src.misc.LocalLogger import LocalLogger
from src.misc.step_tracker import StepTracker
from src.misc.wandb_tools import update_checkpoint_path
from src.model.decoder import get_decoder
from src.model.encoder import get_encoder
from src.model.model_wrapper import ModelWrapper
def cyan(text: str) -> str:
return f"{Fore.CYAN}{text}{Fore.RESET}"
@hydra.main(
version_base=None,
config_path="../config",
config_name="main",
)
def train(cfg_dict: DictConfig):
cfg = load_typed_root_config(cfg_dict)
set_cfg(cfg_dict)
# Set up the output directory.
output_dir = Path(
hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]
)
output_dir.mkdir(parents=True, exist_ok=True)
print(cyan(f"Saving outputs to {output_dir}."))
cfg.train.output_path = output_dir
# Set up logging with wandb.
callbacks = []
if cfg_dict.wandb.mode != "disabled":
logger = WandbLogger(
project=cfg_dict.wandb.project,
mode=cfg_dict.wandb.mode,
name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})",
tags=cfg_dict.wandb.get("tags", None),
log_model=False,
save_dir=output_dir,
config=OmegaConf.to_container(cfg_dict),
)
callbacks.append(LearningRateMonitor("step", True))
# On rank != 0, wandb.run is None.
if wandb.run is not None:
wandb.run.log_code("src")
else:
logger = LocalLogger()
# Set up checkpointing.
callbacks.append(
ModelCheckpoint(
output_dir / "checkpoints",
every_n_train_steps=cfg.checkpointing.every_n_train_steps,
save_top_k=cfg.checkpointing.save_top_k,
save_weights_only=cfg.checkpointing.save_weights_only,
monitor="info/global_step",
mode="max",
)
)
callbacks[-1].CHECKPOINT_EQUALS_CHAR = '_'
# Prepare the checkpoint for loading.
checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb)
# This allows the current step to be shared with the data loader processes.
step_tracker = StepTracker()
trainer = Trainer(
max_epochs=-1,
num_nodes=cfg.trainer.num_nodes,
# num_sanity_val_steps=0,
accelerator="gpu",
logger=logger,
devices="auto",
strategy=(
"ddp_find_unused_parameters_true"
if torch.cuda.device_count() > 1
else "auto"
),
# strategy="deepspeed_stage_1",
callbacks=callbacks,
val_check_interval=cfg.trainer.val_check_interval,
check_val_every_n_epoch=None,
enable_progress_bar=False,
gradient_clip_val=cfg.trainer.gradient_clip_val,
max_steps=cfg.trainer.max_steps,
precision=cfg.trainer.precision,
accumulate_grad_batches=cfg.trainer.accumulate_grad_batches,
# plugins=[SLURMEnvironment(requeue_signal=signal.SIGUSR1)], # Uncomment for SLURM auto resubmission.
inference_mode=False if (cfg.mode == "test" and cfg.test.align_pose) else True,
)
torch.manual_seed(cfg_dict.seed + trainer.global_rank)
model = get_model(cfg.model.encoder, cfg.model.decoder)
model_wrapper = ModelWrapper(
cfg.optimizer,
cfg.test,
cfg.train,
model,
get_losses(cfg.loss),
step_tracker
)
data_module = DataModule(
cfg.dataset,
cfg.data_loader,
step_tracker,
global_rank=trainer.global_rank,
)
if cfg.mode == "train":
trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
else:
trainer.test(
model_wrapper,
datamodule=data_module,
ckpt_path=checkpoint_path,
)
if __name__ == "__main__":
train()
|