AnySplat / src /main.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
5.01 kB
import os
from pathlib import Path
import hydra
import torch
import wandb
import random
from colorama import Fore
from jaxtyping import install_import_hook
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.plugins.environments import SLURMEnvironment
from lightning.pytorch.strategies import DeepSpeedStrategy
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model.model import get_model
from src.misc.weight_modify import checkpoint_filter_fn
import warnings
warnings.filterwarnings("ignore")
# Configure beartype and jaxtyping.
with install_import_hook(
("src",),
("beartype", "beartype"),
):
from src.config import load_typed_root_config
from src.dataset.data_module import DataModule
from src.global_cfg import set_cfg
from src.loss import get_losses
from src.misc.LocalLogger import LocalLogger
from src.misc.step_tracker import StepTracker
from src.misc.wandb_tools import update_checkpoint_path
from src.model.decoder import get_decoder
from src.model.encoder import get_encoder
from src.model.model_wrapper import ModelWrapper
def cyan(text: str) -> str:
return f"{Fore.CYAN}{text}{Fore.RESET}"
@hydra.main(
version_base=None,
config_path="../config",
config_name="main",
)
def train(cfg_dict: DictConfig):
cfg = load_typed_root_config(cfg_dict)
set_cfg(cfg_dict)
# Set up the output directory.
output_dir = Path(
hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]
)
output_dir.mkdir(parents=True, exist_ok=True)
print(cyan(f"Saving outputs to {output_dir}."))
cfg.train.output_path = output_dir
# Set up logging with wandb.
callbacks = []
if cfg_dict.wandb.mode != "disabled":
logger = WandbLogger(
project=cfg_dict.wandb.project,
mode=cfg_dict.wandb.mode,
name=f"{cfg_dict.wandb.name} ({output_dir.parent.name}/{output_dir.name})",
tags=cfg_dict.wandb.get("tags", None),
log_model=False,
save_dir=output_dir,
config=OmegaConf.to_container(cfg_dict),
)
callbacks.append(LearningRateMonitor("step", True))
# On rank != 0, wandb.run is None.
if wandb.run is not None:
wandb.run.log_code("src")
else:
logger = LocalLogger()
# Set up checkpointing.
callbacks.append(
ModelCheckpoint(
output_dir / "checkpoints",
every_n_train_steps=cfg.checkpointing.every_n_train_steps,
save_top_k=cfg.checkpointing.save_top_k,
save_weights_only=cfg.checkpointing.save_weights_only,
monitor="info/global_step",
mode="max",
)
)
callbacks[-1].CHECKPOINT_EQUALS_CHAR = '_'
# Prepare the checkpoint for loading.
checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb)
# This allows the current step to be shared with the data loader processes.
step_tracker = StepTracker()
trainer = Trainer(
max_epochs=-1,
num_nodes=cfg.trainer.num_nodes,
# num_sanity_val_steps=0,
accelerator="gpu",
logger=logger,
devices="auto",
strategy=(
"ddp_find_unused_parameters_true"
if torch.cuda.device_count() > 1
else "auto"
),
# strategy="deepspeed_stage_1",
callbacks=callbacks,
val_check_interval=cfg.trainer.val_check_interval,
check_val_every_n_epoch=None,
enable_progress_bar=False,
gradient_clip_val=cfg.trainer.gradient_clip_val,
max_steps=cfg.trainer.max_steps,
precision=cfg.trainer.precision,
accumulate_grad_batches=cfg.trainer.accumulate_grad_batches,
# plugins=[SLURMEnvironment(requeue_signal=signal.SIGUSR1)], # Uncomment for SLURM auto resubmission.
inference_mode=False if (cfg.mode == "test" and cfg.test.align_pose) else True,
)
torch.manual_seed(cfg_dict.seed + trainer.global_rank)
model = get_model(cfg.model.encoder, cfg.model.decoder)
model_wrapper = ModelWrapper(
cfg.optimizer,
cfg.test,
cfg.train,
model,
get_losses(cfg.loss),
step_tracker
)
data_module = DataModule(
cfg.dataset,
cfg.data_loader,
step_tracker,
global_rank=trainer.global_rank,
)
if cfg.mode == "train":
trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
else:
trainer.test(
model_wrapper,
datamodule=data_module,
ckpt_path=checkpoint_path,
)
if __name__ == "__main__":
train()