from pytorch_lightning.callbacks import Callback import pytorch_lightning as pl import os from omegaconf import OmegaConf from pytorch_lightning.utilities import rank_zero_only MULTINODE_HACKS = True class SetupCallback(Callback): def __init__( self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config, debug, ckpt_name=None, ): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config self.lightning_config = lightning_config self.debug = debug self.ckpt_name = ckpt_name @rank_zero_only def on_exception(self, trainer: pl.Trainer, pl_module, exception): print("Exception occurred: {}".format(exception)) if not self.debug and trainer.global_rank == 0: print("Summoning checkpoint.") if self.ckpt_name is None: ckpt_path = os.path.join(self.ckptdir, "last.ckpt") else: ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) trainer.save_checkpoint(ckpt_path) @rank_zero_only def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) if "callbacks" in self.lightning_config: if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: os.makedirs( os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True, ) print("Project config") print(OmegaConf.to_yaml(self.config)) if MULTINODE_HACKS: import time time.sleep(5) OmegaConf.save( self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), ) print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save( OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), ) else: # ModelCheckpoint callback created log directory --- remove it if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) except FileNotFoundError: pass