Spaces:
Build error
Build error
| import json | |
| import os | |
| import re | |
| from typing import Dict | |
| import fsspec | |
| import yaml | |
| from coqpit import Coqpit | |
| from TTS.config.shared_configs import * | |
| from TTS.utils.generic_utils import find_module | |
| def read_json_with_comments(json_path): | |
| """for backward compat.""" | |
| # fallback to json | |
| with fsspec.open(json_path, "r", encoding="utf-8") as f: | |
| input_str = f.read() | |
| # handle comments | |
| input_str = re.sub(r"\\\n", "", input_str) | |
| input_str = re.sub(r"//.*\n", "\n", input_str) | |
| data = json.loads(input_str) | |
| return data | |
| def register_config(model_name: str) -> Coqpit: | |
| """Find the right config for the given model name. | |
| Args: | |
| model_name (str): Model name. | |
| Raises: | |
| ModuleNotFoundError: No matching config for the model name. | |
| Returns: | |
| Coqpit: config class. | |
| """ | |
| config_class = None | |
| config_name = model_name + "_config" | |
| # TODO: fix this | |
| if model_name == "xtts": | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| config_class = XttsConfig | |
| paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"] | |
| for path in paths: | |
| try: | |
| config_class = find_module(path, config_name) | |
| except ModuleNotFoundError: | |
| pass | |
| if config_class is None: | |
| raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") | |
| return config_class | |
| def _process_model_name(config_dict: Dict) -> str: | |
| """Format the model name as expected. It is a band-aid for the old `vocoder` model names. | |
| Args: | |
| config_dict (Dict): A dictionary including the config fields. | |
| Returns: | |
| str: Formatted modelname. | |
| """ | |
| model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] | |
| model_name = model_name.replace("_generator", "").replace("_discriminator", "") | |
| return model_name | |
| def load_config(config_path: str) -> Coqpit: | |
| """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name | |
| to find the corresponding Config class. Then initialize the Config. | |
| Args: | |
| config_path (str): path to the config file. | |
| Raises: | |
| TypeError: given config file has an unknown type. | |
| Returns: | |
| Coqpit: TTS config object. | |
| """ | |
| config_dict = {} | |
| ext = os.path.splitext(config_path)[1] | |
| if ext in (".yml", ".yaml"): | |
| with fsspec.open(config_path, "r", encoding="utf-8") as f: | |
| data = yaml.safe_load(f) | |
| elif ext == ".json": | |
| try: | |
| with fsspec.open(config_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| except json.decoder.JSONDecodeError: | |
| # backwards compat. | |
| data = read_json_with_comments(config_path) | |
| else: | |
| raise TypeError(f" [!] Unknown config file type {ext}") | |
| config_dict.update(data) | |
| model_name = _process_model_name(config_dict) | |
| config_class = register_config(model_name.lower()) | |
| config = config_class() | |
| config.from_dict(config_dict) | |
| return config | |
| def check_config_and_model_args(config, arg_name, value): | |
| """Check the give argument in `config.model_args` if exist or in `config` for | |
| the given value. | |
| Return False if the argument does not exist in `config.model_args` or `config`. | |
| This is to patch up the compatibility between models with and without `model_args`. | |
| TODO: Remove this in the future with a unified approach. | |
| """ | |
| if hasattr(config, "model_args"): | |
| if arg_name in config.model_args: | |
| return config.model_args[arg_name] == value | |
| if hasattr(config, arg_name): | |
| return config[arg_name] == value | |
| return False | |
| def get_from_config_or_model_args(config, arg_name): | |
| """Get the given argument from `config.model_args` if exist or in `config`.""" | |
| if hasattr(config, "model_args"): | |
| if arg_name in config.model_args: | |
| return config.model_args[arg_name] | |
| return config[arg_name] | |
| def get_from_config_or_model_args_with_default(config, arg_name, def_val): | |
| """Get the given argument from `config.model_args` if exist or in `config`.""" | |
| if hasattr(config, "model_args"): | |
| if arg_name in config.model_args: | |
| return config.model_args[arg_name] | |
| if hasattr(config, arg_name): | |
| return config[arg_name] | |
| return def_val | |