Spaces:
Running
Running
File size: 11,056 Bytes
ac5de5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
# config.py
# Configuration management for Dia TTS server
import os
import logging
from dotenv import load_dotenv, find_dotenv, set_key
from typing import Dict, Any, Optional
# Configure logging
logger = logging.getLogger(__name__)
# Default configuration values (used if not found in .env or environment)
DEFAULT_CONFIG = {
# Server Settings
"HOST": "0.0.0.0",
"PORT": "8003",
# Model Source Settings
"DIA_MODEL_REPO_ID": "ttj/dia-1.6b-safetensors", # Default to safetensors repo
"DIA_MODEL_CONFIG_FILENAME": "config.json", # Standard config filename
"DIA_MODEL_WEIGHTS_FILENAME": "dia-v0_1_bf16.safetensors", # Default to BF16 weights
# Path Settings
"DIA_MODEL_CACHE_PATH": "./model_cache",
"REFERENCE_AUDIO_PATH": "./reference_audio",
"OUTPUT_PATH": "./outputs",
# Default Generation Parameters (can be overridden by user in UI/API)
# These are saved to .env via the UI's "Save Generation Defaults" button
"GEN_DEFAULT_SPEED_FACTOR": "0.90", # Default speed slightly slower
"GEN_DEFAULT_CFG_SCALE": "3.0",
"GEN_DEFAULT_TEMPERATURE": "1.3",
"GEN_DEFAULT_TOP_P": "0.95",
"GEN_DEFAULT_CFG_FILTER_TOP_K": "35",
}
class ConfigManager:
"""Manages configuration for the TTS server with .env file support."""
def __init__(self):
"""Initialize the configuration manager."""
self.config = {}
self.env_file = find_dotenv()
if not self.env_file:
self.env_file = os.path.join(os.getcwd(), ".env")
logger.info(
f"No .env file found, creating one with defaults at {self.env_file}"
)
self._create_default_env_file()
else:
logger.info(f"Loading configuration from: {self.env_file}")
self.reload()
def _create_default_env_file(self):
"""Create a default .env file with default values."""
try:
with open(self.env_file, "w") as f:
for key, value in DEFAULT_CONFIG.items():
f.write(f"{key}={value}\n")
logger.info("Created default .env file")
except Exception as e:
logger.error(f"Failed to create default .env file: {e}")
def reload(self):
"""Reload configuration from .env file and environment variables."""
load_dotenv(self.env_file, override=True)
loaded_config = {}
for key, default_value in DEFAULT_CONFIG.items():
loaded_config[key] = os.environ.get(key, default_value)
self.config = loaded_config
logger.info("Configuration loaded/reloaded.")
logger.debug(f"Current config: {self.config}")
return self.config
def get(self, key: str, default: Any = None) -> Any:
"""Get a configuration value by key."""
return self.config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set a configuration value in memory (does not save automatically)."""
self.config[key] = value
logger.debug(f"Configuration value set in memory: {key}={value}")
def save(self) -> bool:
"""Save the current in-memory configuration to the .env file."""
if not self.env_file:
logger.error("Cannot save configuration, .env file path not set.")
return False
try:
for key in DEFAULT_CONFIG.keys():
if key not in self.config:
logger.warning(
f"Key '{key}' missing from current config, adding default value before saving."
)
self.config[key] = DEFAULT_CONFIG[key]
for key, value in self.config.items():
if key in DEFAULT_CONFIG:
set_key(self.env_file, key, str(value))
logger.info(f"Configuration saved to {self.env_file}")
return True
except Exception as e:
logger.error(
f"Failed to save configuration to {self.env_file}: {e}", exc_info=True
)
return False
def get_all(self) -> Dict[str, Any]:
"""Get all current configuration values."""
return self.config.copy()
def update(self, new_config: Dict[str, Any]) -> None:
"""Update multiple configuration values in memory from a dictionary."""
updated_keys = []
for key, value in new_config.items():
if key in DEFAULT_CONFIG:
self.config[key] = value
updated_keys.append(key)
else:
logger.warning(
f"Attempted to update unknown config key: {key}. Ignoring."
)
if updated_keys:
logger.debug(
f"Configuration values updated in memory for keys: {updated_keys}"
)
def get_int(self, key: str, default: Optional[int] = None) -> int:
"""Get a configuration value as an integer, with error handling."""
value_str = self.get(key) # Get value which might be from env (str) or default
if value_str is None: # Key not found at all
if default is not None:
logger.warning(
f"Config key '{key}' not found, using provided default: {default}"
)
return default
else:
logger.error(
f"Mandatory config key '{key}' not found and no default provided. Returning 0."
)
return 0 # Or raise error
try:
return int(value_str)
except (ValueError, TypeError):
logger.warning(
f"Invalid integer value '{value_str}' for config key '{key}', using default: {default}"
)
if isinstance(default, int):
return default
elif default is None:
logger.error(
f"Cannot parse '{value_str}' as int for key '{key}' and no valid default. Returning 0."
)
return 0
else: # Default was provided but not an int
logger.error(
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0."
)
return 0
def get_float(self, key: str, default: Optional[float] = None) -> float:
"""Get a configuration value as a float, with error handling."""
value_str = self.get(key)
if value_str is None:
if default is not None:
logger.warning(
f"Config key '{key}' not found, using provided default: {default}"
)
return default
else:
logger.error(
f"Mandatory config key '{key}' not found and no default provided. Returning 0.0."
)
return 0.0
try:
return float(value_str)
except (ValueError, TypeError):
logger.warning(
f"Invalid float value '{value_str}' for config key '{key}', using default: {default}"
)
if isinstance(default, float):
return default
elif default is None:
logger.error(
f"Cannot parse '{value_str}' as float for key '{key}' and no valid default. Returning 0.0."
)
return 0.0
else:
logger.error(
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0.0."
)
return 0.0
# --- Create a singleton instance for global access ---
config_manager = ConfigManager()
# --- Export common getters for easy access ---
# Server Settings
def get_host() -> str:
"""Gets the host address for the server."""
return config_manager.get("HOST", DEFAULT_CONFIG["HOST"])
def get_port() -> int:
"""Gets the port number for the server."""
# Ensure default is parsed correctly if get_int fails on env var
return config_manager.get_int("PORT", int(DEFAULT_CONFIG["PORT"]))
# Model Source Settings
def get_model_repo_id() -> str:
"""Gets the Hugging Face repository ID for the model."""
return config_manager.get("DIA_MODEL_REPO_ID", DEFAULT_CONFIG["DIA_MODEL_REPO_ID"])
def get_model_config_filename() -> str:
"""Gets the filename for the model's configuration file within the repo."""
return config_manager.get(
"DIA_MODEL_CONFIG_FILENAME", DEFAULT_CONFIG["DIA_MODEL_CONFIG_FILENAME"]
)
def get_model_weights_filename() -> str:
"""Gets the filename for the model's weights file within the repo."""
return config_manager.get(
"DIA_MODEL_WEIGHTS_FILENAME", DEFAULT_CONFIG["DIA_MODEL_WEIGHTS_FILENAME"]
)
# Path Settings
def get_model_cache_path() -> str:
"""Gets the local directory path for caching downloaded models."""
return os.path.abspath(
config_manager.get(
"DIA_MODEL_CACHE_PATH", DEFAULT_CONFIG["DIA_MODEL_CACHE_PATH"]
)
)
def get_reference_audio_path() -> str:
"""Gets the local directory path for storing reference audio files for cloning."""
return os.path.abspath(
config_manager.get(
"REFERENCE_AUDIO_PATH", DEFAULT_CONFIG["REFERENCE_AUDIO_PATH"]
)
)
def get_output_path() -> str:
"""Gets the local directory path for saving generated audio outputs."""
return os.path.abspath(
config_manager.get("OUTPUT_PATH", DEFAULT_CONFIG["OUTPUT_PATH"])
)
# Default Generation Parameter Getters
def get_gen_default_speed_factor() -> float:
"""Gets the default speed factor for generation."""
return config_manager.get_float(
"GEN_DEFAULT_SPEED_FACTOR", float(DEFAULT_CONFIG["GEN_DEFAULT_SPEED_FACTOR"])
)
def get_gen_default_cfg_scale() -> float:
"""Gets the default CFG scale for generation."""
return config_manager.get_float(
"GEN_DEFAULT_CFG_SCALE", float(DEFAULT_CONFIG["GEN_DEFAULT_CFG_SCALE"])
)
def get_gen_default_temperature() -> float:
"""Gets the default temperature for generation."""
return config_manager.get_float(
"GEN_DEFAULT_TEMPERATURE", float(DEFAULT_CONFIG["GEN_DEFAULT_TEMPERATURE"])
)
def get_gen_default_top_p() -> float:
"""Gets the default top_p for generation."""
return config_manager.get_float(
"GEN_DEFAULT_TOP_P", float(DEFAULT_CONFIG["GEN_DEFAULT_TOP_P"])
)
def get_gen_default_cfg_filter_top_k() -> int:
"""Gets the default CFG filter top_k for generation."""
return config_manager.get_int(
"GEN_DEFAULT_CFG_FILTER_TOP_K",
int(DEFAULT_CONFIG["GEN_DEFAULT_CFG_FILTER_TOP_K"]),
)
|