Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import os | |
| import torch | |
| from models.tta.autoencoder.autoencoder_trainer import AutoencoderKLTrainer | |
| from models.tta.ldm.audioldm_trainer import AudioLDMTrainer | |
| from utils.util import load_config | |
| def build_trainer(args, cfg): | |
| supported_trainer = { | |
| "AutoencoderKL": AutoencoderKLTrainer, | |
| "AudioLDM": AudioLDMTrainer, | |
| } | |
| trainer_class = supported_trainer[cfg.model_type] | |
| trainer = trainer_class(args, cfg) | |
| return trainer | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", | |
| default="config.json", | |
| help="json files for configurations.", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=6, help="Number of dataloader workers." | |
| ) | |
| parser.add_argument( | |
| "--exp_name", | |
| type=str, | |
| default="exp_name", | |
| help="A specific name to note the experiment", | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| "--resume", | |
| type=str, | |
| default=None, | |
| # action="store_true", | |
| help="The model name to restore", | |
| ) | |
| parser.add_argument( | |
| "--log_level", default="info", help="logging level (info, debug, warning)" | |
| ) | |
| parser.add_argument("--stdout_interval", default=5, type=int) | |
| parser.add_argument("--local_rank", default=-1, type=int) | |
| args = parser.parse_args() | |
| cfg = load_config(args.config) | |
| cfg.exp_name = args.exp_name | |
| # Model saving dir | |
| args.log_dir = os.path.join(cfg.log_dir, args.exp_name) | |
| os.makedirs(args.log_dir, exist_ok=True) | |
| if not cfg.train.ddp: | |
| args.local_rank = torch.device("cuda") | |
| # Build trainer | |
| trainer = build_trainer(args, cfg) | |
| # Restore models | |
| if args.resume: | |
| trainer.restore() | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() | |