Spaces:
Runtime error
Runtime error
| import argparse | |
| import glob | |
| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from lightning import FontLightningModule | |
| from utils import save_files | |
| def load_configuration(path_config): | |
| setting = OmegaConf.load(path_config) | |
| # load hyperparameter | |
| hp = OmegaConf.load(setting.config.dataset) | |
| hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.model)) | |
| hp = OmegaConf.merge(hp, OmegaConf.load(setting.config.logging)) | |
| # with lightning setting | |
| if hasattr(setting.config, 'lightning'): | |
| pl_config = OmegaConf.load(setting.config.lightning) | |
| if hasattr(pl_config, 'pl_config'): | |
| return hp, pl_config.pl_config | |
| return hp, pl_config | |
| # without lightning setting | |
| return hp | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Code to train font style transfer') | |
| parser.add_argument("--config", type=str, default="./config/setting.yaml", | |
| help="Config file for training") | |
| parser.add_argument('-g', '--gpus', type=str, default='0,1', | |
| help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.") | |
| parser.add_argument('-p', '--resume_checkpoint_path', type=str, default=None, | |
| help="path of checkpoint for resuming") | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| hp, pl_config = load_configuration(args.config) | |
| logging_dir = Path(hp.logging.log_dir) | |
| # call lightning module | |
| font_pl = FontLightningModule(hp) | |
| # set logging | |
| hp.logging['log_dir'] = logging_dir / 'tensorboard' | |
| savefiles = [] | |
| for reg in hp.logging.savefiles: | |
| savefiles += glob.glob(reg) | |
| hp.logging['log_dir'].mkdir(exist_ok=True) | |
| save_files(str(logging_dir), savefiles) | |
| # set tensorboard logger | |
| logger = TensorBoardLogger(str(logging_dir), name=str(hp.logging.seed)) | |
| # set checkpoing callback | |
| weights_save_path = logging_dir / 'checkpoint' / str(hp.logging.seed) | |
| weights_save_path.mkdir(exist_ok=True) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=str(weights_save_path), | |
| **pl_config.checkpoint.callback | |
| ) | |
| # set lightning trainer | |
| trainer = pl.Trainer( | |
| logger=logger, | |
| gpus=-1 if args.gpus is None else args.gpus, | |
| callbacks=[checkpoint_callback], | |
| **pl_config.trainer | |
| ) | |
| # let's train | |
| trainer.fit(font_pl) | |
| if __name__ == "__main__": | |
| main() | |