Spaces:
Running
on
Zero
Running
on
Zero
from pytorch_lightning.callbacks import Callback | |
import pytorch_lightning as pl | |
import os | |
from omegaconf import OmegaConf | |
from pytorch_lightning.utilities import rank_zero_only | |
MULTINODE_HACKS = True | |
class SetupCallback(Callback): | |
def __init__( | |
self, | |
resume, | |
now, | |
logdir, | |
ckptdir, | |
cfgdir, | |
config, | |
lightning_config, | |
debug, | |
ckpt_name=None, | |
): | |
super().__init__() | |
self.resume = resume | |
self.now = now | |
self.logdir = logdir | |
self.ckptdir = ckptdir | |
self.cfgdir = cfgdir | |
self.config = config | |
self.lightning_config = lightning_config | |
self.debug = debug | |
self.ckpt_name = ckpt_name | |
def on_exception(self, trainer: pl.Trainer, pl_module, exception): | |
print("Exception occurred: {}".format(exception)) | |
if not self.debug and trainer.global_rank == 0: | |
print("Summoning checkpoint.") | |
if self.ckpt_name is None: | |
ckpt_path = os.path.join(self.ckptdir, "last.ckpt") | |
else: | |
ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) | |
trainer.save_checkpoint(ckpt_path) | |
def on_fit_start(self, trainer, pl_module): | |
if trainer.global_rank == 0: | |
# Create logdirs and save configs | |
os.makedirs(self.logdir, exist_ok=True) | |
os.makedirs(self.ckptdir, exist_ok=True) | |
os.makedirs(self.cfgdir, exist_ok=True) | |
if "callbacks" in self.lightning_config: | |
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: | |
os.makedirs( | |
os.path.join(self.ckptdir, "trainstep_checkpoints"), | |
exist_ok=True, | |
) | |
print("Project config") | |
print(OmegaConf.to_yaml(self.config)) | |
if MULTINODE_HACKS: | |
import time | |
time.sleep(5) | |
OmegaConf.save( | |
self.config, | |
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), | |
) | |
print("Lightning config") | |
print(OmegaConf.to_yaml(self.lightning_config)) | |
OmegaConf.save( | |
OmegaConf.create({"lightning": self.lightning_config}), | |
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), | |
) | |
else: | |
# ModelCheckpoint callback created log directory --- remove it | |
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): | |
dst, name = os.path.split(self.logdir) | |
dst = os.path.join(dst, "child_runs", name) | |
os.makedirs(os.path.split(dst)[0], exist_ok=True) | |
try: | |
os.rename(self.logdir, dst) | |
except FileNotFoundError: | |
pass | |