|
import os |
|
|
|
import hydra |
|
import omegaconf |
|
import pytorch_lightning as pl |
|
import torch |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
def hydra_params_to_dotdict(hparams): |
|
def _to_dot_dict(cfg): |
|
res = {} |
|
for k, v in cfg.items(): |
|
if isinstance(v, omegaconf.DictConfig): |
|
res.update( |
|
{k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()} |
|
) |
|
elif isinstance(v, (str, int, float, bool)): |
|
res[k] = v |
|
|
|
return res |
|
|
|
return _to_dot_dict(hparams) |
|
|
|
|
|
@hydra.main("config/config.yaml") |
|
def main(cfg): |
|
model = hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg)) |
|
|
|
early_stop_callback = pl.callbacks.EarlyStopping(patience=5) |
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
monitor="val_acc", |
|
mode="max", |
|
save_top_k=2, |
|
filepath=os.path.join( |
|
cfg.task_model.name, "{epoch}-{val_loss:.2f}-{val_acc:.3f}" |
|
), |
|
verbose=True, |
|
) |
|
trainer = pl.Trainer( |
|
gpus=list(cfg.gpus), |
|
max_epochs=cfg.epochs, |
|
early_stop_callback=early_stop_callback, |
|
checkpoint_callback=checkpoint_callback, |
|
distributed_backend=cfg.distrib_backend, |
|
) |
|
|
|
trainer.fit(model) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|