Spaces:
Running
Running
# config.py | |
# Configuration management for Dia TTS server | |
import os | |
import logging | |
from dotenv import load_dotenv, find_dotenv, set_key | |
from typing import Dict, Any, Optional | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
# Default configuration values (used if not found in .env or environment) | |
DEFAULT_CONFIG = { | |
# Server Settings | |
"HOST": "0.0.0.0", | |
"PORT": "8003", | |
# Model Source Settings | |
"DIA_MODEL_REPO_ID": "ttj/dia-1.6b-safetensors", # Default to safetensors repo | |
"DIA_MODEL_CONFIG_FILENAME": "config.json", # Standard config filename | |
"DIA_MODEL_WEIGHTS_FILENAME": "dia-v0_1_bf16.safetensors", # Default to BF16 weights | |
# Path Settings | |
"DIA_MODEL_CACHE_PATH": "./model_cache", | |
"REFERENCE_AUDIO_PATH": "./reference_audio", | |
"OUTPUT_PATH": "./outputs", | |
# Default Generation Parameters (can be overridden by user in UI/API) | |
# These are saved to .env via the UI's "Save Generation Defaults" button | |
"GEN_DEFAULT_SPEED_FACTOR": "0.90", # Default speed slightly slower | |
"GEN_DEFAULT_CFG_SCALE": "3.0", | |
"GEN_DEFAULT_TEMPERATURE": "1.3", | |
"GEN_DEFAULT_TOP_P": "0.95", | |
"GEN_DEFAULT_CFG_FILTER_TOP_K": "35", | |
} | |
class ConfigManager: | |
"""Manages configuration for the TTS server with .env file support.""" | |
def __init__(self): | |
"""Initialize the configuration manager.""" | |
self.config = {} | |
self.env_file = find_dotenv() | |
if not self.env_file: | |
self.env_file = os.path.join(os.getcwd(), ".env") | |
logger.info( | |
f"No .env file found, creating one with defaults at {self.env_file}" | |
) | |
self._create_default_env_file() | |
else: | |
logger.info(f"Loading configuration from: {self.env_file}") | |
self.reload() | |
def _create_default_env_file(self): | |
"""Create a default .env file with default values.""" | |
try: | |
with open(self.env_file, "w") as f: | |
for key, value in DEFAULT_CONFIG.items(): | |
f.write(f"{key}={value}\n") | |
logger.info("Created default .env file") | |
except Exception as e: | |
logger.error(f"Failed to create default .env file: {e}") | |
def reload(self): | |
"""Reload configuration from .env file and environment variables.""" | |
load_dotenv(self.env_file, override=True) | |
loaded_config = {} | |
for key, default_value in DEFAULT_CONFIG.items(): | |
loaded_config[key] = os.environ.get(key, default_value) | |
self.config = loaded_config | |
logger.info("Configuration loaded/reloaded.") | |
logger.debug(f"Current config: {self.config}") | |
return self.config | |
def get(self, key: str, default: Any = None) -> Any: | |
"""Get a configuration value by key.""" | |
return self.config.get(key, default) | |
def set(self, key: str, value: Any) -> None: | |
"""Set a configuration value in memory (does not save automatically).""" | |
self.config[key] = value | |
logger.debug(f"Configuration value set in memory: {key}={value}") | |
def save(self) -> bool: | |
"""Save the current in-memory configuration to the .env file.""" | |
if not self.env_file: | |
logger.error("Cannot save configuration, .env file path not set.") | |
return False | |
try: | |
for key in DEFAULT_CONFIG.keys(): | |
if key not in self.config: | |
logger.warning( | |
f"Key '{key}' missing from current config, adding default value before saving." | |
) | |
self.config[key] = DEFAULT_CONFIG[key] | |
for key, value in self.config.items(): | |
if key in DEFAULT_CONFIG: | |
set_key(self.env_file, key, str(value)) | |
logger.info(f"Configuration saved to {self.env_file}") | |
return True | |
except Exception as e: | |
logger.error( | |
f"Failed to save configuration to {self.env_file}: {e}", exc_info=True | |
) | |
return False | |
def get_all(self) -> Dict[str, Any]: | |
"""Get all current configuration values.""" | |
return self.config.copy() | |
def update(self, new_config: Dict[str, Any]) -> None: | |
"""Update multiple configuration values in memory from a dictionary.""" | |
updated_keys = [] | |
for key, value in new_config.items(): | |
if key in DEFAULT_CONFIG: | |
self.config[key] = value | |
updated_keys.append(key) | |
else: | |
logger.warning( | |
f"Attempted to update unknown config key: {key}. Ignoring." | |
) | |
if updated_keys: | |
logger.debug( | |
f"Configuration values updated in memory for keys: {updated_keys}" | |
) | |
def get_int(self, key: str, default: Optional[int] = None) -> int: | |
"""Get a configuration value as an integer, with error handling.""" | |
value_str = self.get(key) # Get value which might be from env (str) or default | |
if value_str is None: # Key not found at all | |
if default is not None: | |
logger.warning( | |
f"Config key '{key}' not found, using provided default: {default}" | |
) | |
return default | |
else: | |
logger.error( | |
f"Mandatory config key '{key}' not found and no default provided. Returning 0." | |
) | |
return 0 # Or raise error | |
try: | |
return int(value_str) | |
except (ValueError, TypeError): | |
logger.warning( | |
f"Invalid integer value '{value_str}' for config key '{key}', using default: {default}" | |
) | |
if isinstance(default, int): | |
return default | |
elif default is None: | |
logger.error( | |
f"Cannot parse '{value_str}' as int for key '{key}' and no valid default. Returning 0." | |
) | |
return 0 | |
else: # Default was provided but not an int | |
logger.error( | |
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0." | |
) | |
return 0 | |
def get_float(self, key: str, default: Optional[float] = None) -> float: | |
"""Get a configuration value as a float, with error handling.""" | |
value_str = self.get(key) | |
if value_str is None: | |
if default is not None: | |
logger.warning( | |
f"Config key '{key}' not found, using provided default: {default}" | |
) | |
return default | |
else: | |
logger.error( | |
f"Mandatory config key '{key}' not found and no default provided. Returning 0.0." | |
) | |
return 0.0 | |
try: | |
return float(value_str) | |
except (ValueError, TypeError): | |
logger.warning( | |
f"Invalid float value '{value_str}' for config key '{key}', using default: {default}" | |
) | |
if isinstance(default, float): | |
return default | |
elif default is None: | |
logger.error( | |
f"Cannot parse '{value_str}' as float for key '{key}' and no valid default. Returning 0.0." | |
) | |
return 0.0 | |
else: | |
logger.error( | |
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0.0." | |
) | |
return 0.0 | |
# --- Create a singleton instance for global access --- | |
config_manager = ConfigManager() | |
# --- Export common getters for easy access --- | |
# Server Settings | |
def get_host() -> str: | |
"""Gets the host address for the server.""" | |
return config_manager.get("HOST", DEFAULT_CONFIG["HOST"]) | |
def get_port() -> int: | |
"""Gets the port number for the server.""" | |
# Ensure default is parsed correctly if get_int fails on env var | |
return config_manager.get_int("PORT", int(DEFAULT_CONFIG["PORT"])) | |
# Model Source Settings | |
def get_model_repo_id() -> str: | |
"""Gets the Hugging Face repository ID for the model.""" | |
return config_manager.get("DIA_MODEL_REPO_ID", DEFAULT_CONFIG["DIA_MODEL_REPO_ID"]) | |
def get_model_config_filename() -> str: | |
"""Gets the filename for the model's configuration file within the repo.""" | |
return config_manager.get( | |
"DIA_MODEL_CONFIG_FILENAME", DEFAULT_CONFIG["DIA_MODEL_CONFIG_FILENAME"] | |
) | |
def get_model_weights_filename() -> str: | |
"""Gets the filename for the model's weights file within the repo.""" | |
return config_manager.get( | |
"DIA_MODEL_WEIGHTS_FILENAME", DEFAULT_CONFIG["DIA_MODEL_WEIGHTS_FILENAME"] | |
) | |
# Path Settings | |
def get_model_cache_path() -> str: | |
"""Gets the local directory path for caching downloaded models.""" | |
return os.path.abspath( | |
config_manager.get( | |
"DIA_MODEL_CACHE_PATH", DEFAULT_CONFIG["DIA_MODEL_CACHE_PATH"] | |
) | |
) | |
def get_reference_audio_path() -> str: | |
"""Gets the local directory path for storing reference audio files for cloning.""" | |
return os.path.abspath( | |
config_manager.get( | |
"REFERENCE_AUDIO_PATH", DEFAULT_CONFIG["REFERENCE_AUDIO_PATH"] | |
) | |
) | |
def get_output_path() -> str: | |
"""Gets the local directory path for saving generated audio outputs.""" | |
return os.path.abspath( | |
config_manager.get("OUTPUT_PATH", DEFAULT_CONFIG["OUTPUT_PATH"]) | |
) | |
# Default Generation Parameter Getters | |
def get_gen_default_speed_factor() -> float: | |
"""Gets the default speed factor for generation.""" | |
return config_manager.get_float( | |
"GEN_DEFAULT_SPEED_FACTOR", float(DEFAULT_CONFIG["GEN_DEFAULT_SPEED_FACTOR"]) | |
) | |
def get_gen_default_cfg_scale() -> float: | |
"""Gets the default CFG scale for generation.""" | |
return config_manager.get_float( | |
"GEN_DEFAULT_CFG_SCALE", float(DEFAULT_CONFIG["GEN_DEFAULT_CFG_SCALE"]) | |
) | |
def get_gen_default_temperature() -> float: | |
"""Gets the default temperature for generation.""" | |
return config_manager.get_float( | |
"GEN_DEFAULT_TEMPERATURE", float(DEFAULT_CONFIG["GEN_DEFAULT_TEMPERATURE"]) | |
) | |
def get_gen_default_top_p() -> float: | |
"""Gets the default top_p for generation.""" | |
return config_manager.get_float( | |
"GEN_DEFAULT_TOP_P", float(DEFAULT_CONFIG["GEN_DEFAULT_TOP_P"]) | |
) | |
def get_gen_default_cfg_filter_top_k() -> int: | |
"""Gets the default CFG filter top_k for generation.""" | |
return config_manager.get_int( | |
"GEN_DEFAULT_CFG_FILTER_TOP_K", | |
int(DEFAULT_CONFIG["GEN_DEFAULT_CFG_FILTER_TOP_K"]), | |
) | |