File size: 1,423 Bytes
b30c1d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import TensorBoardLogger
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def hydra_params_to_dotdict(hparams):
def _to_dot_dict(cfg):
res = {}
for k, v in cfg.items():
if isinstance(v, omegaconf.DictConfig):
res.update(
{k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()}
)
elif isinstance(v, (str, int, float, bool)):
res[k] = v
return res
return _to_dot_dict(hparams)
@hydra.main("config/config.yaml")
def main(cfg):
model = hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg))
early_stop_callback = pl.callbacks.EarlyStopping(patience=5)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="val_acc",
mode="max",
save_top_k=2,
filepath=os.path.join(
cfg.task_model.name, "{epoch}-{val_loss:.2f}-{val_acc:.3f}"
),
verbose=True,
)
trainer = pl.Trainer(
gpus=list(cfg.gpus),
max_epochs=cfg.epochs,
early_stop_callback=early_stop_callback,
checkpoint_callback=checkpoint_callback,
distributed_backend=cfg.distrib_backend,
)
trainer.fit(model)
if __name__ == "__main__":
main()
|