|
import os |
|
import hydra |
|
import wandb |
|
from os.path import isfile, join |
|
from shutil import copyfile |
|
|
|
import torch |
|
|
|
from omegaconf import OmegaConf |
|
from hydra.core.hydra_config import HydraConfig |
|
from hydra.utils import instantiate |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from lightning_fabric.utilities.rank_zero import _get_rank |
|
from callbacks import EMACallback, FixNANinGrad, IncreaseDataEpoch |
|
from models.module import RandomGeolocalizer |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
|
|
|
|
OmegaConf.register_new_resolver("eval", eval) |
|
|
|
|
|
def wandb_init(cfg): |
|
directory = cfg.checkpoints.dirpath |
|
if isfile(join(directory, "wandb_id.txt")) and cfg.logger_suffix == "": |
|
with open(join(directory, "wandb_id.txt"), "r") as f: |
|
wandb_id = f.readline() |
|
else: |
|
rank = _get_rank() |
|
wandb_id = wandb.util.generate_id() |
|
print(f"Generated wandb id: {wandb_id}") |
|
if rank == 0 or rank is None: |
|
with open(join(directory, "wandb_id.txt"), "w") as f: |
|
f.write(str(wandb_id)) |
|
|
|
return wandb_id |
|
|
|
|
|
def load_model(cfg, dict_config, wandb_id, callbacks): |
|
directory = cfg.checkpoints.dirpath |
|
if isfile(join(directory, "last.ckpt")): |
|
checkpoint_path = join(directory, "last.ckpt") |
|
logger = instantiate(cfg.logger, id=wandb_id, resume="allow") |
|
model = RandomGeolocalizer.load_from_checkpoint(checkpoint_path, cfg=cfg.model) |
|
ckpt_path = join(directory, "last.ckpt") |
|
print(f"Loading form checkpoint ... {ckpt_path}") |
|
else: |
|
ckpt_path = None |
|
logger = instantiate(cfg.logger, id=wandb_id, resume="allow") |
|
log_dict = {"model": dict_config["model"], "dataset": dict_config["dataset"]} |
|
logger._wandb_init.update({"config": log_dict}) |
|
model = RandomGeolocalizer(cfg.model) |
|
|
|
trainer, strategy = cfg.trainer, cfg.trainer.strategy |
|
|
|
|
|
trainer = instantiate( |
|
trainer, |
|
strategy=strategy, |
|
logger=logger, |
|
callbacks=callbacks, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
return trainer, model, ckpt_path |
|
|
|
|
|
def project_init(cfg): |
|
print("Working directory set to {}".format(os.getcwd())) |
|
directory = cfg.checkpoints.dirpath |
|
os.makedirs(directory, exist_ok=True) |
|
copyfile(".hydra/config.yaml", join(directory, "config.yaml")) |
|
|
|
|
|
def callback_init(cfg): |
|
checkpoint_callback = instantiate(cfg.checkpoints) |
|
progress_bar = instantiate(cfg.progress_bar) |
|
lr_monitor = LearningRateMonitor() |
|
ema_callback = EMACallback( |
|
"network", |
|
"ema_network", |
|
decay=cfg.model.ema_decay, |
|
start_ema_step=cfg.model.start_ema_step, |
|
init_ema_random=False, |
|
) |
|
fix_nan_callback = FixNANinGrad( |
|
monitor=["train/loss"], |
|
) |
|
increase_data_epoch_callback = IncreaseDataEpoch() |
|
callbacks = [ |
|
checkpoint_callback, |
|
progress_bar, |
|
lr_monitor, |
|
ema_callback, |
|
fix_nan_callback, |
|
increase_data_epoch_callback, |
|
] |
|
return callbacks |
|
|
|
|
|
def init_datamodule(cfg): |
|
datamodule = instantiate(cfg.datamodule) |
|
return datamodule |
|
|
|
|
|
def hydra_boilerplate(cfg): |
|
dict_config = OmegaConf.to_container(cfg, resolve=True) |
|
callbacks = callback_init(cfg) |
|
datamodule = init_datamodule(cfg) |
|
project_init(cfg) |
|
wandb_id = wandb_init(cfg) |
|
trainer, model, ckpt_path = load_model(cfg, dict_config, wandb_id, callbacks) |
|
return trainer, model, datamodule, ckpt_path |
|
|
|
|
|
@hydra.main(config_path="configs", config_name="config", version_base=None) |
|
def main(cfg): |
|
if "stage" in cfg and cfg.stage == "debug": |
|
import lovely_tensors as lt |
|
|
|
lt.monkey_patch() |
|
trainer, model, datamodule, ckpt_path = hydra_boilerplate(cfg) |
|
model.datamodule = datamodule |
|
|
|
if cfg.mode == "train": |
|
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) |
|
elif cfg.mode == "eval": |
|
trainer.test(model, datamodule=datamodule) |
|
elif cfg.mode == "traineval": |
|
cfg.mode = "train" |
|
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) |
|
cfg.mode = "test" |
|
trainer.test(model, datamodule=datamodule) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|