Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from argparse import ArgumentParser | |
| import os | |
| from models.tta.ldm.audioldm_inference import AudioLDMInference | |
| from utils.util import save_config, load_model_config, load_config | |
| import numpy as np | |
| import torch | |
| def build_inference(args, cfg): | |
| supported_inference = { | |
| "AudioLDM": AudioLDMInference, | |
| } | |
| inference_class = supported_inference[cfg.model_type] | |
| inference = inference_class(args, cfg) | |
| return inference | |
| def build_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| required=True, | |
| help="JSON/YAML file for configurations.", | |
| ) | |
| parser.add_argument( | |
| "--text", | |
| help="Text to be synthesized", | |
| type=str, | |
| default="Text to be synthesized.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_path", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--vocoder_path", type=str, help="Checkpoint path of the vocoder" | |
| ) | |
| parser.add_argument( | |
| "--vocoder_config_path", type=str, help="Config path of the vocoder" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default=None, | |
| help="Output dir for saving generated results", | |
| ) | |
| parser.add_argument( | |
| "--num_steps", | |
| type=int, | |
| default=200, | |
| help="The total number of denosing steps", | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=4.0, | |
| help="The scale of classifer free guidance", | |
| ) | |
| parser.add_argument("--local_rank", default=-1, type=int) | |
| return parser | |
| def main(): | |
| # Parse arguments | |
| args = build_parser().parse_args() | |
| # args, infer_type = formulate_parser(args) | |
| # Parse config | |
| cfg = load_config(args.config) | |
| if torch.cuda.is_available(): | |
| args.local_rank = torch.device("cuda") | |
| else: | |
| args.local_rank = torch.device("cpu") | |
| print("args: ", args) | |
| # Build inference | |
| inferencer = build_inference(args, cfg) | |
| # Run inference | |
| inferencer.inference() | |
| if __name__ == "__main__": | |
| main() | |