YuanTang96's picture
1
b30c1d8
import os
import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import TensorBoardLogger
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def hydra_params_to_dotdict(hparams):
def _to_dot_dict(cfg):
res = {}
for k, v in cfg.items():
if isinstance(v, omegaconf.DictConfig):
res.update(
{k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()}
)
elif isinstance(v, (str, int, float, bool)):
res[k] = v
return res
return _to_dot_dict(hparams)
@hydra.main("config/config.yaml")
def main(cfg):
model = hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg))
early_stop_callback = pl.callbacks.EarlyStopping(patience=5)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="val_acc",
mode="max",
save_top_k=2,
filepath=os.path.join(
cfg.task_model.name, "{epoch}-{val_loss:.2f}-{val_acc:.3f}"
),
verbose=True,
)
trainer = pl.Trainer(
gpus=list(cfg.gpus),
max_epochs=cfg.epochs,
early_stop_callback=early_stop_callback,
checkpoint_callback=checkpoint_callback,
distributed_backend=cfg.distrib_backend,
)
trainer.fit(model)
if __name__ == "__main__":
main()