dia-tts-server / config.py
Michael Hu
initial check in of the dia tts server
ac5de5b
# config.py
# Configuration management for Dia TTS server
import os
import logging
from dotenv import load_dotenv, find_dotenv, set_key
from typing import Dict, Any, Optional
# Configure logging
logger = logging.getLogger(__name__)
# Default configuration values (used if not found in .env or environment)
DEFAULT_CONFIG = {
# Server Settings
"HOST": "0.0.0.0",
"PORT": "8003",
# Model Source Settings
"DIA_MODEL_REPO_ID": "ttj/dia-1.6b-safetensors", # Default to safetensors repo
"DIA_MODEL_CONFIG_FILENAME": "config.json", # Standard config filename
"DIA_MODEL_WEIGHTS_FILENAME": "dia-v0_1_bf16.safetensors", # Default to BF16 weights
# Path Settings
"DIA_MODEL_CACHE_PATH": "./model_cache",
"REFERENCE_AUDIO_PATH": "./reference_audio",
"OUTPUT_PATH": "./outputs",
# Default Generation Parameters (can be overridden by user in UI/API)
# These are saved to .env via the UI's "Save Generation Defaults" button
"GEN_DEFAULT_SPEED_FACTOR": "0.90", # Default speed slightly slower
"GEN_DEFAULT_CFG_SCALE": "3.0",
"GEN_DEFAULT_TEMPERATURE": "1.3",
"GEN_DEFAULT_TOP_P": "0.95",
"GEN_DEFAULT_CFG_FILTER_TOP_K": "35",
}
class ConfigManager:
"""Manages configuration for the TTS server with .env file support."""
def __init__(self):
"""Initialize the configuration manager."""
self.config = {}
self.env_file = find_dotenv()
if not self.env_file:
self.env_file = os.path.join(os.getcwd(), ".env")
logger.info(
f"No .env file found, creating one with defaults at {self.env_file}"
)
self._create_default_env_file()
else:
logger.info(f"Loading configuration from: {self.env_file}")
self.reload()
def _create_default_env_file(self):
"""Create a default .env file with default values."""
try:
with open(self.env_file, "w") as f:
for key, value in DEFAULT_CONFIG.items():
f.write(f"{key}={value}\n")
logger.info("Created default .env file")
except Exception as e:
logger.error(f"Failed to create default .env file: {e}")
def reload(self):
"""Reload configuration from .env file and environment variables."""
load_dotenv(self.env_file, override=True)
loaded_config = {}
for key, default_value in DEFAULT_CONFIG.items():
loaded_config[key] = os.environ.get(key, default_value)
self.config = loaded_config
logger.info("Configuration loaded/reloaded.")
logger.debug(f"Current config: {self.config}")
return self.config
def get(self, key: str, default: Any = None) -> Any:
"""Get a configuration value by key."""
return self.config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set a configuration value in memory (does not save automatically)."""
self.config[key] = value
logger.debug(f"Configuration value set in memory: {key}={value}")
def save(self) -> bool:
"""Save the current in-memory configuration to the .env file."""
if not self.env_file:
logger.error("Cannot save configuration, .env file path not set.")
return False
try:
for key in DEFAULT_CONFIG.keys():
if key not in self.config:
logger.warning(
f"Key '{key}' missing from current config, adding default value before saving."
)
self.config[key] = DEFAULT_CONFIG[key]
for key, value in self.config.items():
if key in DEFAULT_CONFIG:
set_key(self.env_file, key, str(value))
logger.info(f"Configuration saved to {self.env_file}")
return True
except Exception as e:
logger.error(
f"Failed to save configuration to {self.env_file}: {e}", exc_info=True
)
return False
def get_all(self) -> Dict[str, Any]:
"""Get all current configuration values."""
return self.config.copy()
def update(self, new_config: Dict[str, Any]) -> None:
"""Update multiple configuration values in memory from a dictionary."""
updated_keys = []
for key, value in new_config.items():
if key in DEFAULT_CONFIG:
self.config[key] = value
updated_keys.append(key)
else:
logger.warning(
f"Attempted to update unknown config key: {key}. Ignoring."
)
if updated_keys:
logger.debug(
f"Configuration values updated in memory for keys: {updated_keys}"
)
def get_int(self, key: str, default: Optional[int] = None) -> int:
"""Get a configuration value as an integer, with error handling."""
value_str = self.get(key) # Get value which might be from env (str) or default
if value_str is None: # Key not found at all
if default is not None:
logger.warning(
f"Config key '{key}' not found, using provided default: {default}"
)
return default
else:
logger.error(
f"Mandatory config key '{key}' not found and no default provided. Returning 0."
)
return 0 # Or raise error
try:
return int(value_str)
except (ValueError, TypeError):
logger.warning(
f"Invalid integer value '{value_str}' for config key '{key}', using default: {default}"
)
if isinstance(default, int):
return default
elif default is None:
logger.error(
f"Cannot parse '{value_str}' as int for key '{key}' and no valid default. Returning 0."
)
return 0
else: # Default was provided but not an int
logger.error(
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0."
)
return 0
def get_float(self, key: str, default: Optional[float] = None) -> float:
"""Get a configuration value as a float, with error handling."""
value_str = self.get(key)
if value_str is None:
if default is not None:
logger.warning(
f"Config key '{key}' not found, using provided default: {default}"
)
return default
else:
logger.error(
f"Mandatory config key '{key}' not found and no default provided. Returning 0.0."
)
return 0.0
try:
return float(value_str)
except (ValueError, TypeError):
logger.warning(
f"Invalid float value '{value_str}' for config key '{key}', using default: {default}"
)
if isinstance(default, float):
return default
elif default is None:
logger.error(
f"Cannot parse '{value_str}' as float for key '{key}' and no valid default. Returning 0.0."
)
return 0.0
else:
logger.error(
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0.0."
)
return 0.0
# --- Create a singleton instance for global access ---
config_manager = ConfigManager()
# --- Export common getters for easy access ---
# Server Settings
def get_host() -> str:
"""Gets the host address for the server."""
return config_manager.get("HOST", DEFAULT_CONFIG["HOST"])
def get_port() -> int:
"""Gets the port number for the server."""
# Ensure default is parsed correctly if get_int fails on env var
return config_manager.get_int("PORT", int(DEFAULT_CONFIG["PORT"]))
# Model Source Settings
def get_model_repo_id() -> str:
"""Gets the Hugging Face repository ID for the model."""
return config_manager.get("DIA_MODEL_REPO_ID", DEFAULT_CONFIG["DIA_MODEL_REPO_ID"])
def get_model_config_filename() -> str:
"""Gets the filename for the model's configuration file within the repo."""
return config_manager.get(
"DIA_MODEL_CONFIG_FILENAME", DEFAULT_CONFIG["DIA_MODEL_CONFIG_FILENAME"]
)
def get_model_weights_filename() -> str:
"""Gets the filename for the model's weights file within the repo."""
return config_manager.get(
"DIA_MODEL_WEIGHTS_FILENAME", DEFAULT_CONFIG["DIA_MODEL_WEIGHTS_FILENAME"]
)
# Path Settings
def get_model_cache_path() -> str:
"""Gets the local directory path for caching downloaded models."""
return os.path.abspath(
config_manager.get(
"DIA_MODEL_CACHE_PATH", DEFAULT_CONFIG["DIA_MODEL_CACHE_PATH"]
)
)
def get_reference_audio_path() -> str:
"""Gets the local directory path for storing reference audio files for cloning."""
return os.path.abspath(
config_manager.get(
"REFERENCE_AUDIO_PATH", DEFAULT_CONFIG["REFERENCE_AUDIO_PATH"]
)
)
def get_output_path() -> str:
"""Gets the local directory path for saving generated audio outputs."""
return os.path.abspath(
config_manager.get("OUTPUT_PATH", DEFAULT_CONFIG["OUTPUT_PATH"])
)
# Default Generation Parameter Getters
def get_gen_default_speed_factor() -> float:
"""Gets the default speed factor for generation."""
return config_manager.get_float(
"GEN_DEFAULT_SPEED_FACTOR", float(DEFAULT_CONFIG["GEN_DEFAULT_SPEED_FACTOR"])
)
def get_gen_default_cfg_scale() -> float:
"""Gets the default CFG scale for generation."""
return config_manager.get_float(
"GEN_DEFAULT_CFG_SCALE", float(DEFAULT_CONFIG["GEN_DEFAULT_CFG_SCALE"])
)
def get_gen_default_temperature() -> float:
"""Gets the default temperature for generation."""
return config_manager.get_float(
"GEN_DEFAULT_TEMPERATURE", float(DEFAULT_CONFIG["GEN_DEFAULT_TEMPERATURE"])
)
def get_gen_default_top_p() -> float:
"""Gets the default top_p for generation."""
return config_manager.get_float(
"GEN_DEFAULT_TOP_P", float(DEFAULT_CONFIG["GEN_DEFAULT_TOP_P"])
)
def get_gen_default_cfg_filter_top_k() -> int:
"""Gets the default CFG filter top_k for generation."""
return config_manager.get_int(
"GEN_DEFAULT_CFG_FILTER_TOP_K",
int(DEFAULT_CONFIG["GEN_DEFAULT_CFG_FILTER_TOP_K"]),
)