Spaces:
Running
Running
Michael Hu
commited on
Commit
·
ac5de5b
1
Parent(s):
2493d3b
initial check in of the dia tts server
Browse files- .env +35 -0
- Dockerfile +36 -0
- README.md +1 -0
- config.py +295 -0
- dia/__init__.py +0 -0
- dia/audio.py +280 -0
- dia/config.py +206 -0
- dia/layers.py +903 -0
- dia/model.py +956 -0
- docker-compose.yml +23 -0
- documentation.md +549 -0
- download_model.py +41 -0
- engine.py +356 -0
- models.py +97 -0
- requirements.txt +22 -0
- server.py +1061 -0
- ui/index.html +916 -0
- ui/presets.yaml +57 -0
- ui/script.js +593 -0
- utils.py +146 -0
.env
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .env - Configuration for Dia TTS Server
|
2 |
+
# Values in this file override the defaults set in config.py
|
3 |
+
|
4 |
+
# --- Server Settings ---
|
5 |
+
HOST='0.0.0.0'
|
6 |
+
PORT='8003'
|
7 |
+
|
8 |
+
# --- Path Settings ---
|
9 |
+
# Defaults are usually fine unless you want custom locations.
|
10 |
+
DIA_MODEL_CACHE_PATH='./model_cache'
|
11 |
+
REFERENCE_AUDIO_PATH='./reference_audio'
|
12 |
+
OUTPUT_PATH='./outputs'
|
13 |
+
|
14 |
+
# --- Model Source Settings ---
|
15 |
+
# Defaulting to BF16 safetensors. Uncomment and modify lines below to use other models.
|
16 |
+
DIA_MODEL_REPO_ID='ttj/dia-1.6b-safetensors'
|
17 |
+
DIA_MODEL_CONFIG_FILENAME='config.json'
|
18 |
+
DIA_MODEL_WEIGHTS_FILENAME='dia-v0_1_bf16.safetensors'
|
19 |
+
|
20 |
+
# Example: Use full precision safetensors
|
21 |
+
# DIA_MODEL_REPO_ID=ttj/dia-1.6b-safetensors
|
22 |
+
# DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.safetensors
|
23 |
+
|
24 |
+
# Example: Use original Nari Labs .pth model
|
25 |
+
# DIA_MODEL_REPO_ID=nari-labs/Dia-1.6B
|
26 |
+
# DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.pth
|
27 |
+
|
28 |
+
# --- Default Generation Parameters ---
|
29 |
+
# These set the initial values loaded in the UI.
|
30 |
+
# They can be changed in the UI and saved back here using the 'Save Generation Defaults' button.
|
31 |
+
GEN_DEFAULT_SPEED_FACTOR='0.9'
|
32 |
+
GEN_DEFAULT_CFG_SCALE='3'
|
33 |
+
GEN_DEFAULT_TEMPERATURE='1.3'
|
34 |
+
GEN_DEFAULT_TOP_P='0.95'
|
35 |
+
GEN_DEFAULT_CFG_FILTER_TOP_K='35'
|
Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
|
2 |
+
|
3 |
+
# Set environment variables
|
4 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
5 |
+
ENV PYTHONUNBUFFERED=1
|
6 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
7 |
+
|
8 |
+
# Install system dependencies
|
9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
10 |
+
build-essential \
|
11 |
+
libsndfile1 \
|
12 |
+
ffmpeg \
|
13 |
+
python3 \
|
14 |
+
python3-pip \
|
15 |
+
python3-dev \
|
16 |
+
&& apt-get clean \
|
17 |
+
&& rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
# Set up working directory
|
20 |
+
WORKDIR /app
|
21 |
+
|
22 |
+
# Install Python dependencies
|
23 |
+
COPY requirements.txt .
|
24 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
25 |
+
|
26 |
+
# Copy application code
|
27 |
+
COPY . .
|
28 |
+
|
29 |
+
# Create required directories
|
30 |
+
RUN mkdir -p model_cache reference_audio outputs
|
31 |
+
|
32 |
+
# Expose the port the application will run on (default to 8003 as per config)
|
33 |
+
EXPOSE 8003
|
34 |
+
|
35 |
+
# Command to run the application
|
36 |
+
CMD ["python3", "server.py"]
|
README.md
CHANGED
@@ -5,6 +5,7 @@ colorFrom: indigo
|
|
5 |
colorTo: indigo
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
5 |
colorTo: indigo
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
app_port: 8003
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
config.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config.py
|
2 |
+
# Configuration management for Dia TTS server
|
3 |
+
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
from dotenv import load_dotenv, find_dotenv, set_key
|
7 |
+
from typing import Dict, Any, Optional
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# Default configuration values (used if not found in .env or environment)
|
13 |
+
DEFAULT_CONFIG = {
|
14 |
+
# Server Settings
|
15 |
+
"HOST": "0.0.0.0",
|
16 |
+
"PORT": "8003",
|
17 |
+
# Model Source Settings
|
18 |
+
"DIA_MODEL_REPO_ID": "ttj/dia-1.6b-safetensors", # Default to safetensors repo
|
19 |
+
"DIA_MODEL_CONFIG_FILENAME": "config.json", # Standard config filename
|
20 |
+
"DIA_MODEL_WEIGHTS_FILENAME": "dia-v0_1_bf16.safetensors", # Default to BF16 weights
|
21 |
+
# Path Settings
|
22 |
+
"DIA_MODEL_CACHE_PATH": "./model_cache",
|
23 |
+
"REFERENCE_AUDIO_PATH": "./reference_audio",
|
24 |
+
"OUTPUT_PATH": "./outputs",
|
25 |
+
# Default Generation Parameters (can be overridden by user in UI/API)
|
26 |
+
# These are saved to .env via the UI's "Save Generation Defaults" button
|
27 |
+
"GEN_DEFAULT_SPEED_FACTOR": "0.90", # Default speed slightly slower
|
28 |
+
"GEN_DEFAULT_CFG_SCALE": "3.0",
|
29 |
+
"GEN_DEFAULT_TEMPERATURE": "1.3",
|
30 |
+
"GEN_DEFAULT_TOP_P": "0.95",
|
31 |
+
"GEN_DEFAULT_CFG_FILTER_TOP_K": "35",
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
class ConfigManager:
|
36 |
+
"""Manages configuration for the TTS server with .env file support."""
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
"""Initialize the configuration manager."""
|
40 |
+
self.config = {}
|
41 |
+
self.env_file = find_dotenv()
|
42 |
+
|
43 |
+
if not self.env_file:
|
44 |
+
self.env_file = os.path.join(os.getcwd(), ".env")
|
45 |
+
logger.info(
|
46 |
+
f"No .env file found, creating one with defaults at {self.env_file}"
|
47 |
+
)
|
48 |
+
self._create_default_env_file()
|
49 |
+
else:
|
50 |
+
logger.info(f"Loading configuration from: {self.env_file}")
|
51 |
+
|
52 |
+
self.reload()
|
53 |
+
|
54 |
+
def _create_default_env_file(self):
|
55 |
+
"""Create a default .env file with default values."""
|
56 |
+
try:
|
57 |
+
with open(self.env_file, "w") as f:
|
58 |
+
for key, value in DEFAULT_CONFIG.items():
|
59 |
+
f.write(f"{key}={value}\n")
|
60 |
+
logger.info("Created default .env file")
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Failed to create default .env file: {e}")
|
63 |
+
|
64 |
+
def reload(self):
|
65 |
+
"""Reload configuration from .env file and environment variables."""
|
66 |
+
load_dotenv(self.env_file, override=True)
|
67 |
+
loaded_config = {}
|
68 |
+
for key, default_value in DEFAULT_CONFIG.items():
|
69 |
+
loaded_config[key] = os.environ.get(key, default_value)
|
70 |
+
self.config = loaded_config
|
71 |
+
logger.info("Configuration loaded/reloaded.")
|
72 |
+
logger.debug(f"Current config: {self.config}")
|
73 |
+
return self.config
|
74 |
+
|
75 |
+
def get(self, key: str, default: Any = None) -> Any:
|
76 |
+
"""Get a configuration value by key."""
|
77 |
+
return self.config.get(key, default)
|
78 |
+
|
79 |
+
def set(self, key: str, value: Any) -> None:
|
80 |
+
"""Set a configuration value in memory (does not save automatically)."""
|
81 |
+
self.config[key] = value
|
82 |
+
logger.debug(f"Configuration value set in memory: {key}={value}")
|
83 |
+
|
84 |
+
def save(self) -> bool:
|
85 |
+
"""Save the current in-memory configuration to the .env file."""
|
86 |
+
if not self.env_file:
|
87 |
+
logger.error("Cannot save configuration, .env file path not set.")
|
88 |
+
return False
|
89 |
+
try:
|
90 |
+
for key in DEFAULT_CONFIG.keys():
|
91 |
+
if key not in self.config:
|
92 |
+
logger.warning(
|
93 |
+
f"Key '{key}' missing from current config, adding default value before saving."
|
94 |
+
)
|
95 |
+
self.config[key] = DEFAULT_CONFIG[key]
|
96 |
+
for key, value in self.config.items():
|
97 |
+
if key in DEFAULT_CONFIG:
|
98 |
+
set_key(self.env_file, key, str(value))
|
99 |
+
logger.info(f"Configuration saved to {self.env_file}")
|
100 |
+
return True
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(
|
103 |
+
f"Failed to save configuration to {self.env_file}: {e}", exc_info=True
|
104 |
+
)
|
105 |
+
return False
|
106 |
+
|
107 |
+
def get_all(self) -> Dict[str, Any]:
|
108 |
+
"""Get all current configuration values."""
|
109 |
+
return self.config.copy()
|
110 |
+
|
111 |
+
def update(self, new_config: Dict[str, Any]) -> None:
|
112 |
+
"""Update multiple configuration values in memory from a dictionary."""
|
113 |
+
updated_keys = []
|
114 |
+
for key, value in new_config.items():
|
115 |
+
if key in DEFAULT_CONFIG:
|
116 |
+
self.config[key] = value
|
117 |
+
updated_keys.append(key)
|
118 |
+
else:
|
119 |
+
logger.warning(
|
120 |
+
f"Attempted to update unknown config key: {key}. Ignoring."
|
121 |
+
)
|
122 |
+
if updated_keys:
|
123 |
+
logger.debug(
|
124 |
+
f"Configuration values updated in memory for keys: {updated_keys}"
|
125 |
+
)
|
126 |
+
|
127 |
+
def get_int(self, key: str, default: Optional[int] = None) -> int:
|
128 |
+
"""Get a configuration value as an integer, with error handling."""
|
129 |
+
value_str = self.get(key) # Get value which might be from env (str) or default
|
130 |
+
if value_str is None: # Key not found at all
|
131 |
+
if default is not None:
|
132 |
+
logger.warning(
|
133 |
+
f"Config key '{key}' not found, using provided default: {default}"
|
134 |
+
)
|
135 |
+
return default
|
136 |
+
else:
|
137 |
+
logger.error(
|
138 |
+
f"Mandatory config key '{key}' not found and no default provided. Returning 0."
|
139 |
+
)
|
140 |
+
return 0 # Or raise error
|
141 |
+
|
142 |
+
try:
|
143 |
+
return int(value_str)
|
144 |
+
except (ValueError, TypeError):
|
145 |
+
logger.warning(
|
146 |
+
f"Invalid integer value '{value_str}' for config key '{key}', using default: {default}"
|
147 |
+
)
|
148 |
+
if isinstance(default, int):
|
149 |
+
return default
|
150 |
+
elif default is None:
|
151 |
+
logger.error(
|
152 |
+
f"Cannot parse '{value_str}' as int for key '{key}' and no valid default. Returning 0."
|
153 |
+
)
|
154 |
+
return 0
|
155 |
+
else: # Default was provided but not an int
|
156 |
+
logger.error(
|
157 |
+
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0."
|
158 |
+
)
|
159 |
+
return 0
|
160 |
+
|
161 |
+
def get_float(self, key: str, default: Optional[float] = None) -> float:
|
162 |
+
"""Get a configuration value as a float, with error handling."""
|
163 |
+
value_str = self.get(key)
|
164 |
+
if value_str is None:
|
165 |
+
if default is not None:
|
166 |
+
logger.warning(
|
167 |
+
f"Config key '{key}' not found, using provided default: {default}"
|
168 |
+
)
|
169 |
+
return default
|
170 |
+
else:
|
171 |
+
logger.error(
|
172 |
+
f"Mandatory config key '{key}' not found and no default provided. Returning 0.0."
|
173 |
+
)
|
174 |
+
return 0.0
|
175 |
+
|
176 |
+
try:
|
177 |
+
return float(value_str)
|
178 |
+
except (ValueError, TypeError):
|
179 |
+
logger.warning(
|
180 |
+
f"Invalid float value '{value_str}' for config key '{key}', using default: {default}"
|
181 |
+
)
|
182 |
+
if isinstance(default, float):
|
183 |
+
return default
|
184 |
+
elif default is None:
|
185 |
+
logger.error(
|
186 |
+
f"Cannot parse '{value_str}' as float for key '{key}' and no valid default. Returning 0.0."
|
187 |
+
)
|
188 |
+
return 0.0
|
189 |
+
else:
|
190 |
+
logger.error(
|
191 |
+
f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0.0."
|
192 |
+
)
|
193 |
+
return 0.0
|
194 |
+
|
195 |
+
|
196 |
+
# --- Create a singleton instance for global access ---
|
197 |
+
config_manager = ConfigManager()
|
198 |
+
|
199 |
+
|
200 |
+
# --- Export common getters for easy access ---
|
201 |
+
|
202 |
+
|
203 |
+
# Server Settings
|
204 |
+
def get_host() -> str:
|
205 |
+
"""Gets the host address for the server."""
|
206 |
+
return config_manager.get("HOST", DEFAULT_CONFIG["HOST"])
|
207 |
+
|
208 |
+
|
209 |
+
def get_port() -> int:
|
210 |
+
"""Gets the port number for the server."""
|
211 |
+
# Ensure default is parsed correctly if get_int fails on env var
|
212 |
+
return config_manager.get_int("PORT", int(DEFAULT_CONFIG["PORT"]))
|
213 |
+
|
214 |
+
|
215 |
+
# Model Source Settings
|
216 |
+
def get_model_repo_id() -> str:
|
217 |
+
"""Gets the Hugging Face repository ID for the model."""
|
218 |
+
return config_manager.get("DIA_MODEL_REPO_ID", DEFAULT_CONFIG["DIA_MODEL_REPO_ID"])
|
219 |
+
|
220 |
+
|
221 |
+
def get_model_config_filename() -> str:
|
222 |
+
"""Gets the filename for the model's configuration file within the repo."""
|
223 |
+
return config_manager.get(
|
224 |
+
"DIA_MODEL_CONFIG_FILENAME", DEFAULT_CONFIG["DIA_MODEL_CONFIG_FILENAME"]
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
def get_model_weights_filename() -> str:
|
229 |
+
"""Gets the filename for the model's weights file within the repo."""
|
230 |
+
return config_manager.get(
|
231 |
+
"DIA_MODEL_WEIGHTS_FILENAME", DEFAULT_CONFIG["DIA_MODEL_WEIGHTS_FILENAME"]
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
# Path Settings
|
236 |
+
def get_model_cache_path() -> str:
|
237 |
+
"""Gets the local directory path for caching downloaded models."""
|
238 |
+
return os.path.abspath(
|
239 |
+
config_manager.get(
|
240 |
+
"DIA_MODEL_CACHE_PATH", DEFAULT_CONFIG["DIA_MODEL_CACHE_PATH"]
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
def get_reference_audio_path() -> str:
|
246 |
+
"""Gets the local directory path for storing reference audio files for cloning."""
|
247 |
+
return os.path.abspath(
|
248 |
+
config_manager.get(
|
249 |
+
"REFERENCE_AUDIO_PATH", DEFAULT_CONFIG["REFERENCE_AUDIO_PATH"]
|
250 |
+
)
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
def get_output_path() -> str:
|
255 |
+
"""Gets the local directory path for saving generated audio outputs."""
|
256 |
+
return os.path.abspath(
|
257 |
+
config_manager.get("OUTPUT_PATH", DEFAULT_CONFIG["OUTPUT_PATH"])
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
# Default Generation Parameter Getters
|
262 |
+
def get_gen_default_speed_factor() -> float:
|
263 |
+
"""Gets the default speed factor for generation."""
|
264 |
+
return config_manager.get_float(
|
265 |
+
"GEN_DEFAULT_SPEED_FACTOR", float(DEFAULT_CONFIG["GEN_DEFAULT_SPEED_FACTOR"])
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
def get_gen_default_cfg_scale() -> float:
|
270 |
+
"""Gets the default CFG scale for generation."""
|
271 |
+
return config_manager.get_float(
|
272 |
+
"GEN_DEFAULT_CFG_SCALE", float(DEFAULT_CONFIG["GEN_DEFAULT_CFG_SCALE"])
|
273 |
+
)
|
274 |
+
|
275 |
+
|
276 |
+
def get_gen_default_temperature() -> float:
|
277 |
+
"""Gets the default temperature for generation."""
|
278 |
+
return config_manager.get_float(
|
279 |
+
"GEN_DEFAULT_TEMPERATURE", float(DEFAULT_CONFIG["GEN_DEFAULT_TEMPERATURE"])
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
def get_gen_default_top_p() -> float:
|
284 |
+
"""Gets the default top_p for generation."""
|
285 |
+
return config_manager.get_float(
|
286 |
+
"GEN_DEFAULT_TOP_P", float(DEFAULT_CONFIG["GEN_DEFAULT_TOP_P"])
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
def get_gen_default_cfg_filter_top_k() -> int:
|
291 |
+
"""Gets the default CFG filter top_k for generation."""
|
292 |
+
return config_manager.get_int(
|
293 |
+
"GEN_DEFAULT_CFG_FILTER_TOP_K",
|
294 |
+
int(DEFAULT_CONFIG["GEN_DEFAULT_CFG_FILTER_TOP_K"]),
|
295 |
+
)
|
dia/__init__.py
ADDED
File without changes
|
dia/audio.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as tp
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .config import DataConfig
|
6 |
+
|
7 |
+
|
8 |
+
def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
9 |
+
"""
|
10 |
+
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
11 |
+
Negative t_idx => BOS; t_idx >= T => PAD.
|
12 |
+
"""
|
13 |
+
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
|
14 |
+
|
15 |
+
t_idx_BxT = torch.broadcast_to(
|
16 |
+
torch.arange(T, dtype=torch.int32)[None, :],
|
17 |
+
[B, T],
|
18 |
+
)
|
19 |
+
t_idx_BxTx1 = t_idx_BxT[..., None]
|
20 |
+
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
|
21 |
+
|
22 |
+
b_idx_BxTxC = torch.broadcast_to(
|
23 |
+
torch.arange(B, dtype=torch.int32).view(B, 1, 1),
|
24 |
+
[B, T, C],
|
25 |
+
)
|
26 |
+
c_idx_BxTxC = torch.broadcast_to(
|
27 |
+
torch.arange(C, dtype=torch.int32).view(1, 1, C),
|
28 |
+
[B, T, C],
|
29 |
+
)
|
30 |
+
|
31 |
+
# We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
|
32 |
+
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
|
33 |
+
|
34 |
+
indices_BTCx3 = torch.stack(
|
35 |
+
[
|
36 |
+
b_idx_BxTxC.reshape(-1),
|
37 |
+
t_clamped_BxTxC.reshape(-1),
|
38 |
+
c_idx_BxTxC.reshape(-1),
|
39 |
+
],
|
40 |
+
dim=1,
|
41 |
+
).long() # Ensure indices are long type for indexing
|
42 |
+
|
43 |
+
return t_idx_BxTxC, indices_BTCx3
|
44 |
+
|
45 |
+
|
46 |
+
def apply_audio_delay(
|
47 |
+
audio_BxTxC: torch.Tensor,
|
48 |
+
pad_value: int,
|
49 |
+
bos_value: int,
|
50 |
+
precomp: tp.Tuple[torch.Tensor, torch.Tensor],
|
51 |
+
) -> torch.Tensor:
|
52 |
+
"""
|
53 |
+
Applies the delay pattern to batched audio tokens using precomputed indices,
|
54 |
+
inserting BOS where t_idx < 0 and PAD where t_idx >= T.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
|
58 |
+
pad_value: the padding token
|
59 |
+
bos_value: the BOS token
|
60 |
+
precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
result_BxTxC: [B, T, C] delayed audio tokens
|
64 |
+
"""
|
65 |
+
device = audio_BxTxC.device # Get device from input tensor
|
66 |
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
67 |
+
t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
|
68 |
+
indices_BTCx3 = indices_BTCx3.to(device)
|
69 |
+
|
70 |
+
# Equivalent of tf.gather_nd using advanced indexing
|
71 |
+
# Ensure indices are long type if not already (build_delay_indices should handle this)
|
72 |
+
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
73 |
+
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
74 |
+
|
75 |
+
# Create masks on the correct device
|
76 |
+
mask_bos = t_idx_BxTxC < 0 # => place bos_value
|
77 |
+
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
|
78 |
+
|
79 |
+
# Create scalar tensors on the correct device
|
80 |
+
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
|
81 |
+
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
82 |
+
|
83 |
+
# If mask_bos, BOS; else if mask_pad, PAD; else original gather
|
84 |
+
# All tensors should now be on the same device
|
85 |
+
result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
|
86 |
+
|
87 |
+
return result_BxTxC
|
88 |
+
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
@torch.inference_mode()
|
92 |
+
def audio_to_codebook(
|
93 |
+
model,
|
94 |
+
input_values,
|
95 |
+
data_config: DataConfig,
|
96 |
+
padding_mask=None,
|
97 |
+
sample_rate=44100,
|
98 |
+
):
|
99 |
+
"""
|
100 |
+
Encodes the input audio waveform into discrete codes.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
model: The model to use for encoding.
|
104 |
+
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
105 |
+
Float values of the input audio waveform.
|
106 |
+
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
|
107 |
+
Padding mask used to pad the `input_values`.
|
108 |
+
sample_rate (`int`, *optional*) :
|
109 |
+
Signal sampling_rate
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
113 |
+
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
114 |
+
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
115 |
+
Scale is not used here.
|
116 |
+
|
117 |
+
"""
|
118 |
+
audio_data = model.preprocess(input_values, sample_rate)
|
119 |
+
|
120 |
+
if padding_mask is None:
|
121 |
+
padding_mask = torch.ones_like(input_values).bool()
|
122 |
+
|
123 |
+
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
|
124 |
+
seq_length = encoded_frame.shape[2]
|
125 |
+
|
126 |
+
t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
|
127 |
+
B=1,
|
128 |
+
T=seq_length,
|
129 |
+
C=data_config.channels,
|
130 |
+
delay_pattern=data_config.delay_pattern,
|
131 |
+
)
|
132 |
+
|
133 |
+
encoded_frame = apply_audio_delay(
|
134 |
+
audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
|
135 |
+
pad_value=data_config.audio_pad_value,
|
136 |
+
bos_value=data_config.audio_bos_value,
|
137 |
+
precomp=(t_idx_BxTxC, indices_BTCx3),
|
138 |
+
)
|
139 |
+
|
140 |
+
return encoded_frame
|
141 |
+
|
142 |
+
|
143 |
+
def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
144 |
+
"""
|
145 |
+
Precompute indices for the revert operation using PyTorch.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
A tuple (t_idx_BxTxC, indices_BTCx3) where:
|
149 |
+
- t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
|
150 |
+
- indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
|
151 |
+
batch indices, clamped time indices, and channel indices.
|
152 |
+
"""
|
153 |
+
# Use default device unless specified otherwise; assumes inputs might define device later
|
154 |
+
device = None # Or determine dynamically if needed, e.g., from a model parameter
|
155 |
+
|
156 |
+
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
|
157 |
+
|
158 |
+
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
|
159 |
+
t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
|
160 |
+
|
161 |
+
t_idx_BxTxC = torch.minimum(
|
162 |
+
t_idx_BT1 + delay_arr.view(1, 1, C),
|
163 |
+
torch.tensor(T - 1, device=device),
|
164 |
+
)
|
165 |
+
b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
|
166 |
+
c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
|
167 |
+
|
168 |
+
indices_BTCx3 = torch.stack(
|
169 |
+
[
|
170 |
+
b_idx_BxTxC.reshape(-1),
|
171 |
+
t_idx_BxTxC.reshape(-1),
|
172 |
+
c_idx_BxTxC.reshape(-1),
|
173 |
+
],
|
174 |
+
axis=1,
|
175 |
+
).long() # Ensure indices are long type
|
176 |
+
|
177 |
+
return t_idx_BxTxC, indices_BTCx3
|
178 |
+
|
179 |
+
|
180 |
+
def revert_audio_delay(
|
181 |
+
audio_BxTxC: torch.Tensor,
|
182 |
+
pad_value: int,
|
183 |
+
precomp: tp.Tuple[torch.Tensor, torch.Tensor],
|
184 |
+
T: int,
|
185 |
+
) -> torch.Tensor:
|
186 |
+
"""
|
187 |
+
Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
|
188 |
+
|
189 |
+
Args:
|
190 |
+
audio_BxTxC: Input delayed audio tensor
|
191 |
+
pad_value: Padding value for out-of-bounds indices
|
192 |
+
precomp: Precomputed revert indices tuple containing:
|
193 |
+
- t_idx_BxTxC: Time offset indices tensor
|
194 |
+
- indices_BTCx3: Gather indices tensor for original audio
|
195 |
+
T: Original sequence length before padding
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Reverted audio tensor with same shape as input
|
199 |
+
"""
|
200 |
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
201 |
+
device = audio_BxTxC.device # Get device from input tensor
|
202 |
+
|
203 |
+
# Move precomputed indices to the same device as audio_BxTxC if they aren't already
|
204 |
+
t_idx_BxTxC = t_idx_BxTxC.to(device)
|
205 |
+
indices_BTCx3 = indices_BTCx3.to(device)
|
206 |
+
|
207 |
+
# Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
|
208 |
+
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
209 |
+
gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
|
210 |
+
|
211 |
+
# Create pad_tensor on the correct device
|
212 |
+
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
213 |
+
# Create T tensor on the correct device for comparison
|
214 |
+
T_tensor = torch.tensor(T, device=device)
|
215 |
+
|
216 |
+
result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
|
217 |
+
|
218 |
+
return result_BxTxC
|
219 |
+
|
220 |
+
|
221 |
+
@torch.no_grad()
|
222 |
+
@torch.inference_mode()
|
223 |
+
def decode(
|
224 |
+
model,
|
225 |
+
audio_codes,
|
226 |
+
):
|
227 |
+
"""
|
228 |
+
Decodes the given frames into an output audio waveform
|
229 |
+
"""
|
230 |
+
if len(audio_codes) != 1:
|
231 |
+
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
232 |
+
|
233 |
+
try:
|
234 |
+
audio_values = model.quantizer.from_codes(audio_codes)
|
235 |
+
audio_values = model.decode(audio_values[0])
|
236 |
+
|
237 |
+
return audio_values
|
238 |
+
except Exception as e:
|
239 |
+
print(f"Error in decode method: {str(e)}")
|
240 |
+
raise
|
241 |
+
|
242 |
+
|
243 |
+
def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
|
244 |
+
"""Process a single codebook file to generate audio"""
|
245 |
+
# Remove BOS token
|
246 |
+
generated_codes = generated_codes[:, 1:]
|
247 |
+
|
248 |
+
if generated_codes.shape[1] > T:
|
249 |
+
generated_codes = generated_codes[:, :T]
|
250 |
+
|
251 |
+
seq_length = generated_codes.shape[1]
|
252 |
+
|
253 |
+
# Build revert indices
|
254 |
+
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
|
255 |
+
|
256 |
+
# Transpose and add batch dimension
|
257 |
+
audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
|
258 |
+
reverted_codebook = revert_audio_delay(
|
259 |
+
audio_BxTxC=audio_BxTxC,
|
260 |
+
pad_value=0,
|
261 |
+
precomp=(t_idx_BxTxC, indices_BTCx3),
|
262 |
+
T=seq_length,
|
263 |
+
)
|
264 |
+
reverted_codebook = reverted_codebook[:, :-30, :]
|
265 |
+
|
266 |
+
codebook = reverted_codebook.transpose(1, 2)
|
267 |
+
|
268 |
+
min_valid_index = 0
|
269 |
+
max_valid_index = 1023
|
270 |
+
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
271 |
+
|
272 |
+
num_invalid = torch.sum(invalid_mask).item()
|
273 |
+
if num_invalid > 0:
|
274 |
+
print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
|
275 |
+
|
276 |
+
# Set invalid values to 0 (modify the tensor in-place)
|
277 |
+
codebook[invalid_mask] = 0
|
278 |
+
audio_array = decode(model, codebook)
|
279 |
+
|
280 |
+
return audio_array
|
dia/config.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Configuration management module for the Dia model.
|
2 |
+
|
3 |
+
This module provides comprehensive configuration management for the Dia model,
|
4 |
+
utilizing Pydantic for validation. It defines configurations for data processing,
|
5 |
+
model architecture (encoder and decoder), and training settings.
|
6 |
+
|
7 |
+
Key components:
|
8 |
+
- DataConfig: Parameters for data loading and preprocessing.
|
9 |
+
- EncoderConfig: Architecture details for the encoder module.
|
10 |
+
- DecoderConfig: Architecture details for the decoder module.
|
11 |
+
- ModelConfig: Combined model architecture settings.
|
12 |
+
- TrainingConfig: Training hyperparameters and settings.
|
13 |
+
- DiaConfig: Master configuration combining all components.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
from typing import Annotated
|
18 |
+
|
19 |
+
from pydantic import BaseModel, BeforeValidator, Field
|
20 |
+
|
21 |
+
|
22 |
+
class DataConfig(BaseModel, frozen=True):
|
23 |
+
"""Configuration for data loading and preprocessing.
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
text_length: Maximum length of text sequences (must be multiple of 128).
|
27 |
+
audio_length: Maximum length of audio sequences (must be multiple of 128).
|
28 |
+
channels: Number of audio channels.
|
29 |
+
text_pad_value: Value used for padding text sequences.
|
30 |
+
audio_eos_value: Value representing the end of audio sequences.
|
31 |
+
audio_bos_value: Value representing the beginning of audio sequences.
|
32 |
+
audio_pad_value: Value used for padding audio sequences.
|
33 |
+
delay_pattern: List of delay values for each audio channel.
|
34 |
+
"""
|
35 |
+
|
36 |
+
text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
|
37 |
+
audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
|
38 |
+
channels: int = Field(default=9, gt=0, multiple_of=1)
|
39 |
+
text_pad_value: int = Field(default=0)
|
40 |
+
audio_eos_value: int = Field(default=1024)
|
41 |
+
audio_pad_value: int = Field(default=1025)
|
42 |
+
audio_bos_value: int = Field(default=1026)
|
43 |
+
delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
|
44 |
+
|
45 |
+
def __hash__(self) -> int:
|
46 |
+
"""Generate a hash based on all fields of the config."""
|
47 |
+
return hash(
|
48 |
+
(
|
49 |
+
self.text_length,
|
50 |
+
self.audio_length,
|
51 |
+
self.channels,
|
52 |
+
self.text_pad_value,
|
53 |
+
self.audio_pad_value,
|
54 |
+
self.audio_bos_value,
|
55 |
+
self.audio_eos_value,
|
56 |
+
tuple(self.delay_pattern),
|
57 |
+
)
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
class EncoderConfig(BaseModel, frozen=True):
|
62 |
+
"""Configuration for the encoder component of the Dia model.
|
63 |
+
|
64 |
+
Attributes:
|
65 |
+
n_layer: Number of transformer layers.
|
66 |
+
n_embd: Embedding dimension.
|
67 |
+
n_hidden: Hidden dimension size in the MLP layers.
|
68 |
+
n_head: Number of attention heads.
|
69 |
+
head_dim: Dimension per attention head.
|
70 |
+
mlp_activations: List of activation functions for the MLP layers.
|
71 |
+
use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
|
72 |
+
"""
|
73 |
+
|
74 |
+
n_layer: int = Field(gt=0)
|
75 |
+
n_embd: int = Field(gt=0)
|
76 |
+
n_hidden: int = Field(gt=0)
|
77 |
+
n_head: int = Field(gt=0)
|
78 |
+
head_dim: int = Field(gt=0)
|
79 |
+
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
80 |
+
use_pre_norm: bool = Field(default=False)
|
81 |
+
|
82 |
+
|
83 |
+
class DecoderConfig(BaseModel, frozen=True):
|
84 |
+
"""Configuration for the decoder component of the Dia model.
|
85 |
+
|
86 |
+
Attributes:
|
87 |
+
n_layer: Number of transformer layers.
|
88 |
+
n_embd: Embedding dimension.
|
89 |
+
n_hidden: Hidden dimension size in the MLP layers.
|
90 |
+
gqa_query_heads: Number of query heads for grouped-query self-attention.
|
91 |
+
kv_heads: Number of key/value heads for grouped-query self-attention.
|
92 |
+
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
93 |
+
cross_query_heads: Number of query heads for cross-attention.
|
94 |
+
cross_head_dim: Dimension per cross-attention head.
|
95 |
+
mlp_activations: List of activation functions for the MLP layers.
|
96 |
+
use_pre_norm: Whether to use pre-normalization.
|
97 |
+
"""
|
98 |
+
|
99 |
+
n_layer: int = Field(gt=0)
|
100 |
+
n_embd: int = Field(gt=0)
|
101 |
+
n_hidden: int = Field(gt=0)
|
102 |
+
gqa_query_heads: int = Field(gt=0)
|
103 |
+
kv_heads: int = Field(gt=0)
|
104 |
+
gqa_head_dim: int = Field(gt=0)
|
105 |
+
cross_query_heads: int = Field(gt=0)
|
106 |
+
cross_head_dim: int = Field(gt=0)
|
107 |
+
mlp_activations: list[str] = Field(default=["silu", "linear"])
|
108 |
+
use_pre_norm: bool = Field(default=False)
|
109 |
+
|
110 |
+
|
111 |
+
class ModelConfig(BaseModel, frozen=True):
|
112 |
+
"""Main configuration container for the Dia model architecture.
|
113 |
+
|
114 |
+
Attributes:
|
115 |
+
encoder: Configuration for the encoder component.
|
116 |
+
decoder: Configuration for the decoder component.
|
117 |
+
src_vocab_size: Size of the source (text) vocabulary.
|
118 |
+
tgt_vocab_size: Size of the target (audio code) vocabulary.
|
119 |
+
dropout: Dropout probability applied within the model.
|
120 |
+
normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
|
121 |
+
weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
|
122 |
+
rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
|
123 |
+
rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
|
124 |
+
"""
|
125 |
+
|
126 |
+
encoder: EncoderConfig
|
127 |
+
decoder: DecoderConfig
|
128 |
+
src_vocab_size: int = Field(default=128, gt=0)
|
129 |
+
tgt_vocab_size: int = Field(default=1028, gt=0)
|
130 |
+
dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
|
131 |
+
normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
|
132 |
+
weight_dtype: str = Field(default="float32", description="Weight precision")
|
133 |
+
rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
|
134 |
+
rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
|
135 |
+
|
136 |
+
|
137 |
+
class TrainingConfig(BaseModel, frozen=True):
|
138 |
+
"""Training process configuration and hyperparameters.
|
139 |
+
|
140 |
+
Note: This configuration currently only includes precision settings.
|
141 |
+
Other training parameters (like batch size, learning rate, optimizer settings)
|
142 |
+
are assumed to be handled externally.
|
143 |
+
|
144 |
+
Attributes:
|
145 |
+
dtype: Data type for activations during training (e.g., "bfloat16", "float32").
|
146 |
+
logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
|
147 |
+
"""
|
148 |
+
|
149 |
+
dtype: str = Field(default="bfloat16", description="Activation precision")
|
150 |
+
logits_dot_in_fp32: bool = Field(default=False)
|
151 |
+
|
152 |
+
|
153 |
+
class DiaConfig(BaseModel, frozen=True):
|
154 |
+
"""Master configuration for the Dia model.
|
155 |
+
|
156 |
+
Combines all sub-configurations into a single validated object.
|
157 |
+
|
158 |
+
Attributes:
|
159 |
+
version: Configuration version string.
|
160 |
+
model: Model architecture configuration.
|
161 |
+
training: Training process configuration (precision settings).
|
162 |
+
data: Data loading and processing configuration.
|
163 |
+
"""
|
164 |
+
|
165 |
+
version: str = Field(default="1.0")
|
166 |
+
model: ModelConfig
|
167 |
+
training: TrainingConfig
|
168 |
+
data: DataConfig
|
169 |
+
|
170 |
+
def save(self, path: str) -> None:
|
171 |
+
"""Save the current configuration instance to a JSON file.
|
172 |
+
|
173 |
+
Ensures the parent directory exists and the file has a .json extension.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
path: The target file path to save the configuration.
|
177 |
+
|
178 |
+
Raises:
|
179 |
+
ValueError: If the path is not a file with a .json extension.
|
180 |
+
"""
|
181 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
182 |
+
config_json = self.model_dump_json(indent=2)
|
183 |
+
with open(path, "w") as f:
|
184 |
+
f.write(config_json)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def load(cls, path: str) -> "DiaConfig | None":
|
188 |
+
"""Load and validate a Dia configuration from a JSON file.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
path: The path to the configuration file.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
A validated DiaConfig instance if the file exists and is valid,
|
195 |
+
otherwise None if the file is not found.
|
196 |
+
|
197 |
+
Raises:
|
198 |
+
ValueError: If the path does not point to an existing .json file.
|
199 |
+
pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
|
200 |
+
"""
|
201 |
+
try:
|
202 |
+
with open(path, "r") as f:
|
203 |
+
content = f.read()
|
204 |
+
return cls.model_validate_json(content)
|
205 |
+
except FileNotFoundError:
|
206 |
+
return None
|
dia/layers.py
ADDED
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import RMSNorm
|
8 |
+
|
9 |
+
from .config import DiaConfig
|
10 |
+
|
11 |
+
|
12 |
+
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
|
13 |
+
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
|
14 |
+
|
15 |
+
|
16 |
+
def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
|
17 |
+
# Allow None for default behavior
|
18 |
+
if dtype_str is None or dtype_str.lower() == "none":
|
19 |
+
return None
|
20 |
+
if dtype_str == "float32":
|
21 |
+
return torch.float32
|
22 |
+
elif dtype_str == "float16":
|
23 |
+
return torch.float16
|
24 |
+
elif dtype_str == "bfloat16":
|
25 |
+
return torch.bfloat16
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unsupported dtype string: {dtype_str}")
|
28 |
+
|
29 |
+
|
30 |
+
class DenseGeneral(nn.Module):
|
31 |
+
"""
|
32 |
+
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
|
33 |
+
|
34 |
+
Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
|
35 |
+
for the generalized matrix multiplication. Weight/bias shapes are calculated
|
36 |
+
and parameters created during initialization based on config.
|
37 |
+
`load_weights` validates shapes and copies data.
|
38 |
+
|
39 |
+
Attributes:
|
40 |
+
axis (Tuple[int, ...]): Input axis or axes to contract.
|
41 |
+
in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
|
42 |
+
out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
|
43 |
+
use_bias (bool): Whether to add a bias term.
|
44 |
+
weight (nn.Parameter): The kernel parameter.
|
45 |
+
bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_shapes: tuple[int, ...],
|
51 |
+
out_features: tuple[int, ...],
|
52 |
+
axis: tuple[int, ...] = (-1,),
|
53 |
+
dtype: torch.dtype | None = None,
|
54 |
+
weight_dtype: torch.dtype | None = None,
|
55 |
+
device: torch.device | None = None,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.in_shapes = in_shapes
|
59 |
+
self.out_features = out_features
|
60 |
+
self.axis = axis
|
61 |
+
self.dtype = dtype
|
62 |
+
self.kernel_shape = self.in_shapes + self.out_features
|
63 |
+
|
64 |
+
factory_kwargs = {"device": device, "dtype": weight_dtype}
|
65 |
+
self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
|
66 |
+
self.register_parameter("bias", None)
|
67 |
+
|
68 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
69 |
+
norm_axis = _normalize_axes(self.axis, inputs.ndim)
|
70 |
+
kernel_contract_axes = tuple(range(len(norm_axis)))
|
71 |
+
|
72 |
+
output = torch.tensordot(
|
73 |
+
inputs.float(),
|
74 |
+
self.weight.float(),
|
75 |
+
dims=(norm_axis, kernel_contract_axes),
|
76 |
+
).to(inputs.dtype)
|
77 |
+
return output
|
78 |
+
|
79 |
+
|
80 |
+
def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
|
81 |
+
"""Maps activation string to PyTorch activation function module."""
|
82 |
+
if activation_string == "gelu":
|
83 |
+
return nn.GELU()
|
84 |
+
elif activation_string == "relu":
|
85 |
+
return nn.ReLU()
|
86 |
+
elif activation_string == "silu" or activation_string == "swish":
|
87 |
+
return nn.SiLU()
|
88 |
+
elif activation_string == "linear":
|
89 |
+
return nn.Identity()
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unsupported activation function: {activation_string}")
|
92 |
+
|
93 |
+
|
94 |
+
class MlpBlock(nn.Module):
|
95 |
+
"""MLP block using DenseGeneral."""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
config: DiaConfig,
|
100 |
+
embed_dim: int,
|
101 |
+
intermediate_dim: int,
|
102 |
+
dropout_rate: float,
|
103 |
+
activations: list[str] = ["silu", "linear"],
|
104 |
+
use_pre_norm: bool = False,
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.use_pre_norm = use_pre_norm
|
108 |
+
num_activations = len(activations)
|
109 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
110 |
+
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
111 |
+
self.dtype = compute_dtype
|
112 |
+
# Assume default device for now, could be passed in config
|
113 |
+
|
114 |
+
if use_pre_norm:
|
115 |
+
self.pre_norm = RMSNorm(
|
116 |
+
embed_dim,
|
117 |
+
eps=config.model.normalization_layer_epsilon,
|
118 |
+
dtype=torch.float32,
|
119 |
+
)
|
120 |
+
|
121 |
+
self.wi_fused = DenseGeneral(
|
122 |
+
in_shapes=(embed_dim,),
|
123 |
+
out_features=(
|
124 |
+
num_activations,
|
125 |
+
intermediate_dim,
|
126 |
+
),
|
127 |
+
axis=(-1,),
|
128 |
+
dtype=compute_dtype,
|
129 |
+
weight_dtype=weight_dtype,
|
130 |
+
)
|
131 |
+
|
132 |
+
self.activation_fn_0 = get_activation_fn(activations[0]) # silu
|
133 |
+
self.activation_fn_1 = get_activation_fn(activations[1]) # linear
|
134 |
+
|
135 |
+
self.dropout = nn.Dropout(dropout_rate)
|
136 |
+
|
137 |
+
# Output layer using DenseGeneral
|
138 |
+
self.wo = DenseGeneral(
|
139 |
+
in_shapes=(intermediate_dim,),
|
140 |
+
out_features=(embed_dim,),
|
141 |
+
axis=(-1,),
|
142 |
+
dtype=compute_dtype,
|
143 |
+
weight_dtype=weight_dtype,
|
144 |
+
)
|
145 |
+
|
146 |
+
def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
|
147 |
+
"""Forward pass."""
|
148 |
+
if self.use_pre_norm and hasattr(self, "pre_norm"):
|
149 |
+
x = self.pre_norm(x)
|
150 |
+
|
151 |
+
fused_x = self.wi_fused(x)
|
152 |
+
|
153 |
+
gate_input = fused_x[..., 0, :]
|
154 |
+
up_input = fused_x[..., 1, :]
|
155 |
+
|
156 |
+
gate = self.activation_fn_0(gate_input)
|
157 |
+
up = self.activation_fn_1(up_input)
|
158 |
+
hidden = torch.mul(gate, up).to(self.dtype)
|
159 |
+
|
160 |
+
if not deterministic:
|
161 |
+
hidden = self.dropout(hidden)
|
162 |
+
|
163 |
+
output = self.wo(hidden)
|
164 |
+
return output
|
165 |
+
|
166 |
+
|
167 |
+
class RotaryEmbedding(nn.Module):
|
168 |
+
"""Rotary Position Embedding (RoPE) implementation in PyTorch."""
|
169 |
+
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
embedding_dims: int,
|
173 |
+
min_timescale: int = 1,
|
174 |
+
max_timescale: int = 10000,
|
175 |
+
dtype: torch.dtype = torch.float32,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
if embedding_dims % 2 != 0:
|
179 |
+
raise ValueError("Embedding dim must be even for RoPE.")
|
180 |
+
self.embedding_dims = embedding_dims
|
181 |
+
self.min_timescale = min_timescale
|
182 |
+
self.max_timescale = max_timescale
|
183 |
+
self.dtype = dtype
|
184 |
+
|
185 |
+
half_embedding_dim = embedding_dims // 2
|
186 |
+
fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
|
187 |
+
self.register_buffer(
|
188 |
+
"timescale",
|
189 |
+
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
|
190 |
+
persistent=False,
|
191 |
+
)
|
192 |
+
|
193 |
+
def extra_repr(self) -> str:
|
194 |
+
s = f"{self.timescale.shape}"
|
195 |
+
return s
|
196 |
+
|
197 |
+
def forward(self, inputs: torch.Tensor, position: torch.Tensor):
|
198 |
+
"""Applies RoPE."""
|
199 |
+
position = position.unsqueeze(-1).unsqueeze(-1)
|
200 |
+
timescale = self.timescale.to(inputs.device)
|
201 |
+
sinusoid_inp = position / timescale
|
202 |
+
sin = torch.sin(sinusoid_inp).to(inputs.dtype)
|
203 |
+
cos = torch.cos(sinusoid_inp).to(inputs.dtype)
|
204 |
+
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
|
205 |
+
first_part = first_half * cos - second_half * sin
|
206 |
+
second_part = second_half * cos + first_half * sin
|
207 |
+
return torch.cat((first_part, second_part), dim=-1)
|
208 |
+
|
209 |
+
|
210 |
+
class KVCache:
|
211 |
+
def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
|
212 |
+
self.k = (
|
213 |
+
torch.zeros((2, num_heads, max_len, head_dim), device=device)
|
214 |
+
if k is None
|
215 |
+
else k
|
216 |
+
)
|
217 |
+
self.v = (
|
218 |
+
torch.zeros((2, num_heads, max_len, head_dim), device=device)
|
219 |
+
if v is None
|
220 |
+
else v
|
221 |
+
)
|
222 |
+
self.current_idx = 0
|
223 |
+
self.max_len = max_len
|
224 |
+
|
225 |
+
def get_kv_for_attention(self, current_k, current_v):
|
226 |
+
if self.current_idx == 0:
|
227 |
+
return current_k, current_v
|
228 |
+
else:
|
229 |
+
past_k = self.k[:, :, : self.current_idx, :]
|
230 |
+
past_v = self.v[:, :, : self.current_idx, :]
|
231 |
+
attn_k = torch.cat((past_k, current_k), dim=2)
|
232 |
+
attn_v = torch.cat((past_v, current_v), dim=2)
|
233 |
+
return attn_k, attn_v
|
234 |
+
|
235 |
+
def update_cache(self, k, v):
|
236 |
+
assert self.current_idx < self.max_len
|
237 |
+
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
|
238 |
+
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
|
239 |
+
self.current_idx += 1
|
240 |
+
|
241 |
+
def prefill_kv(self, k, v):
|
242 |
+
prefill_len = k.shape[2]
|
243 |
+
assert prefill_len <= self.max_len
|
244 |
+
self.k[:, :, :prefill_len, :] = k
|
245 |
+
self.v[:, :, :prefill_len, :] = v
|
246 |
+
self.current_idx = prefill_len
|
247 |
+
|
248 |
+
|
249 |
+
class Attention(nn.Module):
|
250 |
+
"""Attention using DenseGeneral."""
|
251 |
+
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
config: DiaConfig,
|
255 |
+
q_embed_dim: int,
|
256 |
+
kv_embed_dim: int,
|
257 |
+
num_query_heads: int,
|
258 |
+
num_kv_heads: int,
|
259 |
+
head_dim: int,
|
260 |
+
dropout_rate: float,
|
261 |
+
is_cross_attn: bool = False,
|
262 |
+
out_embed_dim: int | None = None,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
self.num_query_heads = num_query_heads
|
266 |
+
self.num_kv_heads = num_kv_heads
|
267 |
+
self.head_dim = head_dim
|
268 |
+
self.is_cross_attn = is_cross_attn
|
269 |
+
self.dropout_rate = dropout_rate
|
270 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
271 |
+
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
272 |
+
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
|
273 |
+
self.projected_query_dim = num_query_heads * head_dim
|
274 |
+
if num_query_heads % num_kv_heads != 0:
|
275 |
+
raise ValueError(
|
276 |
+
f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
|
277 |
+
)
|
278 |
+
self.num_gqa_groups = num_query_heads // num_kv_heads
|
279 |
+
|
280 |
+
# --- Projection Layers using DenseGeneral ---
|
281 |
+
self.q_proj = DenseGeneral(
|
282 |
+
in_shapes=(q_embed_dim,),
|
283 |
+
out_features=(num_query_heads, head_dim),
|
284 |
+
axis=(-1,),
|
285 |
+
dtype=compute_dtype,
|
286 |
+
weight_dtype=weight_dtype,
|
287 |
+
)
|
288 |
+
self.k_proj = DenseGeneral(
|
289 |
+
in_shapes=(kv_embed_dim,),
|
290 |
+
out_features=(num_kv_heads, head_dim),
|
291 |
+
axis=(-1,),
|
292 |
+
dtype=compute_dtype,
|
293 |
+
weight_dtype=weight_dtype,
|
294 |
+
)
|
295 |
+
self.v_proj = DenseGeneral(
|
296 |
+
in_shapes=(kv_embed_dim,),
|
297 |
+
out_features=(num_kv_heads, head_dim),
|
298 |
+
axis=(-1,),
|
299 |
+
dtype=compute_dtype,
|
300 |
+
weight_dtype=weight_dtype,
|
301 |
+
)
|
302 |
+
self.o_proj = DenseGeneral(
|
303 |
+
in_shapes=(num_query_heads, head_dim),
|
304 |
+
out_features=(self.output_dim,),
|
305 |
+
axis=(-2, -1),
|
306 |
+
dtype=compute_dtype,
|
307 |
+
weight_dtype=weight_dtype,
|
308 |
+
)
|
309 |
+
|
310 |
+
# --- Rotary Embedding ---
|
311 |
+
self.rotary_emb = RotaryEmbedding(
|
312 |
+
embedding_dims=self.head_dim,
|
313 |
+
min_timescale=config.model.rope_min_timescale,
|
314 |
+
max_timescale=config.model.rope_max_timescale,
|
315 |
+
dtype=compute_dtype,
|
316 |
+
)
|
317 |
+
|
318 |
+
def forward(
|
319 |
+
self,
|
320 |
+
Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
|
321 |
+
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
|
322 |
+
q_positions: torch.Tensor, # (B, T)
|
323 |
+
kv_positions: torch.Tensor | None = None, # (B, S)
|
324 |
+
deterministic: bool = True,
|
325 |
+
attn_mask: (
|
326 |
+
torch.Tensor | None
|
327 |
+
) = None, # None in Decoder Self Attention, Valid mask in Others
|
328 |
+
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
|
329 |
+
prefill: bool = False, # True only when prefilling KV Cache
|
330 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
331 |
+
"""
|
332 |
+
Performs attention calculation with optional KV caching.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
Xq: Query tensor (B, T, D). T=1 during single-step decoding.
|
336 |
+
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
|
337 |
+
q_positions: Positions for queries (B, T).
|
338 |
+
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
|
339 |
+
deterministic: If True, disable dropout.
|
340 |
+
attn_mask: Attention mask.
|
341 |
+
cache: KVCache.
|
342 |
+
prefill: If True, use prefill mode.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
A tuple containing:
|
346 |
+
- output: The attention output tensor (B, T, output_dim).
|
347 |
+
- present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
|
348 |
+
"""
|
349 |
+
if kv_positions is None:
|
350 |
+
kv_positions = q_positions
|
351 |
+
original_dtype = Xq.dtype
|
352 |
+
|
353 |
+
Xq_BxTxNxH = self.q_proj(Xq)
|
354 |
+
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
|
355 |
+
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
|
356 |
+
|
357 |
+
# Input values into attention calculation
|
358 |
+
attn_k: torch.Tensor | None = None
|
359 |
+
attn_v: torch.Tensor | None = None
|
360 |
+
new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
|
361 |
+
|
362 |
+
# Decoder Cross Attention
|
363 |
+
if self.is_cross_attn:
|
364 |
+
# Directly use cache (no need to check index)
|
365 |
+
attn_k, attn_v = cache.k, cache.v
|
366 |
+
if (
|
367 |
+
attn_k.shape[1] != self.num_query_heads
|
368 |
+
or attn_v.shape[1] != self.num_query_heads
|
369 |
+
):
|
370 |
+
raise ValueError(
|
371 |
+
f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
|
372 |
+
f"does not match num_query_heads ({self.num_query_heads}). "
|
373 |
+
"Cache should be pre-repeated for GQA."
|
374 |
+
)
|
375 |
+
# Self Attention
|
376 |
+
else:
|
377 |
+
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
|
378 |
+
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
|
379 |
+
Xk_BxSxKxH = self.rotary_emb(
|
380 |
+
Xk_BxSxKxH, position=kv_positions
|
381 |
+
) # (B, S, K, H)
|
382 |
+
|
383 |
+
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
384 |
+
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
|
385 |
+
# S=1 for Decode Step
|
386 |
+
|
387 |
+
if self.num_gqa_groups > 1:
|
388 |
+
Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
389 |
+
Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
|
390 |
+
else:
|
391 |
+
Xk_BxNxSxH = Xk_BxKxSxH
|
392 |
+
Xv_BxNxSxH = Xv_BxKxSxH
|
393 |
+
|
394 |
+
# Encoder Self Attention
|
395 |
+
if cache is None:
|
396 |
+
attn_k = Xk_BxNxSxH
|
397 |
+
attn_v = Xv_BxNxSxH
|
398 |
+
# Decoder Self Attention
|
399 |
+
else:
|
400 |
+
# In prefill mode, we fill in cache until prefill length
|
401 |
+
if prefill:
|
402 |
+
attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
|
403 |
+
cache.prefill_kv(attn_k, attn_v)
|
404 |
+
# In decode step, we add current K/V to cache step by step
|
405 |
+
else:
|
406 |
+
new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
|
407 |
+
attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
|
408 |
+
|
409 |
+
# Add the dtype conversion here - after both cross-attention and self-attention paths
|
410 |
+
if attn_k is not None and attn_v is not None:
|
411 |
+
attn_k = attn_k.to(Xq_BxNxTxH.dtype)
|
412 |
+
attn_v = attn_v.to(Xq_BxNxTxH.dtype)
|
413 |
+
|
414 |
+
attn_output = F.scaled_dot_product_attention(
|
415 |
+
Xq_BxNxTxH,
|
416 |
+
attn_k,
|
417 |
+
attn_v,
|
418 |
+
attn_mask=attn_mask,
|
419 |
+
dropout_p=self.dropout_rate if not deterministic else 0.0,
|
420 |
+
scale=1.0,
|
421 |
+
)
|
422 |
+
|
423 |
+
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
|
424 |
+
output = self.o_proj(attn_output)
|
425 |
+
|
426 |
+
return output.to(original_dtype), new_kv_cache
|
427 |
+
|
428 |
+
|
429 |
+
class EncoderLayer(nn.Module):
|
430 |
+
"""Transformer Encoder Layer using DenseGeneral."""
|
431 |
+
|
432 |
+
def __init__(self, config: DiaConfig):
|
433 |
+
super().__init__()
|
434 |
+
self.config = config
|
435 |
+
model_config = config.model
|
436 |
+
enc_config = config.model.encoder
|
437 |
+
embed_dim = enc_config.n_embd
|
438 |
+
|
439 |
+
self.pre_sa_norm = RMSNorm(
|
440 |
+
embed_dim,
|
441 |
+
eps=model_config.normalization_layer_epsilon,
|
442 |
+
dtype=torch.float32,
|
443 |
+
)
|
444 |
+
self.self_attention = Attention(
|
445 |
+
config=config,
|
446 |
+
q_embed_dim=embed_dim,
|
447 |
+
kv_embed_dim=embed_dim,
|
448 |
+
num_query_heads=enc_config.n_head,
|
449 |
+
num_kv_heads=enc_config.n_head,
|
450 |
+
head_dim=enc_config.head_dim,
|
451 |
+
dropout_rate=model_config.dropout,
|
452 |
+
is_cross_attn=False,
|
453 |
+
out_embed_dim=embed_dim,
|
454 |
+
)
|
455 |
+
self.post_sa_norm = RMSNorm(
|
456 |
+
embed_dim,
|
457 |
+
eps=model_config.normalization_layer_epsilon,
|
458 |
+
dtype=torch.float32,
|
459 |
+
)
|
460 |
+
self.mlp = MlpBlock(
|
461 |
+
config=config,
|
462 |
+
embed_dim=embed_dim,
|
463 |
+
intermediate_dim=enc_config.n_hidden,
|
464 |
+
activations=enc_config.mlp_activations,
|
465 |
+
dropout_rate=model_config.dropout,
|
466 |
+
use_pre_norm=enc_config.use_pre_norm,
|
467 |
+
)
|
468 |
+
self.dropout = nn.Dropout(model_config.dropout)
|
469 |
+
|
470 |
+
def forward(
|
471 |
+
self,
|
472 |
+
x: torch.Tensor,
|
473 |
+
src_positions: torch.Tensor | None = None,
|
474 |
+
deterministic: bool = True,
|
475 |
+
attn_mask: torch.Tensor | None = None,
|
476 |
+
) -> torch.Tensor:
|
477 |
+
residual = x
|
478 |
+
x_norm = self.pre_sa_norm(x)
|
479 |
+
|
480 |
+
sa_out, _ = self.self_attention(
|
481 |
+
Xq=x_norm,
|
482 |
+
Xkv=x_norm,
|
483 |
+
q_positions=src_positions,
|
484 |
+
kv_positions=src_positions,
|
485 |
+
deterministic=deterministic,
|
486 |
+
attn_mask=attn_mask,
|
487 |
+
)
|
488 |
+
x = residual + sa_out
|
489 |
+
|
490 |
+
residual = x
|
491 |
+
x_norm = self.post_sa_norm(x)
|
492 |
+
mlp_out = self.mlp(x_norm, deterministic=deterministic)
|
493 |
+
x = residual + mlp_out
|
494 |
+
|
495 |
+
if not deterministic:
|
496 |
+
x = self.dropout(x)
|
497 |
+
return x
|
498 |
+
|
499 |
+
|
500 |
+
class Encoder(nn.Module):
|
501 |
+
"""Transformer Encoder Stack using DenseGeneral."""
|
502 |
+
|
503 |
+
def __init__(self, config: DiaConfig):
|
504 |
+
super().__init__()
|
505 |
+
self.config = config
|
506 |
+
model_config = config.model
|
507 |
+
enc_config = config.model.encoder
|
508 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
509 |
+
|
510 |
+
self.embedding = nn.Embedding(
|
511 |
+
model_config.src_vocab_size,
|
512 |
+
enc_config.n_embd,
|
513 |
+
dtype=compute_dtype,
|
514 |
+
)
|
515 |
+
self.dropout = nn.Dropout(model_config.dropout)
|
516 |
+
self.layers = nn.ModuleList(
|
517 |
+
[EncoderLayer(config=config) for _ in range(enc_config.n_layer)]
|
518 |
+
)
|
519 |
+
self.norm = RMSNorm(
|
520 |
+
enc_config.n_embd,
|
521 |
+
eps=model_config.normalization_layer_epsilon,
|
522 |
+
dtype=torch.float32,
|
523 |
+
)
|
524 |
+
|
525 |
+
def forward(
|
526 |
+
self,
|
527 |
+
x_ids: torch.Tensor,
|
528 |
+
src_positions: torch.Tensor | None = None,
|
529 |
+
deterministic: bool = True,
|
530 |
+
attn_mask: torch.Tensor | None = None,
|
531 |
+
) -> torch.Tensor:
|
532 |
+
x = self.embedding(x_ids)
|
533 |
+
|
534 |
+
if not deterministic:
|
535 |
+
x = self.dropout(x)
|
536 |
+
|
537 |
+
for layer in self.layers:
|
538 |
+
x = layer(
|
539 |
+
x,
|
540 |
+
src_positions=src_positions,
|
541 |
+
deterministic=deterministic,
|
542 |
+
attn_mask=attn_mask,
|
543 |
+
)
|
544 |
+
x = self.norm(x)
|
545 |
+
if not deterministic:
|
546 |
+
x = self.dropout(x)
|
547 |
+
return x
|
548 |
+
|
549 |
+
|
550 |
+
class DecoderLayer(nn.Module):
|
551 |
+
"""Transformer Decoder Layer using DenseGeneral."""
|
552 |
+
|
553 |
+
def __init__(self, config: DiaConfig):
|
554 |
+
super().__init__()
|
555 |
+
self.config = config
|
556 |
+
model_config = config.model
|
557 |
+
dec_config = config.model.decoder
|
558 |
+
enc_config = config.model.encoder
|
559 |
+
dec_embed_dim = dec_config.n_embd
|
560 |
+
enc_embed_dim = enc_config.n_embd
|
561 |
+
|
562 |
+
# Norms
|
563 |
+
self.pre_sa_norm = RMSNorm(
|
564 |
+
dec_embed_dim,
|
565 |
+
eps=model_config.normalization_layer_epsilon,
|
566 |
+
dtype=torch.float32,
|
567 |
+
)
|
568 |
+
self.pre_ca_norm = RMSNorm(
|
569 |
+
dec_embed_dim,
|
570 |
+
eps=model_config.normalization_layer_epsilon,
|
571 |
+
dtype=torch.float32,
|
572 |
+
)
|
573 |
+
self.pre_mlp_norm = RMSNorm(
|
574 |
+
dec_embed_dim,
|
575 |
+
eps=model_config.normalization_layer_epsilon,
|
576 |
+
dtype=torch.float32,
|
577 |
+
)
|
578 |
+
|
579 |
+
# Self-Attention (GQA) with Causal Masking
|
580 |
+
self.self_attention = Attention(
|
581 |
+
config=config,
|
582 |
+
q_embed_dim=dec_embed_dim,
|
583 |
+
kv_embed_dim=dec_embed_dim,
|
584 |
+
num_query_heads=dec_config.gqa_query_heads,
|
585 |
+
num_kv_heads=dec_config.kv_heads,
|
586 |
+
head_dim=dec_config.gqa_head_dim,
|
587 |
+
dropout_rate=model_config.dropout,
|
588 |
+
is_cross_attn=False,
|
589 |
+
out_embed_dim=dec_embed_dim,
|
590 |
+
)
|
591 |
+
# Cross-Attention (MHA)
|
592 |
+
self.cross_attention = Attention(
|
593 |
+
config=config,
|
594 |
+
q_embed_dim=dec_embed_dim,
|
595 |
+
kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
|
596 |
+
num_query_heads=dec_config.cross_query_heads,
|
597 |
+
num_kv_heads=dec_config.cross_query_heads,
|
598 |
+
head_dim=dec_config.cross_head_dim,
|
599 |
+
dropout_rate=model_config.dropout,
|
600 |
+
is_cross_attn=True,
|
601 |
+
out_embed_dim=dec_embed_dim,
|
602 |
+
)
|
603 |
+
# MLP
|
604 |
+
self.mlp = MlpBlock(
|
605 |
+
config=config,
|
606 |
+
embed_dim=dec_embed_dim,
|
607 |
+
intermediate_dim=dec_config.n_hidden,
|
608 |
+
activations=dec_config.mlp_activations,
|
609 |
+
dropout_rate=model_config.dropout,
|
610 |
+
use_pre_norm=dec_config.use_pre_norm,
|
611 |
+
)
|
612 |
+
|
613 |
+
def forward(
|
614 |
+
self,
|
615 |
+
x: torch.Tensor,
|
616 |
+
encoder_out: torch.Tensor,
|
617 |
+
tgt_positions: torch.Tensor,
|
618 |
+
src_positions: torch.Tensor | None,
|
619 |
+
deterministic: bool,
|
620 |
+
self_attn_mask: torch.Tensor,
|
621 |
+
cross_attn_mask: torch.Tensor,
|
622 |
+
self_attn_cache: KVCache,
|
623 |
+
cross_attn_cache: KVCache,
|
624 |
+
prefill: bool = False,
|
625 |
+
) -> torch.Tensor:
|
626 |
+
residual = x
|
627 |
+
x_norm = self.pre_sa_norm(x)
|
628 |
+
|
629 |
+
sa_out, new_kv_cache = self.self_attention(
|
630 |
+
Xq=x_norm, # (2, 1, D)
|
631 |
+
Xkv=x_norm, # (2, 1, D)
|
632 |
+
q_positions=tgt_positions, # (2, 1)
|
633 |
+
kv_positions=tgt_positions, # (2, 1)
|
634 |
+
deterministic=deterministic,
|
635 |
+
attn_mask=self_attn_mask, # (2, 1, 1, S_max)
|
636 |
+
cache=self_attn_cache,
|
637 |
+
prefill=prefill,
|
638 |
+
)
|
639 |
+
|
640 |
+
x = residual + sa_out
|
641 |
+
|
642 |
+
# 2. Cross-Attention
|
643 |
+
residual = x
|
644 |
+
x_norm = self.pre_ca_norm(x)
|
645 |
+
ca_out, _ = self.cross_attention(
|
646 |
+
Xq=x_norm,
|
647 |
+
Xkv=encoder_out,
|
648 |
+
q_positions=tgt_positions,
|
649 |
+
kv_positions=src_positions,
|
650 |
+
deterministic=deterministic,
|
651 |
+
attn_mask=cross_attn_mask,
|
652 |
+
cache=cross_attn_cache,
|
653 |
+
)
|
654 |
+
x = residual + ca_out
|
655 |
+
|
656 |
+
# 3. MLP
|
657 |
+
residual = x
|
658 |
+
x_norm = self.pre_mlp_norm(x)
|
659 |
+
mlp_out = self.mlp(x_norm, deterministic=deterministic)
|
660 |
+
x = residual + mlp_out
|
661 |
+
|
662 |
+
return x, new_kv_cache
|
663 |
+
|
664 |
+
|
665 |
+
class Decoder(nn.Module):
|
666 |
+
"""Transformer Decoder Stack using DenseGeneral."""
|
667 |
+
|
668 |
+
def __init__(self, config: DiaConfig):
|
669 |
+
super().__init__()
|
670 |
+
self.config = config
|
671 |
+
model_config = config.model
|
672 |
+
dec_config = config.model.decoder
|
673 |
+
train_config = config.training
|
674 |
+
data_config = config.data
|
675 |
+
compute_dtype = _str_to_dtype(config.training.dtype)
|
676 |
+
weight_dtype = _str_to_dtype(config.model.weight_dtype)
|
677 |
+
self.num_channels = data_config.channels
|
678 |
+
self.num_layers = dec_config.n_layer
|
679 |
+
|
680 |
+
self.embeddings = nn.ModuleList(
|
681 |
+
[
|
682 |
+
nn.Embedding(
|
683 |
+
model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
|
684 |
+
)
|
685 |
+
for _ in range(self.num_channels)
|
686 |
+
]
|
687 |
+
)
|
688 |
+
self.dropout = nn.Dropout(model_config.dropout)
|
689 |
+
self.layers = nn.ModuleList(
|
690 |
+
[DecoderLayer(config=config) for _ in range(self.num_layers)]
|
691 |
+
)
|
692 |
+
self.norm = RMSNorm(
|
693 |
+
dec_config.n_embd,
|
694 |
+
eps=model_config.normalization_layer_epsilon,
|
695 |
+
dtype=torch.float32,
|
696 |
+
)
|
697 |
+
|
698 |
+
# Final Logits Projection using DenseGeneral
|
699 |
+
self.logits_dense = DenseGeneral(
|
700 |
+
in_shapes=(dec_config.n_embd,),
|
701 |
+
out_features=(self.num_channels, model_config.tgt_vocab_size),
|
702 |
+
axis=(-1,),
|
703 |
+
dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
|
704 |
+
weight_dtype=weight_dtype,
|
705 |
+
)
|
706 |
+
self.logits_in_fp32 = train_config.logits_dot_in_fp32
|
707 |
+
|
708 |
+
def precompute_cross_attention_kv(
|
709 |
+
self,
|
710 |
+
max_len: int,
|
711 |
+
encoder_out: torch.Tensor, # (B, S, E)
|
712 |
+
src_positions: torch.Tensor | None, # (B, S)
|
713 |
+
) -> list[KVCache]:
|
714 |
+
"""
|
715 |
+
Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
|
716 |
+
"""
|
717 |
+
per_layer_kv_cache: list[KVCache] = []
|
718 |
+
|
719 |
+
for layer in self.layers:
|
720 |
+
cross_attn_module = layer.cross_attention
|
721 |
+
k_proj = cross_attn_module.k_proj(encoder_out)
|
722 |
+
v_proj = cross_attn_module.v_proj(encoder_out)
|
723 |
+
|
724 |
+
k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
|
725 |
+
k = k_proj.transpose(1, 2)
|
726 |
+
v = v_proj.transpose(1, 2)
|
727 |
+
|
728 |
+
per_layer_kv_cache.append(
|
729 |
+
KVCache(
|
730 |
+
cross_attn_module.num_kv_heads,
|
731 |
+
max_len,
|
732 |
+
cross_attn_module.head_dim,
|
733 |
+
k.device,
|
734 |
+
k=k,
|
735 |
+
v=v,
|
736 |
+
)
|
737 |
+
)
|
738 |
+
|
739 |
+
return per_layer_kv_cache
|
740 |
+
|
741 |
+
def decode_step(
|
742 |
+
self,
|
743 |
+
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
|
744 |
+
tgt_pos_Bx1: torch.Tensor, # [B, 1]
|
745 |
+
encoder_out: torch.Tensor, # [B, S, E]
|
746 |
+
self_attn_mask: Any, # None
|
747 |
+
cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
|
748 |
+
self_attention_cache: list[KVCache],
|
749 |
+
cross_attention_cache: list[KVCache],
|
750 |
+
) -> torch.Tensor:
|
751 |
+
"""
|
752 |
+
Performs a single decoding step, managing KV caches layer by layer.
|
753 |
+
|
754 |
+
Returns:
|
755 |
+
A tuple containing:
|
756 |
+
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
|
757 |
+
"""
|
758 |
+
assert (
|
759 |
+
self_attn_mask is None
|
760 |
+
), "Self-attention mask should be None, kept for pattern"
|
761 |
+
|
762 |
+
x = None
|
763 |
+
for i in range(self.num_channels):
|
764 |
+
channel_tokens = tgt_ids_Bx1xC[..., i]
|
765 |
+
channel_embed = self.embeddings[i](channel_tokens)
|
766 |
+
x = channel_embed if x is None else x + channel_embed
|
767 |
+
|
768 |
+
new_cache = []
|
769 |
+
|
770 |
+
for i, layer in enumerate(self.layers):
|
771 |
+
self_cache = self_attention_cache[i]
|
772 |
+
cross_cache = cross_attention_cache[i]
|
773 |
+
x, new_kv_cache = layer(
|
774 |
+
x, # (2, 1, D)
|
775 |
+
encoder_out, # (2, S, E)
|
776 |
+
src_positions=None, # CA KV is already computed
|
777 |
+
tgt_positions=tgt_pos_Bx1, # (2, 1)
|
778 |
+
deterministic=True,
|
779 |
+
self_attn_mask=None,
|
780 |
+
cross_attn_mask=cross_attn_mask,
|
781 |
+
self_attn_cache=self_cache,
|
782 |
+
cross_attn_cache=cross_cache,
|
783 |
+
)
|
784 |
+
new_cache.append(new_kv_cache)
|
785 |
+
|
786 |
+
x = self.norm(x)
|
787 |
+
logits_Bx1xCxV = self.logits_dense(x)
|
788 |
+
|
789 |
+
return logits_Bx1xCxV.to(torch.float32), new_cache
|
790 |
+
|
791 |
+
def forward(
|
792 |
+
self,
|
793 |
+
tgt_ids_BxTxC: torch.Tensor,
|
794 |
+
encoder_out: torch.Tensor,
|
795 |
+
tgt_positions: torch.Tensor,
|
796 |
+
src_positions: torch.Tensor,
|
797 |
+
deterministic: bool,
|
798 |
+
self_attn_mask: torch.Tensor,
|
799 |
+
cross_attn_mask: torch.Tensor,
|
800 |
+
self_attention_cache: list[KVCache],
|
801 |
+
cross_attention_cache: list[KVCache],
|
802 |
+
) -> torch.Tensor:
|
803 |
+
"""
|
804 |
+
Forward pass for the Decoder stack, managing KV caches.
|
805 |
+
|
806 |
+
Args:
|
807 |
+
tgt_ids_BxTxC: Target token IDs (B, T, C).
|
808 |
+
encoder_out: Output from the encoder (B, S, E).
|
809 |
+
tgt_positions: Positions for target sequence (B, T).
|
810 |
+
src_positions: Positions for source sequence (B, S).
|
811 |
+
deterministic: Disable dropout if True.
|
812 |
+
self_attn_mask: Mask for self-attention.
|
813 |
+
cross_attn_mask: Mask for cross-attention.
|
814 |
+
past_key_values: List containing the self-attention KV cache for each layer
|
815 |
+
from the previous decoding step. `len(past_key_values)` should
|
816 |
+
equal `num_layers`.
|
817 |
+
precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
|
818 |
+
derived from `encoder_out`. This is passed identically
|
819 |
+
to all layers.
|
820 |
+
|
821 |
+
Returns:
|
822 |
+
A tuple containing:
|
823 |
+
- logits: The final output logits (B, T, C * V), cast to float32.
|
824 |
+
- present_key_values: A list containing the updated self-attention KV cache
|
825 |
+
for each layer for the *current* decoding step.
|
826 |
+
"""
|
827 |
+
_, _, num_channels_in = tgt_ids_BxTxC.shape
|
828 |
+
assert num_channels_in == self.num_channels, "Input channels mismatch"
|
829 |
+
|
830 |
+
# Embeddings
|
831 |
+
x = None
|
832 |
+
for i in range(self.num_channels):
|
833 |
+
channel_tokens = tgt_ids_BxTxC[..., i]
|
834 |
+
channel_embed = self.embeddings[i](channel_tokens)
|
835 |
+
x = channel_embed if x is None else x + channel_embed
|
836 |
+
|
837 |
+
if not deterministic:
|
838 |
+
x = self.dropout(x)
|
839 |
+
|
840 |
+
for i, layer in enumerate(self.layers):
|
841 |
+
x, _ = layer(
|
842 |
+
x,
|
843 |
+
encoder_out,
|
844 |
+
tgt_positions=tgt_positions,
|
845 |
+
src_positions=src_positions,
|
846 |
+
deterministic=deterministic,
|
847 |
+
self_attn_mask=self_attn_mask,
|
848 |
+
cross_attn_mask=cross_attn_mask,
|
849 |
+
self_attn_cache=self_attention_cache[i],
|
850 |
+
cross_attn_cache=cross_attention_cache[i],
|
851 |
+
prefill=True,
|
852 |
+
)
|
853 |
+
|
854 |
+
# Final Norm
|
855 |
+
x = self.norm(x)
|
856 |
+
logits_BxTxCxV = self.logits_dense(x)
|
857 |
+
|
858 |
+
return logits_BxTxCxV.to(torch.float32)
|
859 |
+
|
860 |
+
|
861 |
+
class DiaModel(nn.Module):
|
862 |
+
"""PyTorch Dia Model using DenseGeneral."""
|
863 |
+
|
864 |
+
def __init__(self, config: DiaConfig):
|
865 |
+
super().__init__()
|
866 |
+
self.config = config
|
867 |
+
self.encoder = Encoder(config)
|
868 |
+
self.decoder = Decoder(config)
|
869 |
+
|
870 |
+
def forward(
|
871 |
+
self,
|
872 |
+
src_BxS: torch.Tensor,
|
873 |
+
tgt_BxTxC: torch.Tensor,
|
874 |
+
src_positions: torch.Tensor | None = None,
|
875 |
+
tgt_positions: torch.Tensor | None = None,
|
876 |
+
enc_self_attn_mask: torch.Tensor | None = None,
|
877 |
+
dec_self_attn_mask: torch.Tensor | None = None,
|
878 |
+
dec_cross_attn_mask: torch.Tensor | None = None,
|
879 |
+
enable_dropout: bool = True,
|
880 |
+
):
|
881 |
+
deterministic = not enable_dropout
|
882 |
+
|
883 |
+
# --- Encoder Pass ---
|
884 |
+
encoder_out = self.encoder(
|
885 |
+
x_ids=src_BxS,
|
886 |
+
src_positions=src_positions,
|
887 |
+
deterministic=deterministic,
|
888 |
+
attn_mask=enc_self_attn_mask,
|
889 |
+
)
|
890 |
+
|
891 |
+
# --- Decoder Pass ---
|
892 |
+
logits, _ = self.decoder(
|
893 |
+
tgt_ids_BxTxC=tgt_BxTxC,
|
894 |
+
encoder_out=encoder_out,
|
895 |
+
tgt_positions=tgt_positions,
|
896 |
+
src_positions=src_positions,
|
897 |
+
deterministic=deterministic,
|
898 |
+
self_attn_mask=dec_self_attn_mask,
|
899 |
+
cross_attn_mask=dec_cross_attn_mask,
|
900 |
+
precomputed_cross_attn_kv=None,
|
901 |
+
)
|
902 |
+
|
903 |
+
return logits
|
dia/model.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dia/model.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
import dac # Keep this import name
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from safetensors.torch import load_file # <<< ADDED Import for safetensors
|
12 |
+
|
13 |
+
from .audio import audio_to_codebook, codebook_to_audio
|
14 |
+
from .config import (
|
15 |
+
DiaConfig,
|
16 |
+
) # Assuming this is the Pydantic config for model structure
|
17 |
+
from .layers import DiaModel, KVCache # Assuming these are the nn.Module definitions
|
18 |
+
|
19 |
+
# --- Get a logger instance for this module ---
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# Optional: Add a check after import to verify the library looks correct
|
23 |
+
# Note: We now expect 'utils' based on original code
|
24 |
+
if (
|
25 |
+
not hasattr(dac, "utils")
|
26 |
+
or not hasattr(dac.utils, "download")
|
27 |
+
or not hasattr(dac, "DAC")
|
28 |
+
):
|
29 |
+
logger.warning(
|
30 |
+
"The imported 'dac' module does not appear to have the 'utils.download' structure expected by the original Dia code."
|
31 |
+
)
|
32 |
+
logger.warning(
|
33 |
+
"Ensure 'descript-audio-codec' is installed correctly (pip install descript-audio-codec)."
|
34 |
+
)
|
35 |
+
# If this check fails, _load_dac_model will likely raise an error later anyway.
|
36 |
+
|
37 |
+
|
38 |
+
def _sample_next_token(
|
39 |
+
logits_BCxV: torch.Tensor,
|
40 |
+
temperature: float,
|
41 |
+
top_p: float,
|
42 |
+
use_cfg_filter: bool,
|
43 |
+
cfg_filter_top_k: int | None = None,
|
44 |
+
) -> torch.Tensor:
|
45 |
+
"""Samples the next token based on logits, temperature, and top_p."""
|
46 |
+
if temperature == 0.0:
|
47 |
+
# Greedy sampling
|
48 |
+
return torch.argmax(logits_BCxV, dim=-1)
|
49 |
+
|
50 |
+
# Apply temperature scaling
|
51 |
+
logits_BCxV = logits_BCxV / temperature
|
52 |
+
|
53 |
+
# Apply CFG Top-K filtering (optional)
|
54 |
+
if use_cfg_filter and cfg_filter_top_k is not None:
|
55 |
+
# Get top K values and indices
|
56 |
+
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
|
57 |
+
# Create a mask to keep only top K logits
|
58 |
+
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
59 |
+
mask.scatter_(
|
60 |
+
dim=-1, index=top_k_indices_BCxV, value=False
|
61 |
+
) # Set top K positions to False (don't mask)
|
62 |
+
# Mask out logits not in the top K
|
63 |
+
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
|
64 |
+
|
65 |
+
# Apply Top-P (Nucleus) sampling
|
66 |
+
if top_p < 1.0:
|
67 |
+
# Convert logits to probabilities
|
68 |
+
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
69 |
+
# Sort probabilities in descending order
|
70 |
+
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
|
71 |
+
probs_BCxV, dim=-1, descending=True
|
72 |
+
)
|
73 |
+
# Calculate cumulative probabilities
|
74 |
+
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
75 |
+
|
76 |
+
# Create mask for tokens to remove (those exceeding top_p threshold)
|
77 |
+
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
78 |
+
# Shift the mask: keep the first token that crosses the threshold
|
79 |
+
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
|
80 |
+
..., :-1
|
81 |
+
].clone()
|
82 |
+
sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
|
83 |
+
|
84 |
+
# Scatter the mask back to the original order
|
85 |
+
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
86 |
+
indices_to_remove_BCxV.scatter_(
|
87 |
+
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
|
88 |
+
)
|
89 |
+
# Apply the mask to the logits
|
90 |
+
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
91 |
+
|
92 |
+
# Calculate final probabilities after filtering
|
93 |
+
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
94 |
+
|
95 |
+
# Sample from the filtered distribution
|
96 |
+
# multinomial expects probabilities for each item in the batch
|
97 |
+
sampled_indices_BC = torch.multinomial(
|
98 |
+
final_probs_BCxV, num_samples=1
|
99 |
+
) # Shape [B*C, 1]
|
100 |
+
sampled_indices_C = sampled_indices_BC.squeeze(
|
101 |
+
-1
|
102 |
+
) # Shape [B*C] -> should be [C] if input was [C,V]
|
103 |
+
return sampled_indices_C
|
104 |
+
|
105 |
+
|
106 |
+
class Dia:
|
107 |
+
"""
|
108 |
+
Main class for the Dia Text-to-Speech model, handling loading and generation.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")):
|
112 |
+
"""
|
113 |
+
Initializes the Dia model structure based on the provided configuration.
|
114 |
+
Does not load weights here.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
config: The DiaConfig object defining model parameters.
|
118 |
+
device: The torch device (e.g., 'cuda', 'cpu') the model should eventually run on.
|
119 |
+
Note: The model is instantiated but not moved to the device here.
|
120 |
+
"""
|
121 |
+
super().__init__()
|
122 |
+
logger.info(
|
123 |
+
f"Initializing Dia model structure with config version: {config.version}"
|
124 |
+
)
|
125 |
+
self.config = config
|
126 |
+
# Store the target device, but don't move the model yet. Loading weights will handle device placement.
|
127 |
+
self.target_device = device
|
128 |
+
# Instantiate the underlying PyTorch model based on the config
|
129 |
+
self.model = DiaModel(config)
|
130 |
+
self.dac_model = None # DAC model will be loaded separately
|
131 |
+
logger.info("Dia model structure initialized.")
|
132 |
+
|
133 |
+
@classmethod
|
134 |
+
def load_model_from_files(
|
135 |
+
cls,
|
136 |
+
config_path: str,
|
137 |
+
weights_path: str,
|
138 |
+
device: torch.device = torch.device("cuda"),
|
139 |
+
) -> "Dia":
|
140 |
+
"""
|
141 |
+
Loads the Dia model from local configuration and weights files.
|
142 |
+
Handles both .pth and .safetensors weight formats.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
config_path: Path to the configuration JSON file (e.g., 'config.json').
|
146 |
+
weights_path: Path to the model weights file (e.g., 'model.pth' or 'model.safetensors').
|
147 |
+
device: The torch device ('cuda', 'cpu', etc.) to load the model onto.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
An instance of the Dia model loaded with weights and set to eval mode.
|
151 |
+
|
152 |
+
Raises:
|
153 |
+
FileNotFoundError: If the config or weights file is not found.
|
154 |
+
ValueError: If the weights file format is unsupported.
|
155 |
+
RuntimeError: If there is an error loading the config, weights, or DAC model.
|
156 |
+
"""
|
157 |
+
logger.info(f"Loading Dia model from local files:")
|
158 |
+
logger.info(f" Config: {config_path}")
|
159 |
+
logger.info(f" Weights: {weights_path}")
|
160 |
+
logger.info(f" Target Device: {device}")
|
161 |
+
|
162 |
+
# 1. Load Configuration
|
163 |
+
try:
|
164 |
+
config = DiaConfig.load(config_path)
|
165 |
+
if config is None:
|
166 |
+
# DiaConfig.load returns None on FileNotFoundError
|
167 |
+
logger.error(f"Configuration file not found at {config_path}")
|
168 |
+
raise FileNotFoundError(
|
169 |
+
f"Configuration file not found at {config_path}"
|
170 |
+
)
|
171 |
+
logger.info("Configuration loaded successfully.")
|
172 |
+
except Exception as e:
|
173 |
+
logger.error(
|
174 |
+
f"Error loading or validating configuration from {config_path}: {e}",
|
175 |
+
exc_info=True,
|
176 |
+
)
|
177 |
+
raise RuntimeError(
|
178 |
+
f"Failed to load configuration from {config_path}"
|
179 |
+
) from e
|
180 |
+
|
181 |
+
# 2. Instantiate Model Structure
|
182 |
+
# Pass the target device during instantiation if the underlying DiaModel supports it,
|
183 |
+
# otherwise, we move it later. Assuming __init__ doesn't take device for now.
|
184 |
+
dia_instance = cls(
|
185 |
+
config, device
|
186 |
+
) # Pass device mainly for storing target_device
|
187 |
+
|
188 |
+
# 3. Load Weights (State Dictionary)
|
189 |
+
try:
|
190 |
+
logger.info(f"Loading weights from: {weights_path}")
|
191 |
+
weights_filename = os.path.basename(weights_path)
|
192 |
+
state_dict = None
|
193 |
+
|
194 |
+
if weights_filename.endswith(".safetensors"):
|
195 |
+
logger.info(
|
196 |
+
"Detected .safetensors format. Loading using safetensors library."
|
197 |
+
)
|
198 |
+
# load_file loads directly to the specified device
|
199 |
+
state_dict = load_file(weights_path, device=str(device))
|
200 |
+
logger.info("Safetensors weights loaded.")
|
201 |
+
elif weights_filename.endswith(".pth"):
|
202 |
+
logger.info("Detected .pth format. Loading using torch.load.")
|
203 |
+
# torch.load needs map_location to load onto the correct device
|
204 |
+
state_dict = torch.load(weights_path, map_location=device)
|
205 |
+
logger.info("PyTorch weights (.pth) loaded.")
|
206 |
+
else:
|
207 |
+
logger.error(
|
208 |
+
f"Unsupported weights file format: {weights_filename}. Expected .pth or .safetensors."
|
209 |
+
)
|
210 |
+
raise ValueError(f"Unsupported weights file format: {weights_filename}")
|
211 |
+
|
212 |
+
# Load the state dictionary into the model structure
|
213 |
+
logger.info("Applying loaded weights to the model structure...")
|
214 |
+
# Use strict=True by default to catch mismatches. Can be set to False if needed for specific conversions (e.g., BF16 -> FP32 partial loads)
|
215 |
+
dia_instance.model.load_state_dict(state_dict, strict=True)
|
216 |
+
logger.info("Weights applied successfully.")
|
217 |
+
|
218 |
+
except FileNotFoundError:
|
219 |
+
logger.error(f"Weights file not found at {weights_path}")
|
220 |
+
raise FileNotFoundError(f"Weights file not found at {weights_path}")
|
221 |
+
except Exception as e:
|
222 |
+
logger.error(
|
223 |
+
f"Error loading weights from {weights_path}: {e}", exc_info=True
|
224 |
+
)
|
225 |
+
raise RuntimeError(f"Error loading weights from {weights_path}") from e
|
226 |
+
|
227 |
+
# 4. Move Model to Device and Set Eval Mode
|
228 |
+
logger.info(f"Moving model to device: {device}...")
|
229 |
+
dia_instance.model.to(device)
|
230 |
+
logger.info("Setting model to evaluation mode...")
|
231 |
+
dia_instance.model.eval()
|
232 |
+
|
233 |
+
# 5. Load Associated DAC Model
|
234 |
+
logger.info("Loading associated DAC model...")
|
235 |
+
dia_instance._load_dac_model() # This will log its own progress/errors
|
236 |
+
|
237 |
+
logger.info("Dia model fully loaded and ready.")
|
238 |
+
return dia_instance
|
239 |
+
|
240 |
+
# REMOVED from_pretrained - Responsibility moved to engine.py
|
241 |
+
# @classmethod
|
242 |
+
# def from_pretrained(...) -> "Dia": ...
|
243 |
+
|
244 |
+
def _load_dac_model(self):
|
245 |
+
"""Loads the Descript Audio Codec (DAC) model using the original project's method."""
|
246 |
+
if self.dac_model is not None:
|
247 |
+
logger.info("DAC model already loaded.")
|
248 |
+
return
|
249 |
+
|
250 |
+
# Verify the imported module has the necessary structure expected by original code
|
251 |
+
if (
|
252 |
+
not hasattr(dac, "utils")
|
253 |
+
or not hasattr(dac.utils, "download")
|
254 |
+
or not hasattr(dac, "DAC")
|
255 |
+
):
|
256 |
+
logger.error(
|
257 |
+
"Imported 'dac' module structure mismatch. Expected 'dac.utils.download()' and 'dac.DAC'."
|
258 |
+
)
|
259 |
+
logger.error(
|
260 |
+
"Ensure 'descript-audio-codec' is installed correctly via pip."
|
261 |
+
)
|
262 |
+
raise RuntimeError(
|
263 |
+
"Failed to load DAC model: required functions/structure missing from 'dac' module."
|
264 |
+
)
|
265 |
+
|
266 |
+
try:
|
267 |
+
# Use the original method found in the Dia repository
|
268 |
+
logger.info("Downloading/finding DAC model using dac.utils.download()...")
|
269 |
+
# This assumes dac.utils.download() handles caching internally
|
270 |
+
dac_model_path = dac.utils.download()
|
271 |
+
logger.info(f"DAC model path determined: {dac_model_path}")
|
272 |
+
|
273 |
+
logger.info("Loading DAC model from path...")
|
274 |
+
# Load DAC model and move it to the same device as the main Dia model
|
275 |
+
dac_model = dac.DAC.load(dac_model_path).to(self.target_device)
|
276 |
+
logger.info("DAC model loaded successfully.")
|
277 |
+
|
278 |
+
except AttributeError as ae:
|
279 |
+
logger.error(
|
280 |
+
f"AttributeError loading DAC model: '{ae}'. The installed 'descript-audio-codec' version might be incompatible with Dia's original code which expects 'dac.utils.download()'."
|
281 |
+
)
|
282 |
+
logger.error(
|
283 |
+
"Please check for specific version requirements of 'descript-audio-codec' for Dia, or potential installation issues."
|
284 |
+
)
|
285 |
+
raise RuntimeError(
|
286 |
+
"Failed to load DAC model due to incompatible library version or structure"
|
287 |
+
) from ae
|
288 |
+
except Exception as e:
|
289 |
+
logger.error(f"General error loading DAC model: {e}", exc_info=True)
|
290 |
+
raise RuntimeError("Failed to load DAC model") from e
|
291 |
+
|
292 |
+
self.dac_model = dac_model
|
293 |
+
|
294 |
+
def _create_attn_mask(
|
295 |
+
self,
|
296 |
+
q_padding_mask_1d: torch.Tensor,
|
297 |
+
k_padding_mask_1d: torch.Tensor,
|
298 |
+
is_causal: bool = False,
|
299 |
+
) -> torch.Tensor:
|
300 |
+
"""
|
301 |
+
Creates the attention mask (self or cross) based on padding masks.
|
302 |
+
Mimics JAX segment ID logic where attention is allowed between non-padding tokens
|
303 |
+
OR between padding tokens, but not across the boundary.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
q_padding_mask_1d: Boolean tensor [Batch, SeqLenQ] where True indicates non-padding.
|
307 |
+
k_padding_mask_1d: Boolean tensor [Batch, SeqLenK] where True indicates non-padding.
|
308 |
+
is_causal: If True, applies an additional causal mask (for decoder self-attention).
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Boolean attention mask tensor [Batch, 1, SeqLenQ, SeqLenK] ready for F.scaled_dot_product_attention.
|
312 |
+
"""
|
313 |
+
B1, Tq = q_padding_mask_1d.shape
|
314 |
+
B2, Tk = k_padding_mask_1d.shape
|
315 |
+
if B1 != B2:
|
316 |
+
logger.warning(
|
317 |
+
f"Query ({B1}) and key ({B2}) batch dimensions do not match in _create_attn_mask"
|
318 |
+
)
|
319 |
+
assert B1 == B2, "Query and key batch dimensions must match"
|
320 |
+
|
321 |
+
# Expand masks for broadcasting: [B, Tq, 1] and [B, 1, Tk]
|
322 |
+
p_mask_q = q_padding_mask_1d.unsqueeze(2)
|
323 |
+
p_mask_k = k_padding_mask_1d.unsqueeze(1)
|
324 |
+
|
325 |
+
# True where a non-padding query token attends to a non-padding key token
|
326 |
+
non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
|
327 |
+
# True where a padding query token attends to a padding key token
|
328 |
+
pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
|
329 |
+
|
330 |
+
# Combine: Attention is allowed if tokens are both non-padding OR both padding.
|
331 |
+
mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
|
332 |
+
|
333 |
+
if is_causal:
|
334 |
+
# Apply causal mask for self-attention (query cannot attend to future keys)
|
335 |
+
if Tq != Tk:
|
336 |
+
logger.warning(f"Causal mask requested but Tq ({Tq}) != Tk ({Tk})")
|
337 |
+
assert (
|
338 |
+
Tq == Tk
|
339 |
+
), "Causal mask requires query and key sequence lengths to be equal"
|
340 |
+
# Create lower triangular matrix (True allows attention)
|
341 |
+
causal_mask_2d = torch.tril(
|
342 |
+
torch.ones((Tq, Tk), dtype=torch.bool, device=self.target_device)
|
343 |
+
)
|
344 |
+
# Combine with padding compatibility mask
|
345 |
+
mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
|
346 |
+
|
347 |
+
# Add head dimension for broadcasting: [B, 1, Tq, Tk]
|
348 |
+
return mask.unsqueeze(1)
|
349 |
+
|
350 |
+
def _prepare_text_input(
|
351 |
+
self, text: str
|
352 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
353 |
+
"""
|
354 |
+
Encodes text prompt into byte tokens, pads to max length,
|
355 |
+
and creates position IDs and padding mask.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
text: The input text string.
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
Tuple containing:
|
362 |
+
- src_tokens: Padded token IDs [1, SeqLen].
|
363 |
+
- src_positions: Position IDs [1, SeqLen].
|
364 |
+
- src_padding_mask: Boolean mask (True=non-pad) [1, SeqLen].
|
365 |
+
- enc_self_attn_mask: Attention mask for encoder [1, 1, SeqLen, SeqLen].
|
366 |
+
"""
|
367 |
+
text_pad_value = self.config.data.text_pad_value
|
368 |
+
max_len = self.config.data.text_length
|
369 |
+
logger.debug(
|
370 |
+
f"Preparing text input. Max length: {max_len}, Pad value: {text_pad_value}"
|
371 |
+
)
|
372 |
+
logger.debug(f"Original text (start): '{text[:100]}...'")
|
373 |
+
|
374 |
+
# Convert text to bytes and replace special speaker tokens
|
375 |
+
byte_text = text.encode("utf-8")
|
376 |
+
# Assuming Dia uses byte values 1 and 2 for S1/S2 based on original code context
|
377 |
+
replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
|
378 |
+
text_tokens = list(replaced_bytes) # List of integer byte values
|
379 |
+
logger.debug(
|
380 |
+
f"Text tokens after byte conversion (first 10): {text_tokens[:10]}"
|
381 |
+
)
|
382 |
+
|
383 |
+
# Pad or truncate sequence
|
384 |
+
current_len = len(text_tokens)
|
385 |
+
padding_needed = max_len - current_len
|
386 |
+
if padding_needed <= 0:
|
387 |
+
if current_len > max_len:
|
388 |
+
logger.warning(
|
389 |
+
f"Input text length ({current_len}) exceeds max length ({max_len}). Truncating."
|
390 |
+
)
|
391 |
+
text_tokens = text_tokens[:max_len]
|
392 |
+
padded_text_np = np.array(text_tokens, dtype=np.uint8)
|
393 |
+
else:
|
394 |
+
logger.debug(f"Padding text input with {padding_needed} pad tokens.")
|
395 |
+
padded_text_np = np.pad(
|
396 |
+
text_tokens,
|
397 |
+
(0, padding_needed),
|
398 |
+
mode="constant",
|
399 |
+
constant_values=text_pad_value,
|
400 |
+
).astype(np.uint8)
|
401 |
+
|
402 |
+
# Convert to tensors and add batch dimension [1, SeqLen]
|
403 |
+
src_tokens = (
|
404 |
+
torch.from_numpy(padded_text_np)
|
405 |
+
.to(torch.long)
|
406 |
+
.to(self.target_device)
|
407 |
+
.unsqueeze(0)
|
408 |
+
)
|
409 |
+
src_positions = (
|
410 |
+
torch.arange(max_len, device=self.target_device).to(torch.long).unsqueeze(0)
|
411 |
+
)
|
412 |
+
|
413 |
+
# Create padding mask (True where token is NOT the pad value)
|
414 |
+
src_padding_mask = src_tokens != text_pad_value # Shape [1, SeqLen]
|
415 |
+
|
416 |
+
# Create attention mask for the encoder (non-causal self-attention)
|
417 |
+
# Needs shape [B, 1, Tq, Tk] -> [1, 1, SeqLen, SeqLen]
|
418 |
+
enc_self_attn_mask = self._create_attn_mask(
|
419 |
+
src_padding_mask, src_padding_mask, is_causal=False
|
420 |
+
)
|
421 |
+
|
422 |
+
logger.debug(f"Prepared src_tokens shape: {src_tokens.shape}")
|
423 |
+
logger.debug(f"Prepared src_positions shape: {src_positions.shape}")
|
424 |
+
logger.debug(
|
425 |
+
f"Prepared src_padding_mask shape: {src_padding_mask.shape} (True means non-padding)"
|
426 |
+
)
|
427 |
+
logger.debug(f"Prepared enc_self_attn_mask shape: {enc_self_attn_mask.shape}")
|
428 |
+
|
429 |
+
return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
|
430 |
+
|
431 |
+
@torch.inference_mode()
|
432 |
+
def generate(
|
433 |
+
self,
|
434 |
+
text: str,
|
435 |
+
max_tokens: int | None = None,
|
436 |
+
cfg_scale: float = 3.0,
|
437 |
+
temperature: float = 1.3,
|
438 |
+
top_p: float = 0.95,
|
439 |
+
use_cfg_filter: bool = True,
|
440 |
+
use_torch_compile: bool = False, # Default to False for broader compatibility
|
441 |
+
cfg_filter_top_k: int = 35,
|
442 |
+
audio_prompt_path: str | None = None,
|
443 |
+
) -> np.ndarray:
|
444 |
+
"""
|
445 |
+
Generates audio waveform from a text prompt, optionally conditioned on an audio prompt.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
text: The input text string. For dialogue, use [S1]/[S2] markers.
|
449 |
+
For voice cloning, prepend the transcript of the audio prompt.
|
450 |
+
max_tokens: Maximum number of audio tokens (frames) to generate. Defaults to config value.
|
451 |
+
cfg_scale: Classifier-Free Guidance scale. Higher values increase adherence to text.
|
452 |
+
temperature: Sampling temperature. Higher values increase randomness.
|
453 |
+
top_p: Nucleus sampling probability. Filters vocabulary during sampling.
|
454 |
+
use_cfg_filter: Whether to apply Top-K filtering based on CFG logits.
|
455 |
+
use_torch_compile: If True, attempts to compile the decoder step for potential speedup.
|
456 |
+
cfg_filter_top_k: The 'K' value for CFG Top-K filtering.
|
457 |
+
audio_prompt_path: Path to an audio file (e.g., WAV, MP3) to use as a voice prompt/clone target.
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
A 1D NumPy array containing the generated audio waveform (float32).
|
461 |
+
"""
|
462 |
+
start_time_gen = time.time()
|
463 |
+
logger.info("Starting audio generation...")
|
464 |
+
logger.info(f" Text (start): '{text[:100]}...'")
|
465 |
+
logger.info(
|
466 |
+
f" Max tokens: {max_tokens if max_tokens is not None else 'Model Default'}"
|
467 |
+
)
|
468 |
+
logger.info(f" CFG Scale: {cfg_scale}")
|
469 |
+
logger.info(f" Temperature: {temperature}")
|
470 |
+
logger.info(f" Top P: {top_p}")
|
471 |
+
logger.info(f" Use CFG Filter: {use_cfg_filter}, Top K: {cfg_filter_top_k}")
|
472 |
+
logger.info(
|
473 |
+
f" Audio Prompt: {audio_prompt_path if audio_prompt_path else 'None'}"
|
474 |
+
)
|
475 |
+
logger.info(f" Use torch.compile: {use_torch_compile}")
|
476 |
+
logger.info(f" Target Device: {self.target_device}")
|
477 |
+
|
478 |
+
# --- Parameter Setup ---
|
479 |
+
num_channels = self.config.data.channels
|
480 |
+
audio_bos_value = self.config.data.audio_bos_value
|
481 |
+
audio_eos_value = self.config.data.audio_eos_value
|
482 |
+
audio_pad_value = self.config.data.audio_pad_value
|
483 |
+
delay_pattern = self.config.data.delay_pattern
|
484 |
+
# Use model's default audio length if max_tokens not provided
|
485 |
+
effective_max_tokens = (
|
486 |
+
max_tokens if max_tokens is not None else self.config.data.audio_length
|
487 |
+
)
|
488 |
+
logger.info(f" Effective max_tokens for generation: {effective_max_tokens}")
|
489 |
+
|
490 |
+
# Ensure delay pattern is usable
|
491 |
+
if not isinstance(delay_pattern, list) or not delay_pattern:
|
492 |
+
logger.warning("Delay pattern is invalid or empty. Using default [0].")
|
493 |
+
delay_pattern = [
|
494 |
+
0
|
495 |
+
] * num_channels # Fallback, though config should provide default
|
496 |
+
|
497 |
+
delay_tensor = torch.tensor(
|
498 |
+
delay_pattern, dtype=torch.long, device=self.target_device
|
499 |
+
)
|
500 |
+
max_delay_pattern = max(delay_pattern) if delay_pattern else 0
|
501 |
+
self.model.eval() # Ensure model is in eval mode
|
502 |
+
|
503 |
+
# --- Prepare Conditional and Unconditional Inputs ---
|
504 |
+
logger.info(
|
505 |
+
"Preparing text inputs for conditional and unconditional generation..."
|
506 |
+
)
|
507 |
+
(
|
508 |
+
cond_src_BxS,
|
509 |
+
cond_src_positions_BxS,
|
510 |
+
cond_src_padding_mask_BxS,
|
511 |
+
cond_enc_self_attn_mask_Bx1xSxS,
|
512 |
+
) = self._prepare_text_input(text)
|
513 |
+
|
514 |
+
# Create unconditional input (batch of zeros representing padding)
|
515 |
+
# Assuming pad value 0 for text based on config default
|
516 |
+
unc_src_BxS = torch.full_like(
|
517 |
+
cond_src_BxS, fill_value=self.config.data.text_pad_value
|
518 |
+
)
|
519 |
+
# Batch conditional and unconditional inputs together [2, SeqLen]
|
520 |
+
src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
|
521 |
+
# Expand other inputs to match batch size 2
|
522 |
+
src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
|
523 |
+
src_padding_mask_BxS = torch.cat(
|
524 |
+
[
|
525 |
+
torch.zeros_like(cond_src_padding_mask_BxS[0:1]),
|
526 |
+
cond_src_padding_mask_BxS,
|
527 |
+
],
|
528 |
+
dim=0,
|
529 |
+
) # Uncond mask is all False (padding)
|
530 |
+
# Encoder mask needs to handle the batched input correctly
|
531 |
+
# For CFG, typically the unconditional branch attends to nothing useful from text,
|
532 |
+
# but the structure needs to be maintained. We can reuse the conditional mask structure,
|
533 |
+
# but the actual attention scores will be based on the zeroed-out unconditional input.
|
534 |
+
# Alternatively, create a specific mask for the unconditional part if needed.
|
535 |
+
# Let's expand the conditional mask for simplicity, assuming the model handles zero inputs appropriately.
|
536 |
+
enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(
|
537 |
+
2, -1, -1, -1
|
538 |
+
)
|
539 |
+
logger.info("Text inputs prepared (batch size 2 for CFG).")
|
540 |
+
|
541 |
+
# --- Encoder Pass ---
|
542 |
+
logger.info("Running encoder pass...")
|
543 |
+
start_time_enc = time.time()
|
544 |
+
# Potentially use autocast for mixed precision if supported and beneficial on device
|
545 |
+
# Example: with torch.autocast(device_type=self.target_device.type, dtype=torch.bfloat16 if self.target_device.type == 'cuda' else torch.float32):
|
546 |
+
encoder_out = self.model.encoder(
|
547 |
+
x_ids=src_BxS, # Shape [2, S]
|
548 |
+
src_positions=src_positions_BxS, # Shape [2, S]
|
549 |
+
deterministic=True, # No dropout during inference
|
550 |
+
attn_mask=enc_self_attn_mask_Bx1xSxS, # Shape [2, 1, S, S]
|
551 |
+
)
|
552 |
+
logger.info(
|
553 |
+
f"Encoder pass completed in {time.time() - start_time_enc:.3f}s. Output shape: {encoder_out.shape}"
|
554 |
+
) # Shape: [2, S, E]
|
555 |
+
|
556 |
+
# --- Prepare Decoder Inputs & KV Cache ---
|
557 |
+
logger.info("Preparing decoder inputs and KV cache...")
|
558 |
+
start_time_kv = time.time()
|
559 |
+
# 3-1. Precompute Cross-Attention KV Cache (Static) from encoder output
|
560 |
+
# This cache is computed once and reused for every decoding step.
|
561 |
+
decoder_cross_attention_cache: list[KVCache] = (
|
562 |
+
self.model.decoder.precompute_cross_attention_kv(
|
563 |
+
effective_max_tokens, encoder_out, src_positions_BxS
|
564 |
+
)
|
565 |
+
)
|
566 |
+
logger.debug(
|
567 |
+
f"Precomputed cross-attention KV cache for {len(decoder_cross_attention_cache)} layers."
|
568 |
+
)
|
569 |
+
|
570 |
+
# 3-2. Initialize Self-Attention KV Cache (Dynamic, grows with each step)
|
571 |
+
decoder_self_attention_cache: list[KVCache] = []
|
572 |
+
for i in range(self.model.decoder.num_layers):
|
573 |
+
decoder_self_attention_cache.append(
|
574 |
+
KVCache(
|
575 |
+
self.config.model.decoder.gqa_query_heads,
|
576 |
+
effective_max_tokens, # Max length the cache can hold
|
577 |
+
self.config.model.decoder.gqa_head_dim,
|
578 |
+
self.target_device, # Cache tensors should be on the target device
|
579 |
+
)
|
580 |
+
)
|
581 |
+
logger.debug(
|
582 |
+
f"Initialized self-attention KV cache for {len(decoder_self_attention_cache)} layers."
|
583 |
+
)
|
584 |
+
logger.info(
|
585 |
+
f"KV cache preparation completed in {time.time() - start_time_kv:.3f}s."
|
586 |
+
)
|
587 |
+
|
588 |
+
# 3-3. Initialize Decoder Start Tokens (BOS)
|
589 |
+
# Shape [2, 1, C] (Batch=2 for cond/uncond, T=1 for first step, C=channels)
|
590 |
+
generated_tokens_history = torch.full(
|
591 |
+
(2, 1, num_channels),
|
592 |
+
fill_value=audio_bos_value,
|
593 |
+
dtype=torch.long,
|
594 |
+
device=self.target_device,
|
595 |
+
)
|
596 |
+
logger.debug(f"Initial decoder input (BOS): {generated_tokens_history.shape}")
|
597 |
+
|
598 |
+
current_step_index = (
|
599 |
+
0 # Index of the step we are currently generating (starts at 0)
|
600 |
+
)
|
601 |
+
prompt_len_inc_bos = 1 # Length of the initial prompt (just BOS initially)
|
602 |
+
|
603 |
+
# 3-4. Handle Audio Prompt (Prefill KV Cache)
|
604 |
+
if audio_prompt_path is not None:
|
605 |
+
logger.info("Processing audio prompt for prefilling...")
|
606 |
+
start_time_prompt = time.time()
|
607 |
+
try:
|
608 |
+
# Load and potentially resample audio
|
609 |
+
audio_prompt_waveform, sr = torchaudio.load(audio_prompt_path)
|
610 |
+
logger.debug(
|
611 |
+
f"Loaded audio prompt: {audio_prompt_waveform.shape}, Sample Rate: {sr}"
|
612 |
+
)
|
613 |
+
if sr != 44100:
|
614 |
+
logger.info(f"Resampling audio prompt from {sr}Hz to 44100Hz")
|
615 |
+
audio_prompt_waveform = torchaudio.functional.resample(
|
616 |
+
audio_prompt_waveform, sr, 44100
|
617 |
+
)
|
618 |
+
# Ensure correct shape [B, C, T_audio] and device
|
619 |
+
# Assuming DAC expects channels first, add batch dim
|
620 |
+
if audio_prompt_waveform.ndim == 1: # Mono
|
621 |
+
audio_prompt_waveform = audio_prompt_waveform.unsqueeze(
|
622 |
+
0
|
623 |
+
) # Add channel dim
|
624 |
+
audio_prompt_waveform = audio_prompt_waveform.unsqueeze(0).to(
|
625 |
+
self.target_device
|
626 |
+
) # Add batch dim
|
627 |
+
|
628 |
+
# Encode audio prompt to codes using DAC
|
629 |
+
logger.info("Encoding audio prompt to codes using DAC...")
|
630 |
+
if self.dac_model is None:
|
631 |
+
raise RuntimeError(
|
632 |
+
"DAC model not loaded, required for audio prompt."
|
633 |
+
)
|
634 |
+
# audio_to_codebook returns [B, T_codes, C]
|
635 |
+
audio_prompt_codes = audio_to_codebook(
|
636 |
+
self.dac_model, audio_prompt_waveform, data_config=self.config.data
|
637 |
+
) # Shape [1, T_codes, C]
|
638 |
+
logger.info(
|
639 |
+
f"Encoded audio prompt to codes: {audio_prompt_codes.shape}"
|
640 |
+
)
|
641 |
+
|
642 |
+
# Concatenate BOS tokens with prompt codes
|
643 |
+
# Expand prompt codes to batch size 2 (for cond/uncond)
|
644 |
+
generated_tokens_history = torch.cat(
|
645 |
+
[generated_tokens_history, audio_prompt_codes.expand(2, -1, -1)],
|
646 |
+
dim=1,
|
647 |
+
) # Shape [2, 1 + T_codes, C]
|
648 |
+
logger.debug(
|
649 |
+
f"Decoder input history after prompt concatenation: {generated_tokens_history.shape}"
|
650 |
+
)
|
651 |
+
|
652 |
+
prefill_len = generated_tokens_history.shape[
|
653 |
+
1
|
654 |
+
] # Length including BOS + prompt
|
655 |
+
prompt_len_inc_bos = prefill_len
|
656 |
+
logger.info(f"Prefilling KV cache with length {prefill_len}...")
|
657 |
+
|
658 |
+
# Prepare inputs for prefill forward pass
|
659 |
+
prefill_tgt_pos = (
|
660 |
+
torch.arange(prefill_len, device=self.target_device)
|
661 |
+
.unsqueeze(0)
|
662 |
+
.expand(2, -1)
|
663 |
+
) # Shape [2, T_prefill]
|
664 |
+
# Padding mask based on actual tokens (BOS and prompt codes are not PAD)
|
665 |
+
# Shape [2, T_prefill] (True where not PAD)
|
666 |
+
prefill_tgt_padding_mask = (
|
667 |
+
generated_tokens_history != audio_pad_value
|
668 |
+
).any(dim=2)
|
669 |
+
|
670 |
+
# Create attention masks for prefill
|
671 |
+
# Shape [2, 1, T_prefill, T_prefill]
|
672 |
+
prefill_self_attn_mask = self._create_attn_mask(
|
673 |
+
prefill_tgt_padding_mask,
|
674 |
+
prefill_tgt_padding_mask,
|
675 |
+
is_causal=True,
|
676 |
+
)
|
677 |
+
# Shape [2, 1, T_prefill, S]
|
678 |
+
prefill_cross_attn_mask = self._create_attn_mask(
|
679 |
+
prefill_tgt_padding_mask,
|
680 |
+
src_padding_mask_BxS,
|
681 |
+
is_causal=False,
|
682 |
+
)
|
683 |
+
|
684 |
+
# Run forward pass through decoder to fill the self-attention KV cache
|
685 |
+
# We discard the logits from prefill
|
686 |
+
_ = self.model.decoder.forward(
|
687 |
+
tgt_ids_BxTxC=generated_tokens_history, # Pass the full history [2, T_prefill, C]
|
688 |
+
encoder_out=encoder_out,
|
689 |
+
tgt_positions=prefill_tgt_pos,
|
690 |
+
src_positions=src_positions_BxS,
|
691 |
+
deterministic=True,
|
692 |
+
self_attn_mask=prefill_self_attn_mask,
|
693 |
+
cross_attn_mask=prefill_cross_attn_mask,
|
694 |
+
self_attention_cache=decoder_self_attention_cache, # Pass cache to be filled
|
695 |
+
cross_attention_cache=decoder_cross_attention_cache, # Pass precomputed cache
|
696 |
+
# prefill=True # Pass prefill flag if decoder layer uses it
|
697 |
+
)
|
698 |
+
|
699 |
+
# Update the current step index. The next token to generate is at index prefill_len.
|
700 |
+
current_step_index = prefill_len
|
701 |
+
logger.info(
|
702 |
+
f"KV cache prefilled in {time.time() - start_time_prompt:.3f}s. Next step index: {current_step_index}"
|
703 |
+
)
|
704 |
+
|
705 |
+
except Exception as e:
|
706 |
+
logger.error(f"Error processing audio prompt: {e}", exc_info=True)
|
707 |
+
raise RuntimeError("Failed to process audio prompt") from e
|
708 |
+
|
709 |
+
# --- Autoregressive Generation Loop ---
|
710 |
+
logger.info("Starting autoregressive generation loop...")
|
711 |
+
start_time_loop = time.time()
|
712 |
+
|
713 |
+
eos_detected_channel_0 = False
|
714 |
+
eos_countdown = -1 # Countdown after EOS detected in channel 0
|
715 |
+
extra_steps_after_eos = (
|
716 |
+
30 # Generate a few extra steps for delay pattern completion
|
717 |
+
)
|
718 |
+
|
719 |
+
# Pre-allocate tensor for storing *newly* generated tokens for efficiency
|
720 |
+
# We already have the prompt in generated_tokens_history
|
721 |
+
num_steps_to_generate = effective_max_tokens
|
722 |
+
newly_generated_tokens = torch.full(
|
723 |
+
(2, num_steps_to_generate, num_channels),
|
724 |
+
fill_value=audio_pad_value, # Fill with pad initially
|
725 |
+
dtype=torch.long,
|
726 |
+
device=self.target_device,
|
727 |
+
)
|
728 |
+
logger.debug(
|
729 |
+
f"Allocated tensor for newly generated tokens: {newly_generated_tokens.shape}"
|
730 |
+
)
|
731 |
+
|
732 |
+
# --- Compile decode_step if requested ---
|
733 |
+
decode_step_fn = self.model.decoder.decode_step
|
734 |
+
if use_torch_compile:
|
735 |
+
logger.info("Compiling decoder step function with torch.compile...")
|
736 |
+
try:
|
737 |
+
# Experiment with modes: "default", "reduce-overhead", "max-autotune"
|
738 |
+
decode_step_fn = torch.compile(decode_step_fn, mode="reduce-overhead")
|
739 |
+
logger.info("Decoder step function compiled.")
|
740 |
+
except Exception as e:
|
741 |
+
logger.warning(
|
742 |
+
f"torch.compile failed: {e}. Using eager mode.", exc_info=True
|
743 |
+
)
|
744 |
+
|
745 |
+
# --- Prepare static cross-attention mask for single-step decoding ---
|
746 |
+
# Query mask is always [B, 1] (True, as generated tokens are not PAD)
|
747 |
+
step_tgt_padding_mask = torch.ones(
|
748 |
+
(2, 1), dtype=torch.bool, device=self.target_device
|
749 |
+
)
|
750 |
+
# Shape [2, 1, 1, S]
|
751 |
+
step_decoder_cross_attn_mask = self._create_attn_mask(
|
752 |
+
step_tgt_padding_mask,
|
753 |
+
src_padding_mask_BxS,
|
754 |
+
is_causal=False,
|
755 |
+
)
|
756 |
+
|
757 |
+
# --- Generation Loop ---
|
758 |
+
steps_taken = 0
|
759 |
+
for step_offset in range(num_steps_to_generate):
|
760 |
+
# Absolute step index considering prompt length
|
761 |
+
current_absolute_step = current_step_index + step_offset
|
762 |
+
|
763 |
+
# Get the token IDs for the *previous* step to predict the current one
|
764 |
+
# Shape [2, 1, C]
|
765 |
+
# If step_offset is 0, use the last token from the prompt history
|
766 |
+
if step_offset == 0:
|
767 |
+
input_token_ids = generated_tokens_history[:, -1, :].unsqueeze(1)
|
768 |
+
else:
|
769 |
+
# Use the token generated in the previous iteration of this loop
|
770 |
+
input_token_ids = newly_generated_tokens[
|
771 |
+
:, step_offset - 1, :
|
772 |
+
].unsqueeze(1)
|
773 |
+
|
774 |
+
# Position ID for the current absolute step
|
775 |
+
# Shape [2, 1]
|
776 |
+
tgt_pos_Bx1 = torch.full(
|
777 |
+
(2, 1),
|
778 |
+
fill_value=current_absolute_step,
|
779 |
+
dtype=torch.long,
|
780 |
+
device=self.target_device,
|
781 |
+
)
|
782 |
+
|
783 |
+
# --- Call Decoder Step ---
|
784 |
+
# self_attn_mask is None because KV cache handles causality implicitly in single-step decoding
|
785 |
+
logits_Bx1xCxV, new_self_kv_cache_list = decode_step_fn(
|
786 |
+
tgt_ids_Bx1xC=input_token_ids,
|
787 |
+
tgt_pos_Bx1=tgt_pos_Bx1,
|
788 |
+
encoder_out=encoder_out,
|
789 |
+
self_attn_mask=None,
|
790 |
+
cross_attn_mask=step_decoder_cross_attn_mask,
|
791 |
+
self_attention_cache=decoder_self_attention_cache,
|
792 |
+
cross_attention_cache=decoder_cross_attention_cache,
|
793 |
+
) # Logits shape: [2, 1, C, V]
|
794 |
+
|
795 |
+
# --- Update Self-Attention KV Cache ---
|
796 |
+
for i, layer_cache in enumerate(decoder_self_attention_cache):
|
797 |
+
if (
|
798 |
+
new_self_kv_cache_list
|
799 |
+
and i < len(new_self_kv_cache_list)
|
800 |
+
and new_self_kv_cache_list[i] is not None
|
801 |
+
):
|
802 |
+
# new_self_kv_cache_list[i] is a tuple (k_tensor, v_tensor) for the current step
|
803 |
+
# k_tensor shape: [2, NumHeads, 1, HeadDim]
|
804 |
+
# v_tensor shape: [2, NumHeads, 1, HeadDim]
|
805 |
+
layer_cache.update_cache(
|
806 |
+
new_self_kv_cache_list[i][0], new_self_kv_cache_list[i][1]
|
807 |
+
)
|
808 |
+
else:
|
809 |
+
logger.warning(
|
810 |
+
f"Missing KV cache update for layer {i} at step {current_absolute_step}"
|
811 |
+
)
|
812 |
+
|
813 |
+
# --- Sampling ---
|
814 |
+
V = self.config.model.tgt_vocab_size
|
815 |
+
# Get logits for the generated step [2, C, V]
|
816 |
+
logits_last_BxCxV = logits_Bx1xCxV.squeeze(1)
|
817 |
+
|
818 |
+
# Separate conditional and unconditional logits
|
819 |
+
uncond_logits_CxV = logits_last_BxCxV[0, :, :] # Shape [C, V]
|
820 |
+
cond_logits_CxV = logits_last_BxCxV[1, :, :] # Shape [C, V]
|
821 |
+
|
822 |
+
# Apply Classifier-Free Guidance (CFG)
|
823 |
+
cfg_logits_CxV = cond_logits_CxV + cfg_scale * (
|
824 |
+
cond_logits_CxV - uncond_logits_CxV
|
825 |
+
) # Shape [C, V]
|
826 |
+
|
827 |
+
# --- Prevent sampling PAD/EOS/BOS tokens inappropriately ---
|
828 |
+
logits_for_sampling_CxV = (
|
829 |
+
cfg_logits_CxV.clone()
|
830 |
+
) # Clone to avoid modifying original logits
|
831 |
+
logits_for_sampling_CxV[:, audio_pad_value] = -torch.inf # Never sample PAD
|
832 |
+
logits_for_sampling_CxV[:, audio_bos_value] = (
|
833 |
+
-torch.inf
|
834 |
+
) # Never sample BOS after start
|
835 |
+
# Allow EOS only if not already detected or in countdown
|
836 |
+
if eos_detected_channel_0 and eos_countdown <= 0:
|
837 |
+
logits_for_sampling_CxV[:, audio_eos_value] = -torch.inf
|
838 |
+
|
839 |
+
# --- Sample the next token for each channel ---
|
840 |
+
pred_C = _sample_next_token(
|
841 |
+
logits_for_sampling_CxV.float(), # Ensure float32 for sampling stability
|
842 |
+
temperature=temperature,
|
843 |
+
top_p=top_p,
|
844 |
+
use_cfg_filter=use_cfg_filter,
|
845 |
+
cfg_filter_top_k=cfg_filter_top_k,
|
846 |
+
) # Shape [C]
|
847 |
+
|
848 |
+
# --- Handle Delay Pattern (Only if no audio prompt was given) ---
|
849 |
+
# If there's no prompt, the first few tokens should be BOS according to delay
|
850 |
+
# generation_step_index is how many tokens generated *after* prompt/initial BOS
|
851 |
+
generation_step_index = step_offset
|
852 |
+
if audio_prompt_path is None:
|
853 |
+
is_before_delay = generation_step_index < delay_tensor # Shape [C]
|
854 |
+
pred_C = torch.where(
|
855 |
+
is_before_delay,
|
856 |
+
torch.tensor(
|
857 |
+
audio_bos_value, device=self.target_device, dtype=torch.long
|
858 |
+
),
|
859 |
+
pred_C,
|
860 |
+
)
|
861 |
+
|
862 |
+
# --- Store the predicted token in the newly_generated_tokens tensor ---
|
863 |
+
newly_generated_tokens[:, step_offset, :] = pred_C.unsqueeze(0).expand(
|
864 |
+
2, -1
|
865 |
+
)
|
866 |
+
|
867 |
+
steps_taken += 1 # Increment steps taken in this loop
|
868 |
+
|
869 |
+
# --- EOS Handling ---
|
870 |
+
if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
|
871 |
+
logger.info(
|
872 |
+
f"EOS token detected in channel 0 at step {current_absolute_step}. Starting countdown."
|
873 |
+
)
|
874 |
+
eos_detected_channel_0 = True
|
875 |
+
eos_countdown = extra_steps_after_eos
|
876 |
+
|
877 |
+
if eos_countdown > 0:
|
878 |
+
step_after_eos = extra_steps_after_eos - eos_countdown
|
879 |
+
logger.debug(
|
880 |
+
f"EOS countdown: {eos_countdown}, Step after EOS: {step_after_eos}"
|
881 |
+
)
|
882 |
+
# Modify the token *just generated* if needed for EOS/PAD forcing
|
883 |
+
current_new_tokens = newly_generated_tokens[
|
884 |
+
:, step_offset, :
|
885 |
+
] # Shape [2, C]
|
886 |
+
for i, d in enumerate(delay_pattern):
|
887 |
+
if step_after_eos == d:
|
888 |
+
logger.debug(
|
889 |
+
f" Forcing EOS in channel {i} at step {current_absolute_step}"
|
890 |
+
)
|
891 |
+
current_new_tokens[:, i] = audio_eos_value
|
892 |
+
elif step_after_eos > d:
|
893 |
+
logger.debug(
|
894 |
+
f" Forcing PAD in channel {i} at step {current_absolute_step}"
|
895 |
+
)
|
896 |
+
current_new_tokens[:, i] = audio_pad_value
|
897 |
+
# Put the potentially modified tokens back
|
898 |
+
newly_generated_tokens[:, step_offset, :] = current_new_tokens
|
899 |
+
|
900 |
+
eos_countdown -= 1
|
901 |
+
if eos_countdown == 0:
|
902 |
+
logger.info(
|
903 |
+
f"EOS countdown finished at step {current_absolute_step}. Stopping generation."
|
904 |
+
)
|
905 |
+
break # Stop generation loop
|
906 |
+
|
907 |
+
# Check if we reached the max *new* tokens requested
|
908 |
+
if steps_taken >= num_steps_to_generate:
|
909 |
+
logger.info(
|
910 |
+
f"Reached max generation steps ({num_steps_to_generate}). Stopping."
|
911 |
+
)
|
912 |
+
break
|
913 |
+
|
914 |
+
logger.info(
|
915 |
+
f"Autoregressive loop finished after {steps_taken} steps in {time.time() - start_time_loop:.3f}s."
|
916 |
+
)
|
917 |
+
|
918 |
+
# --- Extract Generated Codes ---
|
919 |
+
# Get the conditional generation result (index 1) from the *newly* generated tokens
|
920 |
+
# Only take the number of steps actually taken
|
921 |
+
final_new_codes = newly_generated_tokens[
|
922 |
+
1, :steps_taken, :
|
923 |
+
] # Shape [T_generated, C]
|
924 |
+
logger.info(f"Extracted newly generated codes shape: {final_new_codes.shape}")
|
925 |
+
|
926 |
+
# --- Convert Codes to Audio using DAC ---
|
927 |
+
logger.info("Converting generated codes to audio using DAC...")
|
928 |
+
start_time_decode = time.time()
|
929 |
+
if self.dac_model is None:
|
930 |
+
raise RuntimeError("DAC model not loaded, required for audio decoding.")
|
931 |
+
|
932 |
+
# codebook_to_audio expects codes shape [C, T]
|
933 |
+
generated_codes_CxT = final_new_codes.transpose(0, 1) # Shape [C, T_generated]
|
934 |
+
|
935 |
+
if generated_codes_CxT.numel() == 0:
|
936 |
+
logger.warning("No new codes were generated. Returning empty audio.")
|
937 |
+
return np.array([], dtype=np.float32)
|
938 |
+
|
939 |
+
# Call the decoding function (handles delay reversal and DAC decoding)
|
940 |
+
audio_waveform = codebook_to_audio(
|
941 |
+
generated_codes_CxT,
|
942 |
+
self.dac_model,
|
943 |
+
delay_pattern,
|
944 |
+
B=1, # Batch size for decoding is 1
|
945 |
+
T=generated_codes_CxT.shape[1], # Pass the actual length of generated codes
|
946 |
+
C=num_channels,
|
947 |
+
) # Returns shape [1, T_audio] or [T_audio]
|
948 |
+
|
949 |
+
# Ensure output is a 1D numpy array on CPU
|
950 |
+
final_audio_np = audio_waveform.squeeze().cpu().numpy()
|
951 |
+
logger.info(
|
952 |
+
f"Audio decoding completed in {time.time() - start_time_decode:.3f}s. Output shape: {final_audio_np.shape}"
|
953 |
+
)
|
954 |
+
logger.info(f"Total generation time: {time.time() - start_time_gen:.3f}s")
|
955 |
+
|
956 |
+
return final_audio_np
|
docker-compose.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
dia-tts-server:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: Dockerfile
|
8 |
+
ports:
|
9 |
+
- "${PORT:-8003}:${PORT:-8003}"
|
10 |
+
volumes:
|
11 |
+
- ./model_cache:/app/model_cache
|
12 |
+
- ./reference_audio:/app/reference_audio
|
13 |
+
- ./outputs:/app/outputs
|
14 |
+
deploy:
|
15 |
+
resources:
|
16 |
+
reservations:
|
17 |
+
devices:
|
18 |
+
- driver: nvidia
|
19 |
+
count: 1
|
20 |
+
capabilities: [gpu]
|
21 |
+
restart: unless-stopped
|
22 |
+
env_file:
|
23 |
+
- .env
|
documentation.md
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dia TTS Server - Technical Documentation
|
2 |
+
|
3 |
+
**Version:** 1.0.0
|
4 |
+
**Date:** 2025-04-22
|
5 |
+
|
6 |
+
**Table of Contents:**
|
7 |
+
|
8 |
+
1. [Overview](#1-overview)
|
9 |
+
2. [Visual Overview](#2-visual-overview)
|
10 |
+
* [Directory Structure](#21-directory-structure)
|
11 |
+
* [Component Diagram](#22-component-diagram)
|
12 |
+
3. [System Prerequisites](#3-system-prerequisites)
|
13 |
+
4. [Installation and Setup](#4-installation-and-setup)
|
14 |
+
* [Cloning the Repository](#41-cloning-the-repository)
|
15 |
+
* [Setting up Python Virtual Environment](#42-setting-up-python-virtual-environment)
|
16 |
+
* [Windows Setup](#421-windows-setup)
|
17 |
+
* [Linux Setup (Debian/Ubuntu Example)](#422-linux-setup-debianubuntu-example)
|
18 |
+
* [Installing Dependencies](#43-installing-dependencies)
|
19 |
+
* [NVIDIA Driver and CUDA Setup (Required for GPU Acceleration)](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration)
|
20 |
+
* [Step 1: Check/Install NVIDIA Drivers](#441-step-1-checkinstall-nvidia-drivers)
|
21 |
+
* [Step 2: Install PyTorch with CUDA Support](#442-step-2-install-pytorch-with-cuda-support)
|
22 |
+
* [Step 3: Verify PyTorch CUDA Installation](#443-step-3-verify-pytorch-cuda-installation)
|
23 |
+
5. [Configuration](#5-configuration)
|
24 |
+
* [Configuration Files (`.env` and `config.py`)](#51-configuration-files-env-and-configpy)
|
25 |
+
* [Configuration Parameters](#52-configuration-parameters)
|
26 |
+
6. [Running the Server](#6-running-the-server)
|
27 |
+
7. [Usage](#7-usage)
|
28 |
+
* [Web User Interface (Web UI)](#71-web-user-interface-web-ui)
|
29 |
+
* [Main Generation Form](#711-main-generation-form)
|
30 |
+
* [Presets](#712-presets)
|
31 |
+
* [Voice Cloning](#713-voice-cloning)
|
32 |
+
* [Generation Parameters](#714-generation-parameters)
|
33 |
+
* [Server Configuration (UI)](#715-server-configuration-ui)
|
34 |
+
* [Generated Audio Player](#716-generated-audio-player)
|
35 |
+
* [Theme Toggle](#717-theme-toggle)
|
36 |
+
* [API Endpoints](#72-api-endpoints)
|
37 |
+
* [POST /v1/audio/speech (OpenAI Compatible)](#721-post-v1audiospeech-openai-compatible)
|
38 |
+
* [POST /tts (Custom Parameters)](#722-post-tts-custom-parameters)
|
39 |
+
* [Configuration & Helper Endpoints](#723-configuration--helper-endpoints)
|
40 |
+
8. [Troubleshooting](#8-troubleshooting)
|
41 |
+
9. [Project Architecture](#9-project-architecture)
|
42 |
+
10. [License and Disclaimer](#10-license-and-disclaimer)
|
43 |
+
|
44 |
+
---
|
45 |
+
|
46 |
+
## 1. Overview
|
47 |
+
|
48 |
+
The Dia TTS Server provides a backend service and web interface for generating high-fidelity speech, including dialogue with multiple speakers and non-verbal sounds, using the Dia text-to-speech model family (originally from Nari Labs, with support for community conversions like SafeTensors).
|
49 |
+
|
50 |
+
This server is built using the FastAPI framework and offers both a RESTful API (including an OpenAI-compatible endpoint) and an interactive web UI powered by Jinja2, Tailwind CSS, and JavaScript. It supports voice cloning via audio prompts and allows configuration of various generation parameters.
|
51 |
+
|
52 |
+
**Key Features:**
|
53 |
+
|
54 |
+
* **High-Quality TTS:** Leverages the Dia model for realistic speech synthesis.
|
55 |
+
* **Dialogue Generation:** Supports `[S1]` and `[S2]` tags for multi-speaker dialogue.
|
56 |
+
* **Non-Verbal Sounds:** Can generate sounds like `(laughs)`, `(sighs)`, etc., when included in the text.
|
57 |
+
* **Voice Cloning:** Allows conditioning the output voice on a provided reference audio file.
|
58 |
+
* **Flexible Model Loading:** Supports loading models from Hugging Face repositories, including both `.pth` and `.safetensors` formats (defaults to BF16 SafeTensors for efficiency).
|
59 |
+
* **API Access:** Provides a custom API endpoint (`/tts`) and an OpenAI-compatible endpoint (`/v1/audio/speech`).
|
60 |
+
* **Web Interface:** Offers an easy-to-use UI for text input, parameter adjustment, preset loading, reference audio management, and audio playback.
|
61 |
+
* **Configuration:** Server settings, model sources, paths, and default generation parameters are configurable via an `.env` file.
|
62 |
+
* **GPU Acceleration:** Utilizes NVIDIA GPUs via CUDA for significantly faster inference when available, falling back to CPU otherwise.
|
63 |
+
|
64 |
+
---
|
65 |
+
|
66 |
+
## 2. Visual Overview
|
67 |
+
|
68 |
+
### 2.1 Directory Structure
|
69 |
+
|
70 |
+
```
|
71 |
+
dia-tts-server/
|
72 |
+
│
|
73 |
+
├── .env # Local configuration overrides (user-created)
|
74 |
+
├── config.py # Default configuration and management class
|
75 |
+
├── engine.py # Core model loading and generation logic
|
76 |
+
├── models.py # Pydantic models for API requests
|
77 |
+
├── requirements.txt # Python dependencies
|
78 |
+
├── server.py # Main FastAPI application, API endpoints, UI routes
|
79 |
+
├── utils.py # Utility functions (audio encoding, saving, etc.)
|
80 |
+
│
|
81 |
+
├── dia/ # Core Dia model implementation package
|
82 |
+
│ ├── __init__.py
|
83 |
+
│ ├── audio.py # Audio processing helpers (delay, codebook conversion)
|
84 |
+
│ ├── config.py # Pydantic models for Dia model architecture config
|
85 |
+
│ ├── layers.py # Custom PyTorch layers for the Dia model
|
86 |
+
│ └── model.py # Dia model class wrapper (loading, generation)
|
87 |
+
│
|
88 |
+
├── static/ # Static assets (e.g., favicon.ico)
|
89 |
+
│ └── favicon.ico
|
90 |
+
│
|
91 |
+
├── ui/ # Web User Interface files
|
92 |
+
│ ├── index.html # Main HTML template (Jinja2)
|
93 |
+
│ ├── presets.yaml # Predefined UI examples
|
94 |
+
│ ├── script.js # Frontend JavaScript logic
|
95 |
+
│ └── style.css # Frontend CSS styling (Tailwind via CDN/build)
|
96 |
+
│
|
97 |
+
├── model_cache/ # Default directory for downloaded model files (configurable)
|
98 |
+
├── outputs/ # Default directory for saved audio output (configurable)
|
99 |
+
└── reference_audio/ # Default directory for voice cloning reference files (configurable)
|
100 |
+
```
|
101 |
+
|
102 |
+
### 2.2 Component Diagram
|
103 |
+
|
104 |
+
```
|
105 |
+
┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐
|
106 |
+
│ User (Web UI / │────→ │ FastAPI Server │────→ │ TTS Engine │────→ │ Dia Model Wrapper │
|
107 |
+
│ API Client) │ │ (server.py) │ │ (engine.py) │ │ (dia/model.py) │
|
108 |
+
└───────────────────┘ └─────────┬─────────┘ └─────────┬─────────┘ └─────────┬─────────┘
|
109 |
+
│ │ │
|
110 |
+
│ Uses │ Uses │ Uses
|
111 |
+
▼ ▼ ▼
|
112 |
+
┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐
|
113 |
+
│ Configuration │ ←─── │ .env File │ │ Dia Model Layers │
|
114 |
+
│ (config.py) │ └───────────────────┘ │ (dia/layers.py) │
|
115 |
+
└───────────────────┘ └───────────────────┘
|
116 |
+
│ │ Uses
|
117 |
+
│ Uses │
|
118 |
+
▼ │
|
119 |
+
┌───────────────────┐ │ Uses
|
120 |
+
│ Utilities │ ▼
|
121 |
+
│ (utils.py) │ ┌───────────────────┐
|
122 |
+
└───────────────────┘ │ PyTorch / CUDA │
|
123 |
+
▲ └───────────────────┘
|
124 |
+
│ Uses │ Uses
|
125 |
+
│ ▼
|
126 |
+
┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐
|
127 |
+
│ Web UI Files │ ←─── │ Jinja2 Templates │ │ DAC Model │
|
128 |
+
│ (ui/) │ └───────────────────┘ │ (descript-audio..)│
|
129 |
+
└───────────────────┘ ▲ └───────────────────┘
|
130 |
+
│ Renders ▲
|
131 |
+
│ │ Uses
|
132 |
+
└────────────────────────────────────────────────┘
|
133 |
+
```
|
134 |
+
|
135 |
+
**Diagram Legend:**
|
136 |
+
|
137 |
+
* Boxes represent major components or file groups.
|
138 |
+
* Arrows (`→`) indicate primary data flow or control flow.
|
139 |
+
* Lines with "Uses" indicate dependencies or function calls.
|
140 |
+
|
141 |
+
---
|
142 |
+
|
143 |
+
## 3. System Prerequisites
|
144 |
+
|
145 |
+
Before installing and running the Dia TTS Server, ensure your system meets the following requirements:
|
146 |
+
|
147 |
+
* **Operating System:**
|
148 |
+
* Windows 10/11 (64-bit)
|
149 |
+
* Linux (Debian/Ubuntu recommended, other distributions may require adjustments)
|
150 |
+
* **Python:** Python 3.10 or later (Python 3.10.x recommended based on tracebacks). Ensure Python and Pip are added to your system's PATH.
|
151 |
+
* **Version Control:** Git (for cloning the repository).
|
152 |
+
* **Internet Connection:** Required for downloading dependencies and model files.
|
153 |
+
* **(Optional but Highly Recommended for Performance):**
|
154 |
+
* **NVIDIA GPU:** A CUDA-compatible NVIDIA GPU (Maxwell architecture or newer). Check compatibility [here](https://developer.nvidia.com/cuda-gpus). Sufficient VRAM is needed (BF16 model requires ~5-6GB, full precision ~10GB).
|
155 |
+
* **NVIDIA Drivers:** Latest appropriate drivers for your GPU and OS.
|
156 |
+
* **CUDA Toolkit:** Version compatible with the chosen PyTorch build (e.g., 11.8, 12.1). See [Section 4.4](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration).
|
157 |
+
* **(Linux System Libraries):**
|
158 |
+
* `libsndfile1`: Required by the `soundfile` Python library for audio I/O. Install using your package manager (e.g., `sudo apt install libsndfile1` on Debian/Ubuntu).
|
159 |
+
|
160 |
+
---
|
161 |
+
|
162 |
+
## 4. Installation and Setup
|
163 |
+
|
164 |
+
Follow these steps to set up the project environment and install necessary dependencies.
|
165 |
+
|
166 |
+
### 4.1 Cloning the Repository
|
167 |
+
|
168 |
+
Open your terminal or command prompt and navigate to the directory where you want to store the project. Then, clone the repository:
|
169 |
+
|
170 |
+
```bash
|
171 |
+
git clone https://github.com/devnen/dia-tts-server.git # Replace with the actual repo URL if different
|
172 |
+
cd dia-tts-server
|
173 |
+
```
|
174 |
+
|
175 |
+
### 4.2 Setting up Python Virtual Environment
|
176 |
+
|
177 |
+
Using a virtual environment is strongly recommended to isolate project dependencies.
|
178 |
+
|
179 |
+
#### 4.2.1 Windows Setup
|
180 |
+
|
181 |
+
1. **Open PowerShell or Command Prompt** in the project directory (`dia-tts-server`).
|
182 |
+
2. **Create the virtual environment:**
|
183 |
+
```powershell
|
184 |
+
python -m venv venv
|
185 |
+
```
|
186 |
+
3. **Activate the virtual environment:**
|
187 |
+
```powershell
|
188 |
+
.\venv\Scripts\activate
|
189 |
+
```
|
190 |
+
Your terminal prompt should now be prefixed with `(venv)`.
|
191 |
+
|
192 |
+
#### 4.2.2 Linux Setup (Debian/Ubuntu Example)
|
193 |
+
|
194 |
+
1. **Install prerequisites (if not already present):**
|
195 |
+
```bash
|
196 |
+
sudo apt update
|
197 |
+
sudo apt install python3 python3-venv python3-pip libsndfile1 -y
|
198 |
+
```
|
199 |
+
2. **Open your terminal** in the project directory (`dia-tts-server`).
|
200 |
+
3. **Create the virtual environment:**
|
201 |
+
```bash
|
202 |
+
python3 -m venv venv
|
203 |
+
```
|
204 |
+
4. **Activate the virtual environment:**
|
205 |
+
```bash
|
206 |
+
source venv/bin/activate
|
207 |
+
```
|
208 |
+
Your terminal prompt should now be prefixed with `(venv)`.
|
209 |
+
|
210 |
+
### 4.3 Installing Dependencies
|
211 |
+
|
212 |
+
With your virtual environment activated (`(venv)` prefix visible), install the required Python packages:
|
213 |
+
|
214 |
+
```bash
|
215 |
+
# Upgrade pip first (optional but good practice)
|
216 |
+
pip install --upgrade pip
|
217 |
+
|
218 |
+
# Install all dependencies from requirements.txt
|
219 |
+
pip install -r requirements.txt
|
220 |
+
```
|
221 |
+
|
222 |
+
**Note:** This command installs the CPU-only version of PyTorch by default. If you have a compatible NVIDIA GPU and want acceleration, proceed to [Section 4.4](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration) **before** running the server.
|
223 |
+
|
224 |
+
### 4.4 NVIDIA Driver and CUDA Setup (Required for GPU Acceleration)
|
225 |
+
|
226 |
+
Follow these steps **only if you have a compatible NVIDIA GPU** and want faster inference.
|
227 |
+
|
228 |
+
#### 4.4.1 Step 1: Check/Install NVIDIA Drivers
|
229 |
+
|
230 |
+
1. **Check Existing Driver:** Open Command Prompt (Windows) or Terminal (Linux) and run:
|
231 |
+
```bash
|
232 |
+
nvidia-smi
|
233 |
+
```
|
234 |
+
2. **Interpret Output:**
|
235 |
+
* If the command runs successfully, note the **Driver Version** and the **CUDA Version** listed in the top right corner. This CUDA version is the *maximum* supported by your current driver.
|
236 |
+
* If the command fails ("not recognized"), you need to install or update your NVIDIA drivers.
|
237 |
+
3. **Install/Update Drivers:** Go to the [NVIDIA Driver Downloads](https://www.nvidia.com/Download/index.aspx) page. Select your GPU model and OS, then download and install the latest recommended driver (Game Ready or Studio). **Reboot your computer** after installation. Run `nvidia-smi` again to confirm it works.
|
238 |
+
|
239 |
+
#### 4.4.2 Step 2: Install PyTorch with CUDA Support
|
240 |
+
|
241 |
+
1. **Go to PyTorch Website:** Visit [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/).
|
242 |
+
2. **Configure:** Select:
|
243 |
+
* **PyTorch Build:** Stable
|
244 |
+
* **Your OS:** Windows or Linux
|
245 |
+
* **Package:** Pip
|
246 |
+
* **Language:** Python
|
247 |
+
* **Compute Platform:** Choose the CUDA version **equal to or lower than** the version reported by `nvidia-smi`. For example, if `nvidia-smi` shows `CUDA Version: 12.4`, select `CUDA 12.1`. If it shows `11.8`, select `CUDA 11.8`. **Do not select a version higher than your driver supports.** (CUDA 12.1 or 11.8 are common stable choices).
|
248 |
+
3. **Copy Command:** Copy the generated installation command. It will look similar to:
|
249 |
+
```bash
|
250 |
+
# Example for CUDA 12.1 (Windows/Linux):
|
251 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
252 |
+
# Example for CUDA 11.8 (Windows/Linux):
|
253 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
254 |
+
```
|
255 |
+
*(Use `pip` instead of `pip3` if that's your command)*
|
256 |
+
4. **Install in Activated venv:**
|
257 |
+
* Ensure your `(venv)` is active.
|
258 |
+
* **Uninstall CPU PyTorch first:**
|
259 |
+
```bash
|
260 |
+
pip uninstall torch torchvision torchaudio -y
|
261 |
+
```
|
262 |
+
* **Paste and run the copied command** from the PyTorch website.
|
263 |
+
|
264 |
+
#### 4.4.3 Step 3: Verify PyTorch CUDA Installation
|
265 |
+
|
266 |
+
1. With the `(venv)` still active, start a Python interpreter:
|
267 |
+
```bash
|
268 |
+
python
|
269 |
+
```
|
270 |
+
2. Run the following Python code:
|
271 |
+
```python
|
272 |
+
import torch
|
273 |
+
print(f"PyTorch version: {torch.__version__}")
|
274 |
+
cuda_available = torch.cuda.is_available()
|
275 |
+
print(f"CUDA available: {cuda_available}")
|
276 |
+
if cuda_available:
|
277 |
+
print(f"CUDA version used by PyTorch: {torch.version.cuda}")
|
278 |
+
print(f"Device count: {torch.cuda.device_count()}")
|
279 |
+
print(f"Current device index: {torch.cuda.current_device()}")
|
280 |
+
print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
281 |
+
else:
|
282 |
+
print("CUDA not available to PyTorch. Ensure drivers and CUDA-enabled PyTorch are installed correctly.")
|
283 |
+
exit()
|
284 |
+
```
|
285 |
+
3. If `CUDA available:` shows `True`, the setup was successful. If `False`, review driver installation and the PyTorch installation command.
|
286 |
+
|
287 |
+
---
|
288 |
+
|
289 |
+
## 5. Configuration
|
290 |
+
|
291 |
+
The server's behavior, including model selection, paths, and default generation parameters, is controlled via configuration settings.
|
292 |
+
|
293 |
+
### 5.1 Configuration Files (`.env` and `config.py`)
|
294 |
+
|
295 |
+
* **`config.py`:** Defines the *default* values for all configuration parameters in the `DEFAULT_CONFIG` dictionary. It also contains the `ConfigManager` class and getter functions used by the application.
|
296 |
+
* **`.env` File:** This file, located in the project root directory (`dia-tts-server/.env`), allows you to *override* the default values. Create this file if it doesn't exist. Settings are defined as `KEY=VALUE` pairs, one per line. The server reads this file on startup using `python-dotenv`.
|
297 |
+
|
298 |
+
**Priority:** Values set in the `.env` file take precedence over the defaults in `config.py`. Environment variables set directly in your system also override `.env` file values (though using `.env` is generally recommended for project-specific settings).
|
299 |
+
|
300 |
+
### 5.2 Configuration Parameters
|
301 |
+
|
302 |
+
The following parameters can be set in your `.env` file:
|
303 |
+
|
304 |
+
| Parameter Name (in `.env`) | Default Value (`config.py`) | Description | Example `.env` Value |
|
305 |
+
| :--------------------------------- | :--------------------------------- | :--------------------------------------------------------------------------------------------------------- | :----------------------------------- |
|
306 |
+
| **Server Settings** | | | |
|
307 |
+
| `HOST` | `0.0.0.0` | The network interface address the server listens on. `0.0.0.0` makes it accessible on your local network. | `127.0.0.1` (localhost only) |
|
308 |
+
| `PORT` | `8003` | The port number the server listens on. | `8080` |
|
309 |
+
| **Model Source Settings** | | | |
|
310 |
+
| `DIA_MODEL_REPO_ID` | `ttj/dia-1.6b-safetensors` | The Hugging Face repository ID containing the model files. | `nari-labs/Dia-1.6B` |
|
311 |
+
| `DIA_MODEL_CONFIG_FILENAME` | `config.json` | The filename of the model's configuration JSON within the repository. | `config.json` |
|
312 |
+
| `DIA_MODEL_WEIGHTS_FILENAME` | `dia-v0_1_bf16.safetensors` | The filename of the model weights file (`.safetensors` or `.pth`) within the repository to load. | `dia-v0_1.safetensors` or `dia-v0_1.pth` |
|
313 |
+
| **Path Settings** | | | |
|
314 |
+
| `DIA_MODEL_CACHE_PATH` | `./model_cache` | Local directory to store downloaded model files. Relative paths are based on the project root. | `/path/to/shared/cache` |
|
315 |
+
| `REFERENCE_AUDIO_PATH` | `./reference_audio` | Local directory to store reference audio files (`.wav`, `.mp3`) used for voice cloning. | `./voices` |
|
316 |
+
| `OUTPUT_PATH` | `./outputs` | Local directory where generated audio files from the Web UI are saved. | `./generated_speech` |
|
317 |
+
| **Default Generation Parameters** | | *(These set the initial UI values and can be saved via the UI)* | |
|
318 |
+
| `GEN_DEFAULT_SPEED_FACTOR` | `0.90` | Default playback speed factor applied *after* generation (UI slider initial value). | `1.0` |
|
319 |
+
| `GEN_DEFAULT_CFG_SCALE` | `3.0` | Default Classifier-Free Guidance scale (UI slider initial value). | `2.5` |
|
320 |
+
| `GEN_DEFAULT_TEMPERATURE` | `1.3` | Default sampling temperature (UI slider initial value). | `1.2` |
|
321 |
+
| `GEN_DEFAULT_TOP_P` | `0.95` | Default nucleus sampling probability (UI slider initial value). | `0.9` |
|
322 |
+
| `GEN_DEFAULT_CFG_FILTER_TOP_K` | `35` | Default Top-K value for CFG filtering (UI slider initial value). | `40` |
|
323 |
+
|
324 |
+
**Example `.env` File (Using Original Nari Labs Model):**
|
325 |
+
|
326 |
+
```dotenv
|
327 |
+
# .env
|
328 |
+
# Example configuration to use the original Nari Labs model
|
329 |
+
|
330 |
+
HOST=0.0.0.0
|
331 |
+
PORT=8003
|
332 |
+
|
333 |
+
DIA_MODEL_REPO_ID=nari-labs/Dia-1.6B
|
334 |
+
DIA_MODEL_CONFIG_FILENAME=config.json
|
335 |
+
DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.pth
|
336 |
+
|
337 |
+
# Keep other paths as default or specify custom ones
|
338 |
+
# DIA_MODEL_CACHE_PATH=./model_cache
|
339 |
+
# REFERENCE_AUDIO_PATH=./reference_audio
|
340 |
+
# OUTPUT_PATH=./outputs
|
341 |
+
|
342 |
+
# Keep default generation parameters or override them
|
343 |
+
# GEN_DEFAULT_SPEED_FACTOR=0.90
|
344 |
+
# GEN_DEFAULT_CFG_SCALE=3.0
|
345 |
+
# GEN_DEFAULT_TEMPERATURE=1.3
|
346 |
+
# GEN_DEFAULT_TOP_P=0.95
|
347 |
+
# GEN_DEFAULT_CFG_FILTER_TOP_K=35
|
348 |
+
```
|
349 |
+
|
350 |
+
**Important:** You must **restart the server** after making changes to the `.env` file for them to take effect.
|
351 |
+
|
352 |
+
---
|
353 |
+
|
354 |
+
## 6. Running the Server
|
355 |
+
|
356 |
+
1. **Activate Virtual Environment:** Ensure your virtual environment is activated (`(venv)` prefix).
|
357 |
+
* Windows: `.\venv\Scripts\activate`
|
358 |
+
* Linux: `source venv/bin/activate`
|
359 |
+
2. **Navigate to Project Root:** Make sure your terminal is in the `dia-tts-server` directory.
|
360 |
+
3. **Run the Server:**
|
361 |
+
```bash
|
362 |
+
python server.py
|
363 |
+
```
|
364 |
+
4. **Server Output:** You should see log messages indicating the server is starting, including:
|
365 |
+
* The configuration being used (repo ID, filenames, paths).
|
366 |
+
* The device being used (CPU or CUDA).
|
367 |
+
* Model loading progress (downloading if necessary).
|
368 |
+
* Confirmation that the server is running (e.g., `Uvicorn running on http://0.0.0.0:8003`).
|
369 |
+
* URLs for accessing the Web UI and API Docs.
|
370 |
+
|
371 |
+
5. **Accessing the Server:**
|
372 |
+
* **Web UI:** Open your web browser and go to `http://localhost:PORT` (e.g., `http://localhost:8003` if using the default port). If running on a different machine or VM, replace `localhost` with the server's IP address.
|
373 |
+
* **API Docs:** Access the interactive API documentation (Swagger UI) at `http://localhost:PORT/docs`.
|
374 |
+
6. **Stopping the Server:** Press `CTRL+C` in the terminal where the server is running.
|
375 |
+
|
376 |
+
**Auto-Reload:** The server is configured to run with `reload=True`. This means Uvicorn will automatically restart the server if it detects changes in `.py`, `.html`, `.css`, `.js`, `.env`, or `.yaml` files within the project or `ui` directory. This is useful for development but should generally be disabled in production.
|
377 |
+
|
378 |
+
---
|
379 |
+
|
380 |
+
## 7. Usage
|
381 |
+
|
382 |
+
The Dia TTS Server can be used via its Web UI or its API endpoints.
|
383 |
+
|
384 |
+
### 7.1 Web User Interface (Web UI)
|
385 |
+
|
386 |
+
Access the UI by navigating to the server's base URL (e.g., `http://localhost:8003`).
|
387 |
+
|
388 |
+
#### 7.1.1 Main Generation Form
|
389 |
+
|
390 |
+
* **Text to speak:** Enter the text you want to synthesize.
|
391 |
+
* Use `[S1]` and `[S2]` tags to indicate speaker turns for dialogue.
|
392 |
+
* Include non-verbal cues like `(laughs)`, `(sighs)`, `(clears throat)` directly in the text where desired.
|
393 |
+
* For voice cloning, **prepend the exact transcript** of the selected reference audio before the text you want generated (e.g., `[S1] Reference transcript text. [S1] This is the new text to generate in the cloned voice.`).
|
394 |
+
* **Voice Mode:** Select the desired generation mode:
|
395 |
+
* **Single / Dialogue (Use [S1]/[S2]):** Use this for single-speaker text (you can use `[S1]` or omit tags if the model handles it) or multi-speaker dialogue (using `[S1]` and `[S2]`).
|
396 |
+
* **Voice Clone (from Reference):** Enables voice cloning based on a selected audio file. Requires selecting a file below and prepending its transcript to the text input.
|
397 |
+
* **Generate Speech Button:** Submits the text and settings to the server to start generation.
|
398 |
+
|
399 |
+
#### 7.1.2 Presets
|
400 |
+
|
401 |
+
* Located below the Voice Mode selection.
|
402 |
+
* Clicking a preset button (e.g., "Standard Dialogue", "Expressive Narration") will automatically populate the "Text to speak" area and the "Generation Parameters" sliders with predefined values, demonstrating different use cases.
|
403 |
+
|
404 |
+
#### 7.1.3 Voice Cloning
|
405 |
+
|
406 |
+
* This section appears only when "Voice Clone" mode is selected.
|
407 |
+
* **Reference Audio File Dropdown:** Lists available `.wav` and `.mp3` files found in the configured `REFERENCE_AUDIO_PATH`. Select the file whose voice you want to clone. Remember to prepend its transcript to the main text input.
|
408 |
+
* **Load Button:** Click this to open your system's file browser. You can select one or more `.wav` or `.mp3` files to upload. The selected files will be copied to the server's `REFERENCE_AUDIO_PATH`, and the dropdown list will refresh automatically. The first newly uploaded file will be selected in the dropdown.
|
409 |
+
|
410 |
+
#### 7.1.4 Generation Parameters
|
411 |
+
|
412 |
+
* Expand this section to fine-tune the generation process. These values correspond to the parameters used by the underlying Dia model.
|
413 |
+
* **Sliders:** Adjust Speed Factor, CFG Scale, Temperature, Top P, and CFG Filter Top K. The current value is displayed next to the label.
|
414 |
+
* **Save Generation Defaults Button:** Saves the *current* values of these sliders to the `.env` file (as `GEN_DEFAULT_...` keys). These saved values will become the default settings loaded into the UI the next time the server starts.
|
415 |
+
|
416 |
+
#### 7.1.5 Server Configuration (UI)
|
417 |
+
|
418 |
+
* Expand this section to view and modify server-level settings stored in the `.env` file.
|
419 |
+
* **Fields:** Edit Model Repo ID, Config/Weights Filenames, Cache/Reference/Output Paths, Host, and Port.
|
420 |
+
* **Save Server Configuration Button:** Saves the values currently shown in these fields to the `.env` file. **A server restart is required** for most of these changes (especially model source or paths) to take effect.
|
421 |
+
* **Restart Server Button:** (Appears after saving) Attempts to trigger a server restart. This works best if the server was started with `reload=True` or is managed by a process manager like systemd or Supervisor.
|
422 |
+
|
423 |
+
#### 7.1.6 Generated Audio Player
|
424 |
+
|
425 |
+
* Appears below the main form after a successful generation.
|
426 |
+
* **Waveform:** Visual representation of the generated audio.
|
427 |
+
* **Play/Pause Button:** Controls audio playback.
|
428 |
+
* **Download WAV Button:** Downloads the generated audio as a `.wav` file.
|
429 |
+
* **Info:** Displays the voice mode used, generation time, and audio duration.
|
430 |
+
|
431 |
+
#### 7.1.7 Theme Toggle
|
432 |
+
|
433 |
+
* Located in the top-right navigation bar.
|
434 |
+
* Click the Sun/Moon icon to switch between Light and Dark themes. Your preference is saved in your browser's `localStorage`.
|
435 |
+
|
436 |
+
### 7.2 API Endpoints
|
437 |
+
|
438 |
+
Access the interactive API documentation via the `/docs` path (e.g., `http://localhost:8003/docs`).
|
439 |
+
|
440 |
+
#### 7.2.1 POST `/v1/audio/speech` (OpenAI Compatible)
|
441 |
+
|
442 |
+
* **Purpose:** Provides an endpoint compatible with the basic OpenAI TTS API for easier integration with existing tools.
|
443 |
+
* **Request Body:** (`application/json`) - Uses the `OpenAITTSRequest` model.
|
444 |
+
| Field | Type | Required | Description | Example |
|
445 |
+
| :---------------- | :----------------------- | :------- | :---------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------- |
|
446 |
+
| `model` | string | No | Ignored by this server (always uses Dia). Included for compatibility. Defaults to `dia-1.6b`. | `"dia-1.6b"` |
|
447 |
+
| `input` | string | Yes | The text to synthesize. Use `[S1]`/`[S2]` tags for dialogue. For cloning, prepend reference transcript. | `"Hello [S1] world."` |
|
448 |
+
| `voice` | string | No | Maps to Dia modes. Use `"S1"`, `"S2"`, `"dialogue"`, or the filename of a reference audio (e.g., `"my_ref.wav"`) for cloning. Defaults to `S1`. | `"dialogue"` or `"ref.mp3"` |
|
449 |
+
| `response_format` | `"opus"` \| `"wav"` | No | Desired audio output format. Defaults to `opus`. | `"wav"` |
|
450 |
+
| `speed` | float | No | Playback speed factor (0.5-2.0). Applied *after* generation. Defaults to `1.0`. | `0.9` |
|
451 |
+
* **Response:**
|
452 |
+
* **Success (200 OK):** `StreamingResponse` containing the binary audio data (`audio/opus` or `audio/wav`).
|
453 |
+
* **Error:** Standard FastAPI JSON error response (e.g., 400, 404, 500).
|
454 |
+
|
455 |
+
#### 7.2.2 POST `/tts` (Custom Parameters)
|
456 |
+
|
457 |
+
* **Purpose:** Allows generation using all specific Dia generation parameters.
|
458 |
+
* **Request Body:** (`application/json`) - Uses the `CustomTTSRequest` model.
|
459 |
+
| Field | Type | Required | Description | Default |
|
460 |
+
| :------------------------- | :------------------------------------- | :------- | :---------------------------------------------------------------------------------------------------------------------------------------- | :---------- |
|
461 |
+
| `text` | string | Yes | The text to synthesize. Use `[S1]`/`[S2]` tags. Prepend transcript for cloning. | |
|
462 |
+
| `voice_mode` | `"dialogue"` \| `"clone"` | No | Generation mode. Note: `single_s1`/`single_s2` are handled via `dialogue` mode with appropriate tags in the text. | `dialogue` |
|
463 |
+
| `clone_reference_filename` | string \| null | No | Filename of reference audio in `REFERENCE_AUDIO_PATH`. **Required if `voice_mode` is `clone`**. | `null` |
|
464 |
+
| `output_format` | `"opus"` \| `"wav"` | No | Desired audio output format. | `opus` |
|
465 |
+
| `max_tokens` | integer \| null | No | Maximum audio tokens to generate. `null` uses the model's default. | `null` |
|
466 |
+
| `cfg_scale` | float | No | Classifier-Free Guidance scale. | `3.0` |
|
467 |
+
| `temperature` | float | No | Sampling temperature. | `1.3` |
|
468 |
+
| `top_p` | float | No | Nucleus sampling probability. | `0.95` |
|
469 |
+
| `speed_factor` | float | No | Playback speed factor (0.5-2.0). Applied *after* generation. | `0.90` |
|
470 |
+
| `cfg_filter_top_k` | integer | No | Top-K value for CFG filtering. | `35` |
|
471 |
+
* **Response:**
|
472 |
+
* **Success (200 OK):** `StreamingResponse` containing the binary audio data (`audio/opus` or `audio/wav`).
|
473 |
+
* **Error:** Standard FastAPI JSON error response (e.g., 400, 404, 500).
|
474 |
+
|
475 |
+
#### 7.2.3 Configuration & Helper Endpoints
|
476 |
+
|
477 |
+
* **GET `/get_config`:** Returns the current server configuration as JSON.
|
478 |
+
* **POST `/save_config`:** Saves server configuration settings provided in the JSON request body to the `.env` file. Requires server restart.
|
479 |
+
* **POST `/save_generation_defaults`:** Saves default generation parameters provided in the JSON request body to the `.env` file. Affects UI defaults on next load.
|
480 |
+
* **POST `/restart_server`:** Attempts to trigger a server restart (reliability depends on execution environment).
|
481 |
+
* **POST `/upload_reference`:** Uploads one or more audio files (`.wav`, `.mp3`) as `multipart/form-data` to the reference audio directory. Returns JSON with status and updated file list.
|
482 |
+
* **GET `/health`:** Basic health check endpoint. Returns `{"status": "healthy", "model_loaded": true/false}`.
|
483 |
+
|
484 |
+
---
|
485 |
+
|
486 |
+
## 8. Troubleshooting
|
487 |
+
|
488 |
+
* **Error: `CUDA available: False` or Slow Performance:**
|
489 |
+
* Verify NVIDIA drivers are installed correctly (`nvidia-smi` command).
|
490 |
+
* Ensure you installed the correct PyTorch version with CUDA support matching your driver (See [Section 4.4](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration)). Reinstall PyTorch using the command from the official website if unsure.
|
491 |
+
* Check if another process is using all GPU VRAM.
|
492 |
+
* **Error: `ImportError: No module named 'dac'` (or `safetensors`, `yaml`, etc.):**
|
493 |
+
* Make sure your virtual environment is activated.
|
494 |
+
* Run `pip install -r requirements.txt` again to install missing dependencies.
|
495 |
+
* Specifically for `dac`, ensure you installed `descript-audio-codec` and not a different package named `dac`. Run `pip uninstall dac -y && pip install descript-audio-codec`.
|
496 |
+
* **Error: `libsndfile library not found` (or similar `soundfile` error, mainly on Linux):**
|
497 |
+
* Install the system library: `sudo apt update && sudo apt install libsndfile1` (Debian/Ubuntu) or the equivalent for your distribution.
|
498 |
+
* **Error: Model Download Fails (e.g., `HTTPError`, `ConnectionError`):**
|
499 |
+
* Check your internet connection.
|
500 |
+
* Verify the `DIA_MODEL_REPO_ID`, `DIA_MODEL_CONFIG_FILENAME`, and `DIA_MODEL_WEIGHTS_FILENAME` in your `.env` file (or defaults in `config.py`) are correct and accessible on Hugging Face Hub.
|
501 |
+
* Check Hugging Face Hub status if multiple downloads fail.
|
502 |
+
* Ensure the cache directory (`DIA_MODEL_CACHE_PATH`) is writable.
|
503 |
+
* **Error: `RuntimeError: Failed to load DAC model...`:**
|
504 |
+
* This usually indicates an issue with the `descript-audio-codec` installation or version incompatibility. Ensure it's installed correctly (see `ImportError` above).
|
505 |
+
* Check logs for specific `AttributeError` messages (like missing `utils` or `download`) which might indicate version mismatches between the Dia code's expectation and the installed library. The current code expects `dac.utils.download()`.
|
506 |
+
* **Error: `FileNotFoundError` during generation (Reference Audio):**
|
507 |
+
* Ensure the filename selected/provided for voice cloning exists in the configured `REFERENCE_AUDIO_PATH`.
|
508 |
+
* Check that the path in `config.py` or `.env` is correct and the server has permission to read from it.
|
509 |
+
* **Error: Cannot Save Output/Reference Files (`PermissionError`, etc.):**
|
510 |
+
* Ensure the directories specified by `OUTPUT_PATH` and `REFERENCE_AUDIO_PATH` exist and the server process has write permissions to them.
|
511 |
+
* **Web UI Issues (Buttons don't work, styles missing):**
|
512 |
+
* Clear your browser cache.
|
513 |
+
* Check the browser's developer console (usually F12) for JavaScript errors.
|
514 |
+
* Ensure `ui/script.js` and `ui/style.css` are being loaded correctly (check network tab in developer tools).
|
515 |
+
* **Generation Cancel Button Doesn't Stop Process:**
|
516 |
+
* This is expected ("Fake Cancel"). The button currently only prevents the UI from processing the result when it eventually arrives. True cancellation is complex and not implemented. Clicking "Generate" again *will* cancel the *previous UI request's result processing* before starting the new one.
|
517 |
+
|
518 |
+
---
|
519 |
+
|
520 |
+
## 9. Project Architecture
|
521 |
+
|
522 |
+
* **`server.py`:** The main entry point using FastAPI. Defines API routes, serves the Web UI using Jinja2, handles requests, and orchestrates calls to the engine.
|
523 |
+
* **`engine.py`:** Responsible for loading the Dia model (including downloading files via `huggingface_hub`), managing the model instance, preparing inputs for the model's `generate` method based on user requests (handling voice modes), and calling the model's generation function. Also handles post-processing like speed adjustment.
|
524 |
+
* **`config.py`:** Manages all configuration settings using default values and overrides from a `.env` file. Provides getter functions for easy access to settings.
|
525 |
+
* **`dia/` package:** Contains the core implementation of the Dia model itself.
|
526 |
+
* `model.py`: Defines the `Dia` class, which wraps the underlying PyTorch model (`DiaModel`). It handles loading weights (`.pth` or `.safetensors`), loading the required DAC model, preparing inputs specifically for the `DiaModel` forward pass (including CFG logic), and running the autoregressive generation loop.
|
527 |
+
* `config.py` (within `dia/`): Defines Pydantic models representing the *structure* and hyperparameters of the Dia model architecture (encoder, decoder, data parameters). This is loaded from the `config.json` file associated with the model weights.
|
528 |
+
* `layers.py`: Contains custom PyTorch `nn.Module` implementations used within the `DiaModel` (e.g., Attention blocks, MLP blocks, RoPE).
|
529 |
+
* `audio.py`: Includes helper functions for audio processing specific to the model's tokenization and delay patterns (e.g., `audio_to_codebook`, `codebook_to_audio`, `apply_audio_delay`).
|
530 |
+
* **`ui/` directory:** Contains all files related to the Web UI.
|
531 |
+
* `index.html`: The main Jinja2 template.
|
532 |
+
* `script.js`: Frontend JavaScript for interactivity, API calls, theme switching, etc.
|
533 |
+
* `presets.yaml`: Definitions for the UI preset examples.
|
534 |
+
* **`utils.py`:** General utility functions, such as audio encoding (`encode_audio`) and saving (`save_audio_to_file`) using the `soundfile` library.
|
535 |
+
* **Dependencies:** Relies heavily on `FastAPI`, `Uvicorn`, `PyTorch`, `torchaudio`, `huggingface_hub`, `safetensors`, `descript-audio-codec`, `soundfile`, `PyYAML`, `python-dotenv`, `pydantic`, and `Jinja2`.
|
536 |
+
|
537 |
+
---
|
538 |
+
|
539 |
+
## 10. License and Disclaimer
|
540 |
+
|
541 |
+
* **License:** This project is licensed under the MIT License.
|
542 |
+
* **Disclaimer:** This project offers a high-fidelity speech generation model intended solely for research and educational use. The following uses are **strictly forbidden**:
|
543 |
+
* **Identity Misuse**: Do not produce audio resembling real individuals without permission.
|
544 |
+
* **Deceptive Content**: Do not use this model to generate misleading content (e.g. fake news)
|
545 |
+
* **Illegal or Malicious Use**: Do not use this model for activities that are illegal or intended to cause harm.
|
546 |
+
|
547 |
+
By using this model, you agree to uphold relevant legal standards and ethical responsibilities. The creators **are not responsible** for any misuse and firmly oppose any unethical usage of this technology.
|
548 |
+
|
549 |
+
---
|
download_model.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# download_model.py
|
2 |
+
# Utility script to download the Dia model and dependencies without starting the server.
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import engine # Import the engine module to trigger its loading logic
|
7 |
+
|
8 |
+
# Configure basic logging for the script
|
9 |
+
logging.basicConfig(
|
10 |
+
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
|
11 |
+
)
|
12 |
+
logger = logging.getLogger("ModelDownloader")
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
logger.info("--- Starting Dia Model Download ---")
|
16 |
+
|
17 |
+
# Ensure cache directory exists (redundant if engine.load_model does it, but safe)
|
18 |
+
try:
|
19 |
+
from config import get_model_cache_path
|
20 |
+
|
21 |
+
cache_path = get_model_cache_path()
|
22 |
+
os.makedirs(cache_path, exist_ok=True)
|
23 |
+
logger.info(
|
24 |
+
f"Ensured model cache directory exists: {os.path.abspath(cache_path)}"
|
25 |
+
)
|
26 |
+
except Exception as e:
|
27 |
+
logger.warning(f"Could not ensure cache directory exists: {e}")
|
28 |
+
|
29 |
+
# Trigger the model loading function from the engine
|
30 |
+
logger.info("Calling engine.load_model() to initiate download if necessary...")
|
31 |
+
success = engine.load_model()
|
32 |
+
|
33 |
+
if success:
|
34 |
+
logger.info("--- Model download/load process completed successfully ---")
|
35 |
+
else:
|
36 |
+
logger.error(
|
37 |
+
"--- Model download/load process failed. Check logs for details. ---"
|
38 |
+
)
|
39 |
+
exit(1) # Exit with error code
|
40 |
+
|
41 |
+
logger.info("You can now start the server using 'python server.py'")
|
engine.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# engine.py
|
2 |
+
# Core Dia TTS model loading and generation logic
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from typing import Optional, Tuple
|
10 |
+
from huggingface_hub import hf_hub_download # Import downloader
|
11 |
+
|
12 |
+
# Import Dia model class and config
|
13 |
+
try:
|
14 |
+
from dia.model import Dia
|
15 |
+
from dia.config import DiaConfig
|
16 |
+
except ImportError as e:
|
17 |
+
# Log critical error if core components are missing
|
18 |
+
logging.critical(
|
19 |
+
f"Failed to import Dia model components: {e}. Ensure the 'dia' package exists and is importable.",
|
20 |
+
exc_info=True,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Define dummy classes/functions to prevent server crash on import,
|
24 |
+
# but generation will fail later if these are used.
|
25 |
+
class Dia:
|
26 |
+
@staticmethod
|
27 |
+
def load_model_from_files(*args, **kwargs):
|
28 |
+
raise RuntimeError("Dia model package not available or failed to import.")
|
29 |
+
|
30 |
+
def generate(*args, **kwargs):
|
31 |
+
raise RuntimeError("Dia model package not available or failed to import.")
|
32 |
+
|
33 |
+
class DiaConfig:
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
# Import configuration getters from our project's config.py
|
38 |
+
from config import (
|
39 |
+
get_model_repo_id,
|
40 |
+
get_model_cache_path,
|
41 |
+
get_reference_audio_path,
|
42 |
+
get_model_config_filename,
|
43 |
+
get_model_weights_filename,
|
44 |
+
)
|
45 |
+
|
46 |
+
logger = logging.getLogger(__name__) # Use standard logger name
|
47 |
+
|
48 |
+
# --- Global Variables ---
|
49 |
+
dia_model: Optional[Dia] = None
|
50 |
+
# model_config is now loaded within Dia.load_model_from_files, maybe remove global?
|
51 |
+
# Let's keep it for now if needed elsewhere, but populate it after loading.
|
52 |
+
model_config_instance: Optional[DiaConfig] = None
|
53 |
+
model_device: Optional[torch.device] = None
|
54 |
+
MODEL_LOADED = False
|
55 |
+
EXPECTED_SAMPLE_RATE = 44100 # Dia model and DAC typically operate at 44.1kHz
|
56 |
+
|
57 |
+
# --- Model Loading ---
|
58 |
+
|
59 |
+
|
60 |
+
def get_device() -> torch.device:
|
61 |
+
"""Determines the optimal torch device (CUDA > MPS > CPU)."""
|
62 |
+
if torch.cuda.is_available():
|
63 |
+
logger.info("CUDA is available, using GPU.")
|
64 |
+
return torch.device("cuda")
|
65 |
+
# Add MPS check for Apple Silicon GPUs
|
66 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
67 |
+
# Basic check is usually sufficient
|
68 |
+
logger.info("MPS is available, using Apple Silicon GPU.")
|
69 |
+
return torch.device("mps")
|
70 |
+
else:
|
71 |
+
logger.info("CUDA and MPS not available, using CPU.")
|
72 |
+
return torch.device("cpu")
|
73 |
+
|
74 |
+
|
75 |
+
def load_model():
|
76 |
+
"""
|
77 |
+
Loads the Dia TTS model and associated DAC model.
|
78 |
+
Downloads model files based on configuration if they don't exist locally.
|
79 |
+
Handles both .pth and .safetensors formats.
|
80 |
+
"""
|
81 |
+
global dia_model, model_config_instance, model_device, MODEL_LOADED
|
82 |
+
|
83 |
+
if MODEL_LOADED:
|
84 |
+
logger.info("Dia model already loaded.")
|
85 |
+
return True
|
86 |
+
|
87 |
+
# Get configuration values
|
88 |
+
repo_id = get_model_repo_id()
|
89 |
+
config_filename = get_model_config_filename()
|
90 |
+
weights_filename = get_model_weights_filename()
|
91 |
+
cache_path = get_model_cache_path() # Already absolute path
|
92 |
+
model_device = get_device()
|
93 |
+
|
94 |
+
logger.info(f"Attempting to load Dia model:")
|
95 |
+
logger.info(f" Repo ID: {repo_id}")
|
96 |
+
logger.info(f" Config File: {config_filename}")
|
97 |
+
logger.info(f" Weights File: {weights_filename}")
|
98 |
+
logger.info(f" Cache Directory: {cache_path}")
|
99 |
+
logger.info(f" Target Device: {model_device}")
|
100 |
+
|
101 |
+
# Ensure cache directory exists
|
102 |
+
try:
|
103 |
+
os.makedirs(cache_path, exist_ok=True)
|
104 |
+
except OSError as e:
|
105 |
+
logger.error(
|
106 |
+
f"Failed to create cache directory '{cache_path}': {e}", exc_info=True
|
107 |
+
)
|
108 |
+
# Depending on severity, might want to return False here
|
109 |
+
# return False
|
110 |
+
pass # Continue and let hf_hub_download handle potential issues
|
111 |
+
|
112 |
+
try:
|
113 |
+
start_time = time.time()
|
114 |
+
|
115 |
+
# --- Download Model Files ---
|
116 |
+
logger.info(
|
117 |
+
f"Downloading/finding configuration file '{config_filename}' from repo '{repo_id}'..."
|
118 |
+
)
|
119 |
+
local_config_path = hf_hub_download(
|
120 |
+
repo_id=repo_id,
|
121 |
+
filename=config_filename,
|
122 |
+
cache_dir=cache_path,
|
123 |
+
# force_download=False, # Default: only download if missing or outdated
|
124 |
+
# resume_download=True, # Default: resume interrupted downloads
|
125 |
+
)
|
126 |
+
logger.info(f"Configuration file path: {local_config_path}")
|
127 |
+
|
128 |
+
logger.info(
|
129 |
+
f"Downloading/finding weights file '{weights_filename}' from repo '{repo_id}'..."
|
130 |
+
)
|
131 |
+
local_weights_path = hf_hub_download(
|
132 |
+
repo_id=repo_id,
|
133 |
+
filename=weights_filename,
|
134 |
+
cache_dir=cache_path,
|
135 |
+
)
|
136 |
+
logger.info(f"Weights file path: {local_weights_path}")
|
137 |
+
|
138 |
+
# --- Load Model using the class method ---
|
139 |
+
# The Dia class method now handles config loading, instantiation, weight loading, etc.
|
140 |
+
dia_model = Dia.load_model_from_files(
|
141 |
+
config_path=local_config_path,
|
142 |
+
weights_path=local_weights_path,
|
143 |
+
device=model_device,
|
144 |
+
)
|
145 |
+
|
146 |
+
# Store the config instance if needed globally (optional)
|
147 |
+
model_config_instance = dia_model.config
|
148 |
+
|
149 |
+
end_time = time.time()
|
150 |
+
logger.info(
|
151 |
+
f"Dia model loaded successfully in {end_time - start_time:.2f} seconds."
|
152 |
+
)
|
153 |
+
MODEL_LOADED = True
|
154 |
+
return True
|
155 |
+
|
156 |
+
except FileNotFoundError as e:
|
157 |
+
logger.error(
|
158 |
+
f"Model loading failed: Required file not found. {e}", exc_info=True
|
159 |
+
)
|
160 |
+
MODEL_LOADED = False
|
161 |
+
return False
|
162 |
+
except ImportError:
|
163 |
+
# This catches if the 'dia' package itself is missing
|
164 |
+
logger.critical(
|
165 |
+
"Failed to load model: Dia package or its core dependencies not found.",
|
166 |
+
exc_info=True,
|
167 |
+
)
|
168 |
+
MODEL_LOADED = False
|
169 |
+
return False
|
170 |
+
except Exception as e:
|
171 |
+
# Catch other potential errors during download or loading
|
172 |
+
logger.error(
|
173 |
+
f"Error loading Dia model from repo '{repo_id}': {e}", exc_info=True
|
174 |
+
)
|
175 |
+
dia_model = None
|
176 |
+
model_config_instance = None
|
177 |
+
MODEL_LOADED = False
|
178 |
+
return False
|
179 |
+
|
180 |
+
|
181 |
+
# --- Speech Generation ---
|
182 |
+
|
183 |
+
|
184 |
+
def generate_speech(
|
185 |
+
text: str,
|
186 |
+
voice_mode: str = "single_s1",
|
187 |
+
clone_reference_filename: Optional[str] = None,
|
188 |
+
max_tokens: Optional[int] = None,
|
189 |
+
cfg_scale: float = 3.0,
|
190 |
+
temperature: float = 1.3,
|
191 |
+
top_p: float = 0.95,
|
192 |
+
speed_factor: float = 0.94, # Keep speed factor separate from model generation params
|
193 |
+
cfg_filter_top_k: int = 35,
|
194 |
+
) -> Optional[Tuple[np.ndarray, int]]:
|
195 |
+
"""
|
196 |
+
Generates speech using the loaded Dia model, handling voice modes and speed adjustment.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
text: Text to synthesize.
|
200 |
+
voice_mode: 'dialogue', 'single_s1', 'single_s2', 'clone'.
|
201 |
+
clone_reference_filename: Filename for voice cloning (if mode is 'clone'). Located in reference audio path.
|
202 |
+
max_tokens: Max generation tokens for the model's generate method.
|
203 |
+
cfg_scale: CFG scale for the model's generate method.
|
204 |
+
temperature: Sampling temperature for the model's generate method.
|
205 |
+
top_p: Nucleus sampling p for the model's generate method.
|
206 |
+
speed_factor: Factor to adjust the playback speed *after* generation (e.g., 0.9 = slower, 1.1 = faster).
|
207 |
+
cfg_filter_top_k: CFG filter top K for the model's generate method.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
Tuple of (numpy_audio_array, sample_rate), or None on failure.
|
211 |
+
"""
|
212 |
+
if not MODEL_LOADED or dia_model is None:
|
213 |
+
logger.error("Dia model is not loaded. Cannot generate speech.")
|
214 |
+
return None
|
215 |
+
|
216 |
+
logger.info(f"Generating speech with mode: {voice_mode}")
|
217 |
+
logger.debug(f"Input text (start): '{text[:100]}...'")
|
218 |
+
# Log model generation parameters
|
219 |
+
logger.debug(
|
220 |
+
f"Model Params: max_tokens={max_tokens}, cfg={cfg_scale}, temp={temperature}, top_p={top_p}, top_k={cfg_filter_top_k}"
|
221 |
+
)
|
222 |
+
# Log post-processing parameters
|
223 |
+
logger.debug(f"Post-processing Params: speed_factor={speed_factor}")
|
224 |
+
|
225 |
+
audio_prompt_path = None
|
226 |
+
processed_text = text # Start with original text
|
227 |
+
|
228 |
+
# --- Handle Voice Mode ---
|
229 |
+
if voice_mode == "clone":
|
230 |
+
if not clone_reference_filename:
|
231 |
+
logger.error("Clone mode selected but no reference filename provided.")
|
232 |
+
return None
|
233 |
+
ref_base_path = get_reference_audio_path() # Gets absolute path
|
234 |
+
potential_path = os.path.join(ref_base_path, clone_reference_filename)
|
235 |
+
if os.path.isfile(potential_path):
|
236 |
+
audio_prompt_path = potential_path
|
237 |
+
logger.info(f"Using audio prompt for cloning: {audio_prompt_path}")
|
238 |
+
# Dia requires the transcript of the clone audio to be prepended to the target text.
|
239 |
+
# The UI/API caller is responsible for constructing this combined text.
|
240 |
+
logger.warning(
|
241 |
+
"Clone mode active. Ensure the 'text' input includes the transcript of the reference audio for best results (e.g., '[S1] Reference transcript. [S1] Target text...')."
|
242 |
+
)
|
243 |
+
processed_text = text # Use the combined text provided by the caller
|
244 |
+
else:
|
245 |
+
logger.error(f"Reference audio file not found: {potential_path}")
|
246 |
+
return None # Fail generation if reference file is missing
|
247 |
+
elif voice_mode == "dialogue":
|
248 |
+
# Assume text already contains [S1]/[S2] tags as required by the model
|
249 |
+
logger.info("Using dialogue mode. Expecting [S1]/[S2] tags in input text.")
|
250 |
+
if "[S1]" not in text and "[S2]" not in text:
|
251 |
+
logger.warning(
|
252 |
+
"Dialogue mode selected, but no [S1] or [S2] tags found in the input text."
|
253 |
+
)
|
254 |
+
processed_text = text # Pass directly
|
255 |
+
elif voice_mode == "single_s1":
|
256 |
+
logger.info("Using single voice mode (S1).")
|
257 |
+
# Check if text *already* contains tags, warn if so, as it might confuse the model
|
258 |
+
if "[S1]" in text or "[S2]" in text:
|
259 |
+
logger.warning(
|
260 |
+
"Input text contains dialogue tags ([S1]/[S2]), but 'single_s1' mode was selected. Model behavior might be unexpected."
|
261 |
+
)
|
262 |
+
# Dia likely expects tags even for single speaker. Prepending [S1] might be safer.
|
263 |
+
# Let's assume for now the model handles untagged text as S1, but this could be adjusted.
|
264 |
+
# Consider: processed_text = f"[S1] {text}" # Option to enforce S1 tag
|
265 |
+
processed_text = text # Pass directly for now
|
266 |
+
elif voice_mode == "single_s2":
|
267 |
+
logger.info("Using single voice mode (S2).")
|
268 |
+
if "[S1]" in text or "[S2]" in text:
|
269 |
+
logger.warning(
|
270 |
+
"Input text contains dialogue tags ([S1]/[S2]), but 'single_s2' mode was selected."
|
271 |
+
)
|
272 |
+
# Similar to S1, how to signal S2? Prepending [S2] seems logical if needed.
|
273 |
+
# Consider: processed_text = f"[S2] {text}" # Option to enforce S2 tag
|
274 |
+
processed_text = text # Pass directly for now
|
275 |
+
else:
|
276 |
+
logger.error(
|
277 |
+
f"Unsupported voice_mode: {voice_mode}. Defaulting to 'single_s1'."
|
278 |
+
)
|
279 |
+
processed_text = text # Fallback
|
280 |
+
|
281 |
+
# --- Call Dia Generate ---
|
282 |
+
try:
|
283 |
+
start_time = time.time()
|
284 |
+
logger.info("Calling Dia model generate method...")
|
285 |
+
|
286 |
+
# Call the model's generate method with appropriate parameters
|
287 |
+
generated_audio_np = dia_model.generate(
|
288 |
+
text=processed_text,
|
289 |
+
audio_prompt_path=audio_prompt_path,
|
290 |
+
max_tokens=max_tokens, # Pass None if not specified, Dia uses its default
|
291 |
+
cfg_scale=cfg_scale,
|
292 |
+
temperature=temperature,
|
293 |
+
top_p=top_p,
|
294 |
+
use_cfg_filter=True, # Default from Dia's app.py, seems reasonable
|
295 |
+
cfg_filter_top_k=cfg_filter_top_k,
|
296 |
+
use_torch_compile=False, # Keep False for stability unless specifically tested/enabled
|
297 |
+
)
|
298 |
+
gen_end_time = time.time()
|
299 |
+
logger.info(
|
300 |
+
f"Dia model generation finished in {gen_end_time - start_time:.2f} seconds."
|
301 |
+
)
|
302 |
+
|
303 |
+
if generated_audio_np is None or generated_audio_np.size == 0:
|
304 |
+
logger.warning("Dia model returned None or empty audio array.")
|
305 |
+
return None
|
306 |
+
|
307 |
+
# --- Apply Speed Factor (Post-processing) ---
|
308 |
+
# This mimics the logic in Dia's original app.py
|
309 |
+
if speed_factor != 1.0:
|
310 |
+
logger.info(f"Applying speed factor: {speed_factor}")
|
311 |
+
original_len = len(generated_audio_np)
|
312 |
+
# Ensure speed_factor is within a reasonable range to avoid extreme distortion
|
313 |
+
# Adjust range based on observed quality (e.g., 0.5 to 2.0)
|
314 |
+
speed_factor = max(0.5, min(speed_factor, 2.0))
|
315 |
+
target_len = int(original_len / speed_factor)
|
316 |
+
|
317 |
+
if target_len > 0 and target_len != original_len:
|
318 |
+
logger.debug(
|
319 |
+
f"Resampling audio from {original_len} to {target_len} samples."
|
320 |
+
)
|
321 |
+
# Create time axes for original and resampled audio
|
322 |
+
x_original = np.linspace(0, original_len - 1, original_len)
|
323 |
+
x_resampled = np.linspace(0, original_len - 1, target_len)
|
324 |
+
# Interpolate using numpy
|
325 |
+
resampled_audio_np = np.interp(
|
326 |
+
x_resampled, x_original, generated_audio_np
|
327 |
+
)
|
328 |
+
final_audio_np = resampled_audio_np.astype(np.float32) # Ensure float32
|
329 |
+
logger.info(f"Audio resampled for {speed_factor:.2f}x speed.")
|
330 |
+
else:
|
331 |
+
logger.warning(
|
332 |
+
f"Skipping speed adjustment (factor: {speed_factor:.2f}). Target length invalid ({target_len}) or no change needed."
|
333 |
+
)
|
334 |
+
final_audio_np = generated_audio_np # Use original audio
|
335 |
+
else:
|
336 |
+
logger.info("Speed factor is 1.0, no speed adjustment needed.")
|
337 |
+
final_audio_np = generated_audio_np # No speed change needed
|
338 |
+
|
339 |
+
# Ensure output is float32 (DAC output should be, but good practice)
|
340 |
+
if final_audio_np.dtype != np.float32:
|
341 |
+
logger.warning(
|
342 |
+
f"Generated audio was not float32 ({final_audio_np.dtype}), converting."
|
343 |
+
)
|
344 |
+
final_audio_np = final_audio_np.astype(np.float32)
|
345 |
+
|
346 |
+
logger.info(
|
347 |
+
f"Final audio ready. Shape: {final_audio_np.shape}, dtype: {final_audio_np.dtype}"
|
348 |
+
)
|
349 |
+
# Return the processed audio and the expected sample rate
|
350 |
+
return final_audio_np, EXPECTED_SAMPLE_RATE
|
351 |
+
|
352 |
+
except Exception as e:
|
353 |
+
logger.error(
|
354 |
+
f"Error during Dia generation or post-processing: {e}", exc_info=True
|
355 |
+
)
|
356 |
+
return None # Return None on failure
|
models.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# models.py
|
2 |
+
# Pydantic models for API requests and potentially responses
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from typing import Optional, Literal
|
6 |
+
|
7 |
+
# --- Request Models ---
|
8 |
+
|
9 |
+
|
10 |
+
class OpenAITTSRequest(BaseModel):
|
11 |
+
"""Request model compatible with the OpenAI TTS API."""
|
12 |
+
|
13 |
+
model: str = Field(
|
14 |
+
default="dia-1.6b",
|
15 |
+
description="Model identifier (ignored by this server, always uses Dia). Included for compatibility.",
|
16 |
+
)
|
17 |
+
input: str = Field(..., description="The text to synthesize.")
|
18 |
+
voice: str = Field(
|
19 |
+
default="S1",
|
20 |
+
description="Voice mode or reference audio filename. Examples: 'S1', 'S2', 'dialogue', 'my_reference.wav'.",
|
21 |
+
)
|
22 |
+
response_format: Literal["opus", "wav"] = Field(
|
23 |
+
default="opus", description="The desired audio output format."
|
24 |
+
)
|
25 |
+
speed: float = Field(
|
26 |
+
default=1.0,
|
27 |
+
ge=0.8,
|
28 |
+
le=1.2, # Dia speed factor range seems narrower
|
29 |
+
description="Adjusts the speed of the generated audio (0.8 to 1.2).",
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class CustomTTSRequest(BaseModel):
|
34 |
+
"""Request model for the custom /tts endpoint."""
|
35 |
+
|
36 |
+
text: str = Field(
|
37 |
+
...,
|
38 |
+
description="The text to synthesize. For 'dialogue' mode, include [S1]/[S2] tags.",
|
39 |
+
)
|
40 |
+
voice_mode: Literal["dialogue", "single_s1", "single_s2", "clone"] = Field(
|
41 |
+
default="single_s1", description="Specifies the generation mode."
|
42 |
+
)
|
43 |
+
clone_reference_filename: Optional[str] = Field(
|
44 |
+
default=None,
|
45 |
+
description="Filename of the reference audio within the configured reference path (required if voice_mode is 'clone').",
|
46 |
+
)
|
47 |
+
output_format: Literal["opus", "wav"] = Field(
|
48 |
+
default="opus", description="The desired audio output format."
|
49 |
+
)
|
50 |
+
# Dia-specific generation parameters
|
51 |
+
max_tokens: Optional[int] = Field(
|
52 |
+
default=None,
|
53 |
+
gt=0,
|
54 |
+
description="Maximum number of audio tokens to generate (defaults to model's internal config value).",
|
55 |
+
)
|
56 |
+
cfg_scale: float = Field(
|
57 |
+
default=3.0,
|
58 |
+
ge=1.0,
|
59 |
+
le=5.0,
|
60 |
+
description="Classifier-Free Guidance scale (1.0-5.0).",
|
61 |
+
)
|
62 |
+
temperature: float = Field(
|
63 |
+
default=1.3, ge=1.0, le=1.5, description="Sampling temperature (1.0-1.5)."
|
64 |
+
)
|
65 |
+
top_p: float = Field(
|
66 |
+
default=0.95,
|
67 |
+
ge=0.8,
|
68 |
+
le=1.0,
|
69 |
+
description="Nucleus sampling probability (0.8-1.0).",
|
70 |
+
)
|
71 |
+
speed_factor: float = Field(
|
72 |
+
default=0.94,
|
73 |
+
ge=0.8,
|
74 |
+
le=1.0, # Dia's default range seems to be <= 1.0
|
75 |
+
description="Adjusts the speed of the generated audio (0.8 to 1.0).",
|
76 |
+
)
|
77 |
+
cfg_filter_top_k: int = Field(
|
78 |
+
default=35, ge=15, le=50, description="Top k filter for CFG guidance (15-50)."
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
# --- Response Models (Optional, can be simple dicts too) ---
|
83 |
+
|
84 |
+
|
85 |
+
class TTSResponse(BaseModel):
|
86 |
+
"""Basic response model for successful generation (if returning JSON)."""
|
87 |
+
|
88 |
+
request_id: str
|
89 |
+
status: str = "completed"
|
90 |
+
generation_time_sec: float
|
91 |
+
output_url: Optional[str] = None # If saving file and returning URL
|
92 |
+
|
93 |
+
|
94 |
+
class ErrorResponse(BaseModel):
|
95 |
+
"""Error response model."""
|
96 |
+
|
97 |
+
detail: str
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# requirements.txt
|
2 |
+
|
3 |
+
# Core Web Framework
|
4 |
+
fastapi
|
5 |
+
uvicorn[standard]
|
6 |
+
|
7 |
+
# Machine Learning & Audio
|
8 |
+
torch
|
9 |
+
torchaudio
|
10 |
+
numpy
|
11 |
+
soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu)
|
12 |
+
huggingface_hub
|
13 |
+
descript-audio-codec
|
14 |
+
safetensors
|
15 |
+
|
16 |
+
# Configuration & Utilities
|
17 |
+
pydantic
|
18 |
+
python-dotenv
|
19 |
+
Jinja2
|
20 |
+
python-multipart # For potential file uploads in UI
|
21 |
+
requests # For health checks or other potential uses
|
22 |
+
PyYAML # For parsing presets.yaml
|
server.py
ADDED
@@ -0,0 +1,1061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# server.py
|
2 |
+
# Main FastAPI server for Dia TTS
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
import os
|
8 |
+
import io
|
9 |
+
import uuid
|
10 |
+
import sys
|
11 |
+
import shutil # For file copying
|
12 |
+
import yaml # For loading presets
|
13 |
+
from datetime import datetime
|
14 |
+
from contextlib import asynccontextmanager
|
15 |
+
from typing import Optional, Literal, List, Dict, Any
|
16 |
+
import webbrowser
|
17 |
+
import threading
|
18 |
+
import time
|
19 |
+
|
20 |
+
from fastapi import (
|
21 |
+
FastAPI,
|
22 |
+
HTTPException,
|
23 |
+
Request,
|
24 |
+
Response,
|
25 |
+
Form,
|
26 |
+
UploadFile,
|
27 |
+
File,
|
28 |
+
BackgroundTasks,
|
29 |
+
)
|
30 |
+
from fastapi.responses import (
|
31 |
+
StreamingResponse,
|
32 |
+
JSONResponse,
|
33 |
+
HTMLResponse,
|
34 |
+
RedirectResponse,
|
35 |
+
)
|
36 |
+
from fastapi.staticfiles import StaticFiles
|
37 |
+
from fastapi.templating import Jinja2Templates
|
38 |
+
import uvicorn
|
39 |
+
import numpy as np
|
40 |
+
|
41 |
+
# Internal imports
|
42 |
+
from config import (
|
43 |
+
config_manager,
|
44 |
+
get_host,
|
45 |
+
get_port,
|
46 |
+
get_output_path,
|
47 |
+
get_reference_audio_path,
|
48 |
+
# register_config_routes is now defined locally
|
49 |
+
get_model_cache_path,
|
50 |
+
get_model_repo_id,
|
51 |
+
get_model_config_filename,
|
52 |
+
get_model_weights_filename,
|
53 |
+
# Generation default getters
|
54 |
+
get_gen_default_speed_factor,
|
55 |
+
get_gen_default_cfg_scale,
|
56 |
+
get_gen_default_temperature,
|
57 |
+
get_gen_default_top_p,
|
58 |
+
get_gen_default_cfg_filter_top_k,
|
59 |
+
DEFAULT_CONFIG,
|
60 |
+
)
|
61 |
+
from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse
|
62 |
+
import engine
|
63 |
+
from engine import (
|
64 |
+
load_model as load_dia_model,
|
65 |
+
generate_speech,
|
66 |
+
EXPECTED_SAMPLE_RATE,
|
67 |
+
)
|
68 |
+
from utils import encode_audio, save_audio_to_file, PerformanceMonitor
|
69 |
+
|
70 |
+
# Configure logging (Basic setup, can be enhanced)
|
71 |
+
logging.basicConfig(
|
72 |
+
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
73 |
+
)
|
74 |
+
# Reduce verbosity of noisy libraries if needed
|
75 |
+
# logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
76 |
+
# logging.getLogger("watchfiles").setLevel(logging.WARNING)
|
77 |
+
logger = logging.getLogger(__name__) # Logger for this module
|
78 |
+
|
79 |
+
# --- Global Variables & Constants ---
|
80 |
+
PRESETS_FILE = "ui/presets.yaml"
|
81 |
+
loaded_presets: List[Dict[str, Any]] = [] # Cache presets in memory
|
82 |
+
startup_complete_event = threading.Event()
|
83 |
+
|
84 |
+
# --- Helper Functions ---
|
85 |
+
|
86 |
+
|
87 |
+
def load_presets():
|
88 |
+
"""Loads presets from the YAML file."""
|
89 |
+
global loaded_presets
|
90 |
+
try:
|
91 |
+
if os.path.exists(PRESETS_FILE):
|
92 |
+
with open(PRESETS_FILE, "r", encoding="utf-8") as f:
|
93 |
+
loaded_presets = yaml.safe_load(f)
|
94 |
+
if not isinstance(loaded_presets, list):
|
95 |
+
logger.error(
|
96 |
+
f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded."
|
97 |
+
)
|
98 |
+
loaded_presets = []
|
99 |
+
else:
|
100 |
+
logger.info(
|
101 |
+
f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}."
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
logger.warning(
|
105 |
+
f"Presets file not found at '{PRESETS_FILE}'. No presets will be available."
|
106 |
+
)
|
107 |
+
loaded_presets = []
|
108 |
+
except yaml.YAMLError as e:
|
109 |
+
logger.error(
|
110 |
+
f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True
|
111 |
+
)
|
112 |
+
loaded_presets = []
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True)
|
115 |
+
loaded_presets = []
|
116 |
+
|
117 |
+
|
118 |
+
def get_valid_reference_files() -> list[str]:
|
119 |
+
"""Gets a list of valid audio files (.wav, .mp3) from the reference directory."""
|
120 |
+
ref_path = get_reference_audio_path()
|
121 |
+
valid_files = []
|
122 |
+
allowed_extensions = (".wav", ".mp3")
|
123 |
+
try:
|
124 |
+
if os.path.isdir(ref_path):
|
125 |
+
for filename in os.listdir(ref_path):
|
126 |
+
if filename.lower().endswith(allowed_extensions):
|
127 |
+
# Optional: Add check for file size or basic validity if needed
|
128 |
+
valid_files.append(filename)
|
129 |
+
else:
|
130 |
+
logger.warning(f"Reference audio directory not found: {ref_path}")
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(
|
133 |
+
f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True
|
134 |
+
)
|
135 |
+
return sorted(valid_files)
|
136 |
+
|
137 |
+
|
138 |
+
def sanitize_filename(filename: str) -> str:
|
139 |
+
"""Removes potentially unsafe characters and path components from a filename."""
|
140 |
+
# Remove directory separators
|
141 |
+
filename = os.path.basename(filename)
|
142 |
+
# Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore.
|
143 |
+
safe_chars = set(
|
144 |
+
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-"
|
145 |
+
)
|
146 |
+
sanitized = "".join(c if c in safe_chars else "_" for c in filename)
|
147 |
+
# Prevent names starting with dot or consisting only of dots/spaces
|
148 |
+
if not sanitized or sanitized.lstrip("._ ") == "":
|
149 |
+
return f"uploaded_file_{uuid.uuid4().hex[:8]}" # Generate a safe fallback name
|
150 |
+
# Limit length
|
151 |
+
max_len = 100
|
152 |
+
if len(sanitized) > max_len:
|
153 |
+
name, ext = os.path.splitext(sanitized)
|
154 |
+
sanitized = name[: max_len - len(ext)] + ext
|
155 |
+
return sanitized
|
156 |
+
|
157 |
+
|
158 |
+
# --- Application Lifespan (Startup/Shutdown) ---
|
159 |
+
@asynccontextmanager
|
160 |
+
async def lifespan(app: FastAPI):
|
161 |
+
"""Application lifespan manager for startup/shutdown."""
|
162 |
+
model_loaded_successfully = False # Flag to track success
|
163 |
+
try:
|
164 |
+
logger.info("Starting Dia TTS server initialization...")
|
165 |
+
# Ensure base directories exist
|
166 |
+
os.makedirs(get_output_path(), exist_ok=True)
|
167 |
+
os.makedirs(get_reference_audio_path(), exist_ok=True)
|
168 |
+
os.makedirs(get_model_cache_path(), exist_ok=True)
|
169 |
+
os.makedirs("ui", exist_ok=True)
|
170 |
+
os.makedirs("static", exist_ok=True)
|
171 |
+
|
172 |
+
# Load presets from YAML file
|
173 |
+
load_presets()
|
174 |
+
|
175 |
+
# Load the main TTS model during startup
|
176 |
+
if not load_dia_model():
|
177 |
+
# Model loading failed
|
178 |
+
error_msg = (
|
179 |
+
"CRITICAL: Failed to load Dia model on startup. Server cannot start."
|
180 |
+
)
|
181 |
+
logger.critical(error_msg)
|
182 |
+
# Option 1: Raise an exception to stop Uvicorn startup cleanly
|
183 |
+
raise RuntimeError(error_msg)
|
184 |
+
# Option 2: Force exit (less clean, might bypass some Uvicorn shutdown)
|
185 |
+
# sys.exit(1)
|
186 |
+
else:
|
187 |
+
logger.info("Dia model loaded successfully.")
|
188 |
+
model_loaded_successfully = True
|
189 |
+
|
190 |
+
# Create and start a delayed browser opening thread
|
191 |
+
# IMPORTANT: Create this thread AFTER model loading completes
|
192 |
+
host = get_host()
|
193 |
+
port = get_port()
|
194 |
+
browser_thread = threading.Thread(
|
195 |
+
target=lambda: _delayed_browser_open(host, port), daemon=True
|
196 |
+
)
|
197 |
+
browser_thread.start()
|
198 |
+
|
199 |
+
# --- Signal completion AFTER potentially long operations ---
|
200 |
+
logger.info("Application startup sequence finished. Signaling readiness.")
|
201 |
+
startup_complete_event.set()
|
202 |
+
|
203 |
+
yield # Application runs here
|
204 |
+
|
205 |
+
except Exception as e:
|
206 |
+
# Catch the RuntimeError we raised or any other startup error
|
207 |
+
logger.error(f"Fatal error during application startup: {e}", exc_info=True)
|
208 |
+
# Do NOT set the event here if startup failed
|
209 |
+
# Re-raise the exception or exit to ensure the server stops
|
210 |
+
raise e # Re-raising ensures Uvicorn knows startup failed
|
211 |
+
# Alternatively: sys.exit(1)
|
212 |
+
finally:
|
213 |
+
# Cleanup on shutdown
|
214 |
+
logger.info("Application shutdown initiated...")
|
215 |
+
# Add any specific cleanup needed
|
216 |
+
logger.info("Application shutdown complete.")
|
217 |
+
|
218 |
+
|
219 |
+
def _delayed_browser_open(host, port):
|
220 |
+
"""Opens browser after a short delay to ensure server is ready"""
|
221 |
+
try:
|
222 |
+
# Small delay to ensure Uvicorn is fully ready
|
223 |
+
time.sleep(2)
|
224 |
+
|
225 |
+
display_host = "localhost" if host == "0.0.0.0" else host
|
226 |
+
browser_url = f"http://{display_host}:{port}/"
|
227 |
+
|
228 |
+
# Log to file for debugging
|
229 |
+
with open("browser_thread_debug.log", "a") as f:
|
230 |
+
f.write(f"[{time.time()}] Opening browser at {browser_url}\n")
|
231 |
+
|
232 |
+
# Try to use logger as well (might work at this point)
|
233 |
+
try:
|
234 |
+
logger.info(f"Opening browser at {browser_url}")
|
235 |
+
except:
|
236 |
+
pass
|
237 |
+
|
238 |
+
# Open browser directly without health checks
|
239 |
+
webbrowser.open(browser_url)
|
240 |
+
|
241 |
+
except Exception as e:
|
242 |
+
with open("browser_thread_debug.log", "a") as f:
|
243 |
+
f.write(f"[{time.time()}] Browser open error: {str(e)}\n")
|
244 |
+
|
245 |
+
|
246 |
+
# --- FastAPI App Initialization ---
|
247 |
+
app = FastAPI(
|
248 |
+
title="Dia TTS Server",
|
249 |
+
description="Text-to-Speech server using the Dia model, providing API and Web UI.",
|
250 |
+
version="1.1.0", # Incremented version
|
251 |
+
lifespan=lifespan,
|
252 |
+
)
|
253 |
+
|
254 |
+
# List of folders to check/create
|
255 |
+
folders = ["reference_audio", "model_cache", "outputs"]
|
256 |
+
|
257 |
+
# Check each folder and create if it doesn't exist
|
258 |
+
for folder in folders:
|
259 |
+
if not os.path.exists(folder):
|
260 |
+
os.makedirs(folder)
|
261 |
+
print(f"Created directory: {folder}")
|
262 |
+
|
263 |
+
# --- Static Files and Templates ---
|
264 |
+
# Serve generated audio files from the configured output path
|
265 |
+
app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs")
|
266 |
+
# Serve UI files (CSS, JS) from the 'ui' directory
|
267 |
+
app.mount("/ui", StaticFiles(directory="ui"), name="ui_static")
|
268 |
+
# Initialize Jinja2 templates to look in the 'ui' directory
|
269 |
+
templates = Jinja2Templates(directory="ui")
|
270 |
+
|
271 |
+
|
272 |
+
# --- Configuration Routes Definition ---
|
273 |
+
# Defined locally now instead of importing from config.py
|
274 |
+
def register_config_routes(app: FastAPI):
|
275 |
+
"""Adds configuration management endpoints to the FastAPI app."""
|
276 |
+
logger.info(
|
277 |
+
"Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)."
|
278 |
+
)
|
279 |
+
|
280 |
+
@app.get(
|
281 |
+
"/get_config",
|
282 |
+
tags=["Configuration"],
|
283 |
+
summary="Get current server configuration",
|
284 |
+
)
|
285 |
+
async def get_current_config():
|
286 |
+
"""Returns the current server configuration values (from .env or defaults)."""
|
287 |
+
logger.info("Request received for /get_config")
|
288 |
+
return JSONResponse(content=config_manager.get_all())
|
289 |
+
|
290 |
+
@app.post(
|
291 |
+
"/save_config", tags=["Configuration"], summary="Save server configuration"
|
292 |
+
)
|
293 |
+
async def save_new_config(request: Request):
|
294 |
+
"""
|
295 |
+
Saves updated server configuration values (Host, Port, Model paths, etc.)
|
296 |
+
to the .env file. Requires server restart to apply most changes.
|
297 |
+
"""
|
298 |
+
logger.info("Request received for /save_config")
|
299 |
+
try:
|
300 |
+
new_config_data = await request.json()
|
301 |
+
if not isinstance(new_config_data, dict):
|
302 |
+
raise ValueError("Request body must be a JSON object.")
|
303 |
+
logger.debug(f"Received server config data to save: {new_config_data}")
|
304 |
+
|
305 |
+
# Filter data to only include keys present in DEFAULT_CONFIG
|
306 |
+
filtered_data = {
|
307 |
+
k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG
|
308 |
+
}
|
309 |
+
unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys())
|
310 |
+
if unknown_keys:
|
311 |
+
logger.warning(
|
312 |
+
f"Ignoring unknown keys in save_config request: {unknown_keys}"
|
313 |
+
)
|
314 |
+
|
315 |
+
config_manager.update(filtered_data) # Update in memory first
|
316 |
+
if config_manager.save(): # Attempt to save to .env
|
317 |
+
logger.info("Server configuration saved successfully to .env.")
|
318 |
+
return JSONResponse(
|
319 |
+
content={
|
320 |
+
"message": "Server configuration saved. Restart server to apply changes."
|
321 |
+
}
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
logger.error("Failed to save server configuration to .env file.")
|
325 |
+
raise HTTPException(
|
326 |
+
status_code=500, detail="Failed to save configuration file."
|
327 |
+
)
|
328 |
+
except ValueError as ve:
|
329 |
+
logger.error(f"Invalid data format for /save_config: {ve}")
|
330 |
+
raise HTTPException(
|
331 |
+
status_code=400, detail=f"Invalid request data: {str(ve)}"
|
332 |
+
)
|
333 |
+
except Exception as e:
|
334 |
+
logger.error(f"Error processing /save_config request: {e}", exc_info=True)
|
335 |
+
raise HTTPException(
|
336 |
+
status_code=500, detail=f"Internal server error during save: {str(e)}"
|
337 |
+
)
|
338 |
+
|
339 |
+
@app.post(
|
340 |
+
"/save_generation_defaults",
|
341 |
+
tags=["Configuration"],
|
342 |
+
summary="Save default generation parameters",
|
343 |
+
)
|
344 |
+
async def save_generation_defaults(request: Request):
|
345 |
+
"""
|
346 |
+
Saves the provided generation parameters (speed, cfg, temp, etc.)
|
347 |
+
as the new defaults in the .env file. These are loaded by the UI on startup.
|
348 |
+
"""
|
349 |
+
logger.info("Request received for /save_generation_defaults")
|
350 |
+
try:
|
351 |
+
gen_params = await request.json()
|
352 |
+
if not isinstance(gen_params, dict):
|
353 |
+
raise ValueError("Request body must be a JSON object.")
|
354 |
+
logger.debug(f"Received generation defaults to save: {gen_params}")
|
355 |
+
|
356 |
+
# Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR')
|
357 |
+
defaults_to_save = {}
|
358 |
+
key_map = {
|
359 |
+
"speed_factor": "GEN_DEFAULT_SPEED_FACTOR",
|
360 |
+
"cfg_scale": "GEN_DEFAULT_CFG_SCALE",
|
361 |
+
"temperature": "GEN_DEFAULT_TEMPERATURE",
|
362 |
+
"top_p": "GEN_DEFAULT_TOP_P",
|
363 |
+
"cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K",
|
364 |
+
}
|
365 |
+
valid_keys_found = False
|
366 |
+
for ui_key, env_key in key_map.items():
|
367 |
+
if ui_key in gen_params:
|
368 |
+
# Basic validation could be added here (e.g., check if float/int)
|
369 |
+
defaults_to_save[env_key] = str(
|
370 |
+
gen_params[ui_key]
|
371 |
+
) # Ensure saving as string
|
372 |
+
valid_keys_found = True
|
373 |
+
else:
|
374 |
+
logger.warning(
|
375 |
+
f"Missing expected key '{ui_key}' in save_generation_defaults request."
|
376 |
+
)
|
377 |
+
|
378 |
+
if not valid_keys_found:
|
379 |
+
raise ValueError("No valid generation parameters found in the request.")
|
380 |
+
|
381 |
+
config_manager.update(defaults_to_save) # Update in memory
|
382 |
+
if (
|
383 |
+
config_manager.save()
|
384 |
+
): # Save all current config (including these) to .env
|
385 |
+
logger.info("Generation defaults saved successfully to .env.")
|
386 |
+
return JSONResponse(content={"message": "Generation defaults saved."})
|
387 |
+
else:
|
388 |
+
logger.error("Failed to save generation defaults to .env file.")
|
389 |
+
raise HTTPException(
|
390 |
+
status_code=500, detail="Failed to save configuration file."
|
391 |
+
)
|
392 |
+
except ValueError as ve:
|
393 |
+
logger.error(f"Invalid data format for /save_generation_defaults: {ve}")
|
394 |
+
raise HTTPException(
|
395 |
+
status_code=400, detail=f"Invalid request data: {str(ve)}"
|
396 |
+
)
|
397 |
+
except Exception as e:
|
398 |
+
logger.error(
|
399 |
+
f"Error processing /save_generation_defaults request: {e}",
|
400 |
+
exc_info=True,
|
401 |
+
)
|
402 |
+
raise HTTPException(
|
403 |
+
status_code=500, detail=f"Internal server error during save: {str(e)}"
|
404 |
+
)
|
405 |
+
|
406 |
+
@app.post(
|
407 |
+
"/restart_server",
|
408 |
+
tags=["Configuration"],
|
409 |
+
summary="Attempt to restart the server",
|
410 |
+
)
|
411 |
+
async def trigger_server_restart(background_tasks: BackgroundTasks):
|
412 |
+
"""
|
413 |
+
Attempts to restart the server process.
|
414 |
+
NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload,
|
415 |
+
or managed by systemd/supervisor). A simple exit might just stop the process.
|
416 |
+
This implementation attempts a clean exit, relying on the runner to restart it.
|
417 |
+
"""
|
418 |
+
logger.warning("Received request to restart server via API.")
|
419 |
+
|
420 |
+
def _do_restart():
|
421 |
+
time.sleep(1) # Short delay to allow response to be sent
|
422 |
+
logger.warning("Attempting clean exit for restart...")
|
423 |
+
# Option 1: Clean exit (relies on Uvicorn reload or process manager)
|
424 |
+
sys.exit(0)
|
425 |
+
# Option 2: Forceful re-execution (use with caution, might not work as expected)
|
426 |
+
# try:
|
427 |
+
# logger.warning("Attempting os.execv for restart...")
|
428 |
+
# os.execv(sys.executable, ['python'] + sys.argv)
|
429 |
+
# except Exception as exec_e:
|
430 |
+
# logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.")
|
431 |
+
# # Fallback to sys.exit if execv fails
|
432 |
+
# sys.exit(1)
|
433 |
+
|
434 |
+
background_tasks.add_task(_do_restart)
|
435 |
+
return JSONResponse(
|
436 |
+
content={
|
437 |
+
"message": "Restart signal sent. Server should restart shortly if run with auto-reload."
|
438 |
+
}
|
439 |
+
)
|
440 |
+
|
441 |
+
|
442 |
+
# --- Register Configuration Routes ---
|
443 |
+
register_config_routes(app)
|
444 |
+
|
445 |
+
|
446 |
+
# --- API Endpoints ---
|
447 |
+
|
448 |
+
|
449 |
+
@app.post(
|
450 |
+
"/v1/audio/speech",
|
451 |
+
response_class=StreamingResponse,
|
452 |
+
tags=["TTS Generation"],
|
453 |
+
summary="Generate speech (OpenAI compatible)",
|
454 |
+
)
|
455 |
+
async def openai_tts_endpoint(request: OpenAITTSRequest):
|
456 |
+
"""
|
457 |
+
Generates speech audio from text, compatible with the OpenAI TTS API structure.
|
458 |
+
Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone).
|
459 |
+
"""
|
460 |
+
monitor = PerformanceMonitor()
|
461 |
+
monitor.record("Request received")
|
462 |
+
logger.info(
|
463 |
+
f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'"
|
464 |
+
)
|
465 |
+
logger.debug(f"Input text (start): '{request.input[:100]}...'")
|
466 |
+
|
467 |
+
voice_mode = "single_s1" # Default if mapping fails
|
468 |
+
clone_ref_file = None
|
469 |
+
ref_path = get_reference_audio_path()
|
470 |
+
|
471 |
+
# --- Map OpenAI 'voice' parameter to Dia's modes ---
|
472 |
+
voice_param = request.voice.strip()
|
473 |
+
if voice_param.lower() == "dialogue":
|
474 |
+
voice_mode = "dialogue"
|
475 |
+
elif voice_param.lower() == "s1":
|
476 |
+
voice_mode = "single_s1"
|
477 |
+
elif voice_param.lower() == "s2":
|
478 |
+
voice_mode = "single_s2"
|
479 |
+
# Check if it looks like a filename for cloning (allow .wav or .mp3)
|
480 |
+
elif voice_param.lower().endswith((".wav", ".mp3")):
|
481 |
+
potential_path = os.path.join(ref_path, voice_param)
|
482 |
+
# Check if the file actually exists in the reference directory
|
483 |
+
if os.path.isfile(potential_path):
|
484 |
+
voice_mode = "clone"
|
485 |
+
clone_ref_file = voice_param # Use the provided filename
|
486 |
+
logger.info(
|
487 |
+
f"OpenAI request mapped to clone mode with file: {clone_ref_file}"
|
488 |
+
)
|
489 |
+
else:
|
490 |
+
logger.warning(
|
491 |
+
f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode."
|
492 |
+
)
|
493 |
+
# Fallback to default 'single_s1' if file not found
|
494 |
+
else:
|
495 |
+
logger.warning(
|
496 |
+
f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'."
|
497 |
+
)
|
498 |
+
# Fallback for any other value
|
499 |
+
|
500 |
+
monitor.record("Parameters processed")
|
501 |
+
|
502 |
+
try:
|
503 |
+
# Call the core engine function using mapped parameters
|
504 |
+
result = generate_speech(
|
505 |
+
text=request.input,
|
506 |
+
voice_mode=voice_mode,
|
507 |
+
clone_reference_filename=clone_ref_file,
|
508 |
+
speed_factor=request.speed, # Pass speed factor for post-processing
|
509 |
+
# Use Dia's configured defaults for other generation params unless mapped
|
510 |
+
max_tokens=None, # Let Dia use its default unless specified otherwise
|
511 |
+
cfg_scale=get_gen_default_cfg_scale(), # Use saved defaults
|
512 |
+
temperature=get_gen_default_temperature(),
|
513 |
+
top_p=get_gen_default_top_p(),
|
514 |
+
cfg_filter_top_k=get_gen_default_cfg_filter_top_k(),
|
515 |
+
)
|
516 |
+
monitor.record("Generation complete")
|
517 |
+
|
518 |
+
if result is None:
|
519 |
+
logger.error("Speech generation failed (engine returned None).")
|
520 |
+
raise HTTPException(status_code=500, detail="Speech generation failed.")
|
521 |
+
|
522 |
+
audio_array, sample_rate = result
|
523 |
+
|
524 |
+
if sample_rate != EXPECTED_SAMPLE_RATE:
|
525 |
+
logger.warning(
|
526 |
+
f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}."
|
527 |
+
)
|
528 |
+
# Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for
|
529 |
+
sample_rate = EXPECTED_SAMPLE_RATE
|
530 |
+
|
531 |
+
# Encode the audio in memory to the requested format
|
532 |
+
encoded_audio = encode_audio(audio_array, sample_rate, request.response_format)
|
533 |
+
monitor.record("Audio encoding complete")
|
534 |
+
|
535 |
+
if encoded_audio is None:
|
536 |
+
logger.error(f"Failed to encode audio to format: {request.response_format}")
|
537 |
+
raise HTTPException(
|
538 |
+
status_code=500,
|
539 |
+
detail=f"Failed to encode audio to {request.response_format}",
|
540 |
+
)
|
541 |
+
|
542 |
+
# Determine the correct media type for the response header
|
543 |
+
media_type = "audio/opus" if request.response_format == "opus" else "audio/wav"
|
544 |
+
# Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI.
|
545 |
+
|
546 |
+
logger.info(
|
547 |
+
f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}"
|
548 |
+
)
|
549 |
+
logger.debug(monitor.report())
|
550 |
+
|
551 |
+
# Stream the encoded audio back to the client
|
552 |
+
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
|
553 |
+
|
554 |
+
except HTTPException as http_exc:
|
555 |
+
# Re-raise HTTPExceptions directly (e.g., from parameter validation)
|
556 |
+
logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}")
|
557 |
+
raise http_exc
|
558 |
+
except Exception as e:
|
559 |
+
logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True)
|
560 |
+
logger.debug(monitor.report())
|
561 |
+
# Return generic server error for unexpected issues
|
562 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
563 |
+
|
564 |
+
|
565 |
+
@app.post(
|
566 |
+
"/tts",
|
567 |
+
response_class=StreamingResponse,
|
568 |
+
tags=["TTS Generation"],
|
569 |
+
summary="Generate speech (Custom parameters)",
|
570 |
+
)
|
571 |
+
async def custom_tts_endpoint(request: CustomTTSRequest):
|
572 |
+
"""
|
573 |
+
Generates speech audio from text using explicit Dia parameters.
|
574 |
+
"""
|
575 |
+
monitor = PerformanceMonitor()
|
576 |
+
monitor.record("Request received")
|
577 |
+
logger.info(
|
578 |
+
f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'"
|
579 |
+
)
|
580 |
+
logger.debug(f"Input text (start): '{request.text[:100]}...'")
|
581 |
+
logger.debug(
|
582 |
+
f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}"
|
583 |
+
)
|
584 |
+
|
585 |
+
clone_ref_file = None
|
586 |
+
if request.voice_mode == "clone":
|
587 |
+
if not request.clone_reference_filename:
|
588 |
+
raise HTTPException(
|
589 |
+
status_code=400, # Bad request
|
590 |
+
detail="Missing 'clone_reference_filename' which is required for clone mode.",
|
591 |
+
)
|
592 |
+
ref_path = get_reference_audio_path()
|
593 |
+
potential_path = os.path.join(ref_path, request.clone_reference_filename)
|
594 |
+
if not os.path.isfile(potential_path):
|
595 |
+
logger.error(
|
596 |
+
f"Reference audio file not found for clone mode: {potential_path}"
|
597 |
+
)
|
598 |
+
raise HTTPException(
|
599 |
+
status_code=404, # Not found
|
600 |
+
detail=f"Reference audio file not found: {request.clone_reference_filename}",
|
601 |
+
)
|
602 |
+
clone_ref_file = request.clone_reference_filename
|
603 |
+
logger.info(f"Custom request using clone mode with file: {clone_ref_file}")
|
604 |
+
|
605 |
+
monitor.record("Parameters processed")
|
606 |
+
|
607 |
+
try:
|
608 |
+
# Call the core engine function with parameters from the request
|
609 |
+
result = generate_speech(
|
610 |
+
text=request.text,
|
611 |
+
voice_mode=request.voice_mode,
|
612 |
+
clone_reference_filename=clone_ref_file,
|
613 |
+
max_tokens=request.max_tokens, # Pass user value or None
|
614 |
+
cfg_scale=request.cfg_scale,
|
615 |
+
temperature=request.temperature,
|
616 |
+
top_p=request.top_p,
|
617 |
+
speed_factor=request.speed_factor, # For post-processing
|
618 |
+
cfg_filter_top_k=request.cfg_filter_top_k,
|
619 |
+
)
|
620 |
+
monitor.record("Generation complete")
|
621 |
+
|
622 |
+
if result is None:
|
623 |
+
logger.error("Speech generation failed (engine returned None).")
|
624 |
+
raise HTTPException(status_code=500, detail="Speech generation failed.")
|
625 |
+
|
626 |
+
audio_array, sample_rate = result
|
627 |
+
|
628 |
+
if sample_rate != EXPECTED_SAMPLE_RATE:
|
629 |
+
logger.warning(
|
630 |
+
f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}."
|
631 |
+
)
|
632 |
+
sample_rate = EXPECTED_SAMPLE_RATE
|
633 |
+
|
634 |
+
# Encode the audio in memory
|
635 |
+
encoded_audio = encode_audio(audio_array, sample_rate, request.output_format)
|
636 |
+
monitor.record("Audio encoding complete")
|
637 |
+
|
638 |
+
if encoded_audio is None:
|
639 |
+
logger.error(f"Failed to encode audio to format: {request.output_format}")
|
640 |
+
raise HTTPException(
|
641 |
+
status_code=500,
|
642 |
+
detail=f"Failed to encode audio to {request.output_format}",
|
643 |
+
)
|
644 |
+
|
645 |
+
# Determine media type
|
646 |
+
media_type = "audio/opus" if request.output_format == "opus" else "audio/wav"
|
647 |
+
|
648 |
+
logger.info(
|
649 |
+
f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}"
|
650 |
+
)
|
651 |
+
logger.debug(monitor.report())
|
652 |
+
|
653 |
+
# Stream the response
|
654 |
+
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
|
655 |
+
|
656 |
+
except HTTPException as http_exc:
|
657 |
+
logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}")
|
658 |
+
raise http_exc
|
659 |
+
except Exception as e:
|
660 |
+
logger.error(f"Error processing custom TTS request: {e}", exc_info=True)
|
661 |
+
logger.debug(monitor.report())
|
662 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
663 |
+
|
664 |
+
|
665 |
+
# --- Web UI Endpoints ---
|
666 |
+
|
667 |
+
|
668 |
+
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
669 |
+
async def get_web_ui(request: Request):
|
670 |
+
"""Serves the main TTS web interface."""
|
671 |
+
logger.info("Serving TTS Web UI (index.html)")
|
672 |
+
# Get current list of reference files for the clone dropdown
|
673 |
+
reference_files = get_valid_reference_files()
|
674 |
+
# Get current server config and default generation params
|
675 |
+
current_config = config_manager.get_all()
|
676 |
+
default_gen_params = {
|
677 |
+
"speed_factor": get_gen_default_speed_factor(),
|
678 |
+
"cfg_scale": get_gen_default_cfg_scale(),
|
679 |
+
"temperature": get_gen_default_temperature(),
|
680 |
+
"top_p": get_gen_default_top_p(),
|
681 |
+
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
|
682 |
+
}
|
683 |
+
|
684 |
+
return templates.TemplateResponse(
|
685 |
+
"index.html", # Use the renamed file
|
686 |
+
{
|
687 |
+
"request": request,
|
688 |
+
"reference_files": reference_files,
|
689 |
+
"config": current_config, # Pass current server config
|
690 |
+
"presets": loaded_presets, # Pass loaded presets
|
691 |
+
"default_gen_params": default_gen_params, # Pass default gen params
|
692 |
+
# Add other variables needed by the template for initial state
|
693 |
+
"error": None,
|
694 |
+
"success": None,
|
695 |
+
"output_file_url": None,
|
696 |
+
"generation_time": None,
|
697 |
+
"submitted_text": "",
|
698 |
+
"submitted_voice_mode": "dialogue", # Default to combined mode
|
699 |
+
"submitted_clone_file": None,
|
700 |
+
# Initial generation params will be set by default_gen_params
|
701 |
+
},
|
702 |
+
)
|
703 |
+
|
704 |
+
|
705 |
+
@app.post("/web/generate", response_class=HTMLResponse, include_in_schema=False)
|
706 |
+
async def handle_web_ui_generate(
|
707 |
+
request: Request,
|
708 |
+
text: str = Form(...),
|
709 |
+
voice_mode: Literal["dialogue", "clone"] = Form(...), # Updated modes
|
710 |
+
clone_reference_select: Optional[str] = Form(None),
|
711 |
+
# Generation parameters from form
|
712 |
+
speed_factor: float = Form(...), # Make required or use Depends with default
|
713 |
+
cfg_scale: float = Form(...),
|
714 |
+
temperature: float = Form(...),
|
715 |
+
top_p: float = Form(...),
|
716 |
+
cfg_filter_top_k: int = Form(...),
|
717 |
+
):
|
718 |
+
"""Handles the generation request from the web UI form."""
|
719 |
+
logger.info(f"Web UI generation request: mode='{voice_mode}'")
|
720 |
+
monitor = PerformanceMonitor()
|
721 |
+
monitor.record("Web request received")
|
722 |
+
|
723 |
+
output_file_url = None
|
724 |
+
generation_time = None
|
725 |
+
error_message = None
|
726 |
+
success_message = None
|
727 |
+
output_filename_base = "dia_output" # Default base name
|
728 |
+
|
729 |
+
# --- Pre-generation Validation ---
|
730 |
+
if not text.strip():
|
731 |
+
error_message = "Please enter some text to synthesize."
|
732 |
+
|
733 |
+
clone_ref_file = None
|
734 |
+
if voice_mode == "clone":
|
735 |
+
if not clone_reference_select or clone_reference_select == "none":
|
736 |
+
error_message = "Please select a reference audio file for clone mode."
|
737 |
+
else:
|
738 |
+
# Verify selected file still exists (important if files can be deleted)
|
739 |
+
ref_path = get_reference_audio_path()
|
740 |
+
potential_path = os.path.join(ref_path, clone_reference_select)
|
741 |
+
if not os.path.isfile(potential_path):
|
742 |
+
error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload."
|
743 |
+
# Invalidate selection
|
744 |
+
clone_ref_file = None
|
745 |
+
clone_reference_select = None # Clear submitted value for re-rendering
|
746 |
+
else:
|
747 |
+
clone_ref_file = clone_reference_select
|
748 |
+
logger.info(f"Using selected reference file: {clone_ref_file}")
|
749 |
+
|
750 |
+
# If validation failed, re-render the page with error and submitted values
|
751 |
+
if error_message:
|
752 |
+
logger.warning(f"Web UI validation error: {error_message}")
|
753 |
+
reference_files = get_valid_reference_files()
|
754 |
+
current_config = config_manager.get_all()
|
755 |
+
default_gen_params = { # Pass defaults again for consistency
|
756 |
+
"speed_factor": get_gen_default_speed_factor(),
|
757 |
+
"cfg_scale": get_gen_default_cfg_scale(),
|
758 |
+
"temperature": get_gen_default_temperature(),
|
759 |
+
"top_p": get_gen_default_top_p(),
|
760 |
+
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
|
761 |
+
}
|
762 |
+
# Pass back the values the user submitted
|
763 |
+
submitted_gen_params = {
|
764 |
+
"speed_factor": speed_factor,
|
765 |
+
"cfg_scale": cfg_scale,
|
766 |
+
"temperature": temperature,
|
767 |
+
"top_p": top_p,
|
768 |
+
"cfg_filter_top_k": cfg_filter_top_k,
|
769 |
+
}
|
770 |
+
|
771 |
+
return templates.TemplateResponse(
|
772 |
+
"index.html",
|
773 |
+
{
|
774 |
+
"request": request,
|
775 |
+
"error": error_message,
|
776 |
+
"reference_files": reference_files,
|
777 |
+
"config": current_config,
|
778 |
+
"presets": loaded_presets,
|
779 |
+
"default_gen_params": default_gen_params, # Base defaults
|
780 |
+
# Submitted values to repopulate form
|
781 |
+
"submitted_text": text,
|
782 |
+
"submitted_voice_mode": voice_mode,
|
783 |
+
"submitted_clone_file": clone_reference_select, # Use potentially invalidated value
|
784 |
+
"submitted_gen_params": submitted_gen_params, # Pass submitted params back
|
785 |
+
# Ensure other necessary template variables are passed
|
786 |
+
"success": None,
|
787 |
+
"output_file_url": None,
|
788 |
+
"generation_time": None,
|
789 |
+
},
|
790 |
+
)
|
791 |
+
|
792 |
+
# --- Generation ---
|
793 |
+
try:
|
794 |
+
monitor.record("Parameters processed")
|
795 |
+
# Call the core engine function
|
796 |
+
result = generate_speech(
|
797 |
+
text=text,
|
798 |
+
voice_mode=voice_mode,
|
799 |
+
clone_reference_filename=clone_ref_file,
|
800 |
+
speed_factor=speed_factor,
|
801 |
+
cfg_scale=cfg_scale,
|
802 |
+
temperature=temperature,
|
803 |
+
top_p=top_p,
|
804 |
+
cfg_filter_top_k=cfg_filter_top_k,
|
805 |
+
max_tokens=None, # Use model default for UI simplicity
|
806 |
+
)
|
807 |
+
monitor.record("Generation complete")
|
808 |
+
|
809 |
+
if result:
|
810 |
+
audio_array, sample_rate = result
|
811 |
+
output_path_base = get_output_path()
|
812 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
813 |
+
# Create a more descriptive filename
|
814 |
+
mode_tag = voice_mode
|
815 |
+
if voice_mode == "clone" and clone_ref_file:
|
816 |
+
safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0])
|
817 |
+
mode_tag = f"clone_{safe_ref_name[:20]}" # Limit length
|
818 |
+
output_filename = (
|
819 |
+
f"{mode_tag}_{timestamp}.wav" # Always save as WAV for simplicity
|
820 |
+
)
|
821 |
+
output_filepath = os.path.join(output_path_base, output_filename)
|
822 |
+
|
823 |
+
# Save the audio to a WAV file
|
824 |
+
saved = save_audio_to_file(audio_array, sample_rate, output_filepath)
|
825 |
+
monitor.record("Audio saved")
|
826 |
+
|
827 |
+
if saved:
|
828 |
+
output_file_url = (
|
829 |
+
f"/outputs/{output_filename}" # URL path for browser access
|
830 |
+
)
|
831 |
+
generation_time = (
|
832 |
+
monitor.events[-1][1] - monitor.start_time
|
833 |
+
) # Time until save complete
|
834 |
+
success_message = f"Audio generated successfully!"
|
835 |
+
logger.info(f"Web UI generated audio saved to: {output_filepath}")
|
836 |
+
else:
|
837 |
+
error_message = "Failed to save generated audio file."
|
838 |
+
logger.error("Failed to save audio file from web UI request.")
|
839 |
+
else:
|
840 |
+
error_message = "Speech generation failed (engine returned None)."
|
841 |
+
logger.error("Speech generation failed for web UI request.")
|
842 |
+
|
843 |
+
except Exception as e:
|
844 |
+
logger.error(f"Error processing web UI TTS request: {e}", exc_info=True)
|
845 |
+
error_message = f"An unexpected error occurred: {str(e)}"
|
846 |
+
|
847 |
+
logger.debug(monitor.report())
|
848 |
+
|
849 |
+
# --- Re-render Template with Results ---
|
850 |
+
reference_files = get_valid_reference_files()
|
851 |
+
current_config = config_manager.get_all()
|
852 |
+
default_gen_params = {
|
853 |
+
"speed_factor": get_gen_default_speed_factor(),
|
854 |
+
"cfg_scale": get_gen_default_cfg_scale(),
|
855 |
+
"temperature": get_gen_default_temperature(),
|
856 |
+
"top_p": get_gen_default_top_p(),
|
857 |
+
"cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
|
858 |
+
}
|
859 |
+
# Pass back submitted values to repopulate form correctly
|
860 |
+
submitted_gen_params = {
|
861 |
+
"speed_factor": speed_factor,
|
862 |
+
"cfg_scale": cfg_scale,
|
863 |
+
"temperature": temperature,
|
864 |
+
"top_p": top_p,
|
865 |
+
"cfg_filter_top_k": cfg_filter_top_k,
|
866 |
+
}
|
867 |
+
|
868 |
+
return templates.TemplateResponse(
|
869 |
+
"index.html",
|
870 |
+
{
|
871 |
+
"request": request,
|
872 |
+
"error": error_message,
|
873 |
+
"success": success_message,
|
874 |
+
"output_file_url": output_file_url,
|
875 |
+
"generation_time": f"{generation_time:.2f}" if generation_time else None,
|
876 |
+
"reference_files": reference_files,
|
877 |
+
"config": current_config,
|
878 |
+
"presets": loaded_presets,
|
879 |
+
"default_gen_params": default_gen_params, # Base defaults
|
880 |
+
# Pass back submitted values
|
881 |
+
"submitted_text": text,
|
882 |
+
"submitted_voice_mode": voice_mode,
|
883 |
+
"submitted_clone_file": clone_ref_file, # Pass the validated filename back
|
884 |
+
"submitted_gen_params": submitted_gen_params, # Pass submitted params back
|
885 |
+
},
|
886 |
+
)
|
887 |
+
|
888 |
+
|
889 |
+
# --- Reference Audio Upload Endpoint ---
|
890 |
+
@app.post(
|
891 |
+
"/upload_reference", tags=["Web UI Helpers"], summary="Upload reference audio files"
|
892 |
+
)
|
893 |
+
async def upload_reference_audio(files: List[UploadFile] = File(...)):
|
894 |
+
"""Handles uploading of reference audio files (.wav, .mp3) for voice cloning."""
|
895 |
+
logger.info(f"Received request to upload {len(files)} reference audio file(s).")
|
896 |
+
ref_path = get_reference_audio_path()
|
897 |
+
uploaded_filenames = []
|
898 |
+
errors = []
|
899 |
+
allowed_mime_types = [
|
900 |
+
"audio/wav",
|
901 |
+
"audio/mpeg",
|
902 |
+
"audio/x-wav",
|
903 |
+
] # Common WAV/MP3 types
|
904 |
+
allowed_extensions = [".wav", ".mp3"]
|
905 |
+
|
906 |
+
for file in files:
|
907 |
+
try:
|
908 |
+
# Basic validation
|
909 |
+
if not file.filename:
|
910 |
+
errors.append("Received file with no filename.")
|
911 |
+
continue
|
912 |
+
|
913 |
+
# Sanitize filename
|
914 |
+
safe_filename = sanitize_filename(file.filename)
|
915 |
+
_, ext = os.path.splitext(safe_filename)
|
916 |
+
if ext.lower() not in allowed_extensions:
|
917 |
+
errors.append(
|
918 |
+
f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}"
|
919 |
+
)
|
920 |
+
continue
|
921 |
+
|
922 |
+
# Check MIME type (more reliable than extension)
|
923 |
+
if file.content_type not in allowed_mime_types:
|
924 |
+
errors.append(
|
925 |
+
f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}"
|
926 |
+
)
|
927 |
+
continue
|
928 |
+
|
929 |
+
# Construct full save path
|
930 |
+
destination_path = os.path.join(ref_path, safe_filename)
|
931 |
+
|
932 |
+
# Prevent overwriting existing files (optional, could add counter)
|
933 |
+
if os.path.exists(destination_path):
|
934 |
+
# Simple approach: skip if exists
|
935 |
+
logger.warning(
|
936 |
+
f"Reference file '{safe_filename}' already exists. Skipping upload."
|
937 |
+
)
|
938 |
+
# Add to list so UI knows it's available, even if not newly uploaded this time
|
939 |
+
if safe_filename not in uploaded_filenames:
|
940 |
+
uploaded_filenames.append(safe_filename)
|
941 |
+
continue
|
942 |
+
# Alternative: add counter like file_1.wav, file_2.wav
|
943 |
+
|
944 |
+
# Save the file using shutil.copyfileobj for efficiency with large files
|
945 |
+
try:
|
946 |
+
with open(destination_path, "wb") as buffer:
|
947 |
+
shutil.copyfileobj(file.file, buffer)
|
948 |
+
logger.info(f"Successfully saved reference file: {destination_path}")
|
949 |
+
uploaded_filenames.append(safe_filename)
|
950 |
+
except Exception as save_exc:
|
951 |
+
errors.append(f"Failed to save file '{safe_filename}': {save_exc}")
|
952 |
+
logger.error(
|
953 |
+
f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}",
|
954 |
+
exc_info=True,
|
955 |
+
)
|
956 |
+
finally:
|
957 |
+
# Ensure the UploadFile resource is closed
|
958 |
+
await file.close()
|
959 |
+
|
960 |
+
except Exception as e:
|
961 |
+
errors.append(
|
962 |
+
f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}"
|
963 |
+
)
|
964 |
+
logger.error(
|
965 |
+
f"Unexpected error processing uploaded file: {e}", exc_info=True
|
966 |
+
)
|
967 |
+
# Ensure file is closed even if other errors occur
|
968 |
+
if file:
|
969 |
+
await file.close()
|
970 |
+
|
971 |
+
# Get the updated list of all valid files in the directory
|
972 |
+
updated_file_list = get_valid_reference_files()
|
973 |
+
|
974 |
+
response_data = {
|
975 |
+
"message": f"Processed {len(files)} file(s).",
|
976 |
+
"uploaded_files": uploaded_filenames, # List of successfully saved *new* files this request
|
977 |
+
"all_reference_files": updated_file_list, # Complete current list
|
978 |
+
"errors": errors,
|
979 |
+
}
|
980 |
+
|
981 |
+
status_code = (
|
982 |
+
200 if not errors or len(errors) < len(files) else 400
|
983 |
+
) # OK if at least one succeeded, else Bad Request
|
984 |
+
if errors:
|
985 |
+
logger.warning(f"Upload completed with errors: {errors}")
|
986 |
+
|
987 |
+
return JSONResponse(content=response_data, status_code=status_code)
|
988 |
+
|
989 |
+
|
990 |
+
# --- Health Check Endpoint ---
|
991 |
+
@app.get("/health", tags=["Server Status"], summary="Check server health")
|
992 |
+
async def health_check():
|
993 |
+
"""Basic health check, indicates if the server is running and if the model is loaded."""
|
994 |
+
# Access the MODEL_LOADED variable *directly* from the engine module
|
995 |
+
# each time the endpoint is called to get the current status.
|
996 |
+
current_model_status = getattr(engine, "MODEL_LOADED", False) # Safely get status
|
997 |
+
logger.debug(
|
998 |
+
f"Health check returning model_loaded status: {current_model_status}"
|
999 |
+
) # Add debug log
|
1000 |
+
return {"status": "healthy", "model_loaded": current_model_status}
|
1001 |
+
|
1002 |
+
|
1003 |
+
# --- Main Execution ---
|
1004 |
+
if __name__ == "__main__":
|
1005 |
+
host = get_host()
|
1006 |
+
port = get_port()
|
1007 |
+
logger.info(f"Starting Dia TTS server on {host}:{port}")
|
1008 |
+
logger.info(f"Model Repository: {get_model_repo_id()}")
|
1009 |
+
logger.info(f"Model Config File: {get_model_config_filename()}")
|
1010 |
+
logger.info(f"Model Weights File: {get_model_weights_filename()}")
|
1011 |
+
logger.info(f"Model Cache Path: {get_model_cache_path()}")
|
1012 |
+
logger.info(f"Reference Audio Path: {get_reference_audio_path()}")
|
1013 |
+
logger.info(f"Output Path: {get_output_path()}")
|
1014 |
+
# Determine the host to display in logs and use for browser opening
|
1015 |
+
display_host = "localhost" if host == "0.0.0.0" else host
|
1016 |
+
logger.info(f"Web UI will be available at http://{display_host}:{port}/")
|
1017 |
+
logger.info(f"API Docs available at http://{display_host}:{port}/docs")
|
1018 |
+
|
1019 |
+
# Ensure UI directory and index.html exist for UI
|
1020 |
+
ui_dir = "ui"
|
1021 |
+
index_file = os.path.join(ui_dir, "index.html")
|
1022 |
+
if not os.path.isdir(ui_dir) or not os.path.isfile(index_file):
|
1023 |
+
logger.warning(
|
1024 |
+
f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work."
|
1025 |
+
)
|
1026 |
+
# Optionally create dummy files/dirs if needed for startup
|
1027 |
+
os.makedirs(ui_dir, exist_ok=True)
|
1028 |
+
if not os.path.isfile(index_file):
|
1029 |
+
try:
|
1030 |
+
with open(index_file, "w") as f:
|
1031 |
+
f.write(
|
1032 |
+
"<html><body>Web UI template missing. See project source for index.html.</body></html>"
|
1033 |
+
)
|
1034 |
+
logger.info(f"Created dummy {index_file}.")
|
1035 |
+
except Exception as e:
|
1036 |
+
logger.error(f"Failed to create dummy {index_file}: {e}")
|
1037 |
+
|
1038 |
+
# --- Create synchronization event ---
|
1039 |
+
# This event will be set by the lifespan manager once startup (incl. model loading) is complete.
|
1040 |
+
startup_complete_event = threading.Event()
|
1041 |
+
|
1042 |
+
# Run Uvicorn server
|
1043 |
+
# The lifespan context manager ('lifespan="on"') will run during startup.
|
1044 |
+
# The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'.
|
1045 |
+
uvicorn.run(
|
1046 |
+
"server:app", # Use the format 'module:app_instance'
|
1047 |
+
host=host,
|
1048 |
+
port=port,
|
1049 |
+
reload=False, # Set reload as needed for development/production
|
1050 |
+
# reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development
|
1051 |
+
# reload_includes=[
|
1052 |
+
# "*.py",
|
1053 |
+
# "*.html",
|
1054 |
+
# "*.css",
|
1055 |
+
# "*.js",
|
1056 |
+
# ".env",
|
1057 |
+
# "*.yaml",
|
1058 |
+
# ],
|
1059 |
+
lifespan="on", # Use the lifespan context manager defined in this file
|
1060 |
+
# workers=1 # Keep workers=1 when using reload=True or complex global state/models
|
1061 |
+
)
|
ui/index.html
ADDED
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en" class="dark"> <!-- Default to dark mode class -->
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta charset="UTF-8">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<title>Dia TTS Server | Text-to-Dialogue</title>
|
8 |
+
<link rel="icon" href="/static/favicon.ico" type="image/x-icon">
|
9 |
+
<!-- Tailwind CSS (CDN for simplicity, processes styles in <style type="text/tailwindcss"> below) -->
|
10 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
11 |
+
<script>
|
12 |
+
// Configure Tailwind CSS
|
13 |
+
tailwind.config = {
|
14 |
+
darkMode: 'class', // Enable class-based dark mode
|
15 |
+
theme: {
|
16 |
+
extend: {
|
17 |
+
colors: {
|
18 |
+
// Define color palettes used in style.css
|
19 |
+
// Light Mode Colors (Examples - Adjust as needed)
|
20 |
+
gray: { 50: '#f9fafb', 100: '#f3f4f6', 200: '#e5e7eb', 300: '#d1d5db', 400: '#9ca3af', 500: '#6b7280', 600: '#4b5563', 700: '#374151', 800: '#1f2937', 900: '#111827' },
|
21 |
+
sky: { 50: '#f0f9ff', 100: '#e0f2fe', 200: '#bae6fd', 300: '#7dd3fc', 400: '#38bdf8', 500: '#0ea5e9', 600: '#0284c7', 700: '#0369a1', 800: '#075985', 900: '#0c4a6e' },
|
22 |
+
indigo: { 50: '#eef2ff', 100: '#e0e7ff', 200: '#c7d2fe', 300: '#a5b4fc', 400: '#818cf8', 500: '#6366f1', 600: '#4f46e5', 700: '#4338ca', 800: '#3730a3', 900: '#312e81' },
|
23 |
+
red: { 100: '#fee2e2', 300: '#fca5a5', 500: '#ef4444', 600: '#dc2626', 800: '#991b1b', 900: '#7f1d1d' },
|
24 |
+
green: { 100: '#dcfce7', 300: '#86efac', 500: '#22c55e', 800: '#166534', 900: '#14532d' },
|
25 |
+
yellow: { 100: '#fef9c3', 300: '#fcd34d', 500: '#eab308', 700: '#b45309', 900: '#78350f' },
|
26 |
+
|
27 |
+
// Dark Mode Colors (Copied from previous inline config)
|
28 |
+
primary: { 50: '#f0f9ff', 100: '#e0f2fe', 200: '#bae6fd', 300: '#7dd3fc', 400: '#38bdf8', 500: '#0ea5e9', 600: '#0284c7', 700: '#0369a1', 800: '#075985', 900: '#0c4a6e' },
|
29 |
+
purple: { 50: '#faf5ff', 100: '#f3e8ff', 200: '#e9d5ff', 300: '#d8b4fe', 400: '#c084fc', 500: '#a855f7', 600: '#9333ea', 700: '#7e22ce', 800: '#6b21a8', 900: '#581c87' },
|
30 |
+
dark: { 50: '#f9fafb', 100: '#f3f4f6', 200: '#e5e7eb', 300: '#d1d5db', 400: '#9ca3af', 500: '#6b7280', 600: '#4b5563', 700: '#374151', 800: '#1f2937', 900: '#111827', 950: '#030712', 1000: '#0f1729' }
|
31 |
+
}
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
</script>
|
36 |
+
<!-- Removed External Stylesheet Link: <link rel="stylesheet" href="/ui/style.css"> -->
|
37 |
+
<!-- Wavesurfer for audio visualization -->
|
38 |
+
<script src="https://unpkg.com/wavesurfer.js@7"></script>
|
39 |
+
|
40 |
+
<style type="text/tailwindcss">
|
41 |
+
/* ui/style.css */
|
42 |
+
|
43 |
+
/* Import Tailwind base, components, and utilities */
|
44 |
+
@tailwind base;
|
45 |
+
@tailwind components;
|
46 |
+
@tailwind utilities;
|
47 |
+
|
48 |
+
/* Define custom components/utilities */
|
49 |
+
@layer components {
|
50 |
+
|
51 |
+
/* Base styles (Light Mode) */
|
52 |
+
.body-base {
|
53 |
+
@apply h-full bg-gray-100 text-gray-900;
|
54 |
+
}
|
55 |
+
|
56 |
+
.nav-base {
|
57 |
+
@apply bg-gradient-to-r from-white to-sky-100 border-b border-sky-200 shadow-md;
|
58 |
+
}
|
59 |
+
|
60 |
+
.nav-link {
|
61 |
+
@apply text-sky-700 hover:text-sky-900 px-3 py-2 rounded-md text-sm font-medium;
|
62 |
+
}
|
63 |
+
|
64 |
+
.title-link {
|
65 |
+
@apply text-gray-900 text-xl font-bold;
|
66 |
+
}
|
67 |
+
|
68 |
+
.card-base {
|
69 |
+
@apply bg-white shadow-lg rounded-lg overflow-hidden border border-gray-200;
|
70 |
+
}
|
71 |
+
|
72 |
+
.card-header {
|
73 |
+
@apply text-lg font-medium text-gray-900 mb-4;
|
74 |
+
}
|
75 |
+
|
76 |
+
.card-footer {
|
77 |
+
@apply bg-gray-50 px-6 py-4 flex items-center justify-between border-t border-gray-200;
|
78 |
+
}
|
79 |
+
|
80 |
+
.label-base {
|
81 |
+
@apply block text-sm font-medium text-gray-700 mb-1;
|
82 |
+
}
|
83 |
+
|
84 |
+
.input-base {
|
85 |
+
@apply block w-full rounded-md border-gray-300 shadow-sm focus:border-sky-500 focus:ring-sky-500 sm:text-sm px-3 py-2 bg-white text-gray-900 placeholder-gray-400;
|
86 |
+
}
|
87 |
+
|
88 |
+
.textarea-base {
|
89 |
+
@apply input-base;
|
90 |
+
/* Inherit base input styles */
|
91 |
+
}
|
92 |
+
|
93 |
+
.select-base {
|
94 |
+
@apply input-base appearance-none pr-8;
|
95 |
+
/* Add padding for arrow */
|
96 |
+
/* Consider adding a background SVG for the dropdown arrow */
|
97 |
+
}
|
98 |
+
|
99 |
+
.button-base {
|
100 |
+
@apply inline-flex items-center justify-center px-4 py-2 border border-transparent rounded-md shadow-sm text-sm font-medium focus:outline-none focus:ring-2 focus:ring-offset-2 transition-colors disabled:opacity-50 disabled:cursor-not-allowed whitespace-nowrap flex-shrink-0;
|
101 |
+
/* Added whitespace-nowrap and flex-shrink-0 for button text */
|
102 |
+
}
|
103 |
+
|
104 |
+
.btn-primary {
|
105 |
+
@apply button-base bg-sky-600 text-white hover:bg-sky-700 focus:ring-sky-500;
|
106 |
+
}
|
107 |
+
|
108 |
+
.btn-secondary {
|
109 |
+
@apply button-base bg-gray-200 text-gray-700 border-gray-300 hover:bg-gray-300 focus:ring-indigo-500;
|
110 |
+
/* Example secondary */
|
111 |
+
}
|
112 |
+
|
113 |
+
.btn-danger {
|
114 |
+
@apply button-base bg-red-600 text-white hover:bg-red-700 focus:ring-red-500;
|
115 |
+
}
|
116 |
+
|
117 |
+
.btn-purple {
|
118 |
+
@apply button-base bg-purple-600 text-white hover:bg-purple-700 focus:ring-purple-500;
|
119 |
+
}
|
120 |
+
|
121 |
+
.slider-base {
|
122 |
+
@apply w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer;
|
123 |
+
/* Need to style the thumb separately per browser */
|
124 |
+
}
|
125 |
+
|
126 |
+
.slider-thumb {
|
127 |
+
/* Basic thumb styling */
|
128 |
+
@apply appearance-none w-5 h-5 bg-sky-600 rounded-full cursor-pointer;
|
129 |
+
}
|
130 |
+
|
131 |
+
.radio-label {
|
132 |
+
@apply flex items-center space-x-2 cursor-pointer border border-gray-300 bg-white hover:border-sky-400 p-3 rounded-md transition-colors;
|
133 |
+
}
|
134 |
+
|
135 |
+
.radio-label-text {
|
136 |
+
@apply text-gray-700;
|
137 |
+
}
|
138 |
+
|
139 |
+
/* Apply checked styles directly using peer-checked utility on the container/text span */
|
140 |
+
/* .radio-label input:checked+span {
|
141 |
+
@apply text-sky-600 font-semibold;
|
142 |
+
}
|
143 |
+
|
144 |
+
.radio-label-checked {
|
145 |
+
@apply border-sky-500 ring-2 ring-sky-500;
|
146 |
+
} */
|
147 |
+
/* Replaced these custom classes with Tailwind peer utilities in the HTML */
|
148 |
+
|
149 |
+
|
150 |
+
.preset-button {
|
151 |
+
@apply button-base bg-indigo-100 text-indigo-700 border-indigo-200 hover:bg-indigo-200 focus:ring-indigo-500 text-xs px-3 py-1;
|
152 |
+
}
|
153 |
+
|
154 |
+
.notification-base {
|
155 |
+
@apply px-4 py-3 rounded relative shadow-md flex items-center mb-3;
|
156 |
+
/* Reduced margin bottom */
|
157 |
+
}
|
158 |
+
|
159 |
+
.notification-success {
|
160 |
+
@apply notification-base bg-green-100 border border-green-300 text-green-800;
|
161 |
+
}
|
162 |
+
|
163 |
+
.notification-error {
|
164 |
+
@apply notification-base bg-red-100 border border-red-300 text-red-800;
|
165 |
+
}
|
166 |
+
|
167 |
+
.notification-warning {
|
168 |
+
@apply notification-base bg-yellow-100 border border-yellow-300 text-yellow-800;
|
169 |
+
}
|
170 |
+
|
171 |
+
.notification-info {
|
172 |
+
/* Added info style */
|
173 |
+
@apply notification-base bg-sky-100 border border-sky-300 text-sky-800;
|
174 |
+
}
|
175 |
+
|
176 |
+
.code-inline {
|
177 |
+
@apply bg-gray-200 px-1 rounded text-sm font-mono text-gray-800;
|
178 |
+
}
|
179 |
+
|
180 |
+
.tooltip {
|
181 |
+
/* Basic tooltip styling */
|
182 |
+
@apply absolute hidden group-hover:block bg-gray-700 text-white text-xs rounded py-1 px-2 z-10 -mt-8;
|
183 |
+
}
|
184 |
+
|
185 |
+
.loading-overlay-base {
|
186 |
+
@apply fixed inset-0 bg-gray-600 bg-opacity-75 flex items-center justify-center z-50 transition-opacity duration-300;
|
187 |
+
}
|
188 |
+
|
189 |
+
.loading-box-base {
|
190 |
+
@apply bg-white p-6 rounded-lg shadow-xl flex flex-col items-center border border-gray-300;
|
191 |
+
}
|
192 |
+
|
193 |
+
.loading-spinner {
|
194 |
+
@apply animate-spin h-10 w-10 text-sky-600 mb-4;
|
195 |
+
}
|
196 |
+
|
197 |
+
.loading-text {
|
198 |
+
@apply text-gray-900 text-lg mb-2;
|
199 |
+
}
|
200 |
+
|
201 |
+
.loading-status {
|
202 |
+
@apply text-gray-600 text-sm mb-4 text-center max-w-xs;
|
203 |
+
/* Limit width */
|
204 |
+
}
|
205 |
+
|
206 |
+
.waveform-container {
|
207 |
+
@apply w-full h-24 bg-gray-100 rounded;
|
208 |
+
}
|
209 |
+
|
210 |
+
.audio-player-card {
|
211 |
+
@apply card-base mt-8;
|
212 |
+
/* Margin top for spacing */
|
213 |
+
}
|
214 |
+
|
215 |
+
.audio-player-controls {
|
216 |
+
@apply flex flex-wrap items-center justify-between gap-4;
|
217 |
+
}
|
218 |
+
|
219 |
+
.audio-player-buttons {
|
220 |
+
@apply flex items-center space-x-2 sm:space-x-4;
|
221 |
+
/* Adjust spacing */
|
222 |
+
}
|
223 |
+
|
224 |
+
.audio-player-info {
|
225 |
+
@apply text-sm text-gray-600 text-right;
|
226 |
+
}
|
227 |
+
|
228 |
+
.theme-switch {
|
229 |
+
@apply p-2 rounded-md text-gray-600 hover:bg-gray-200 hover:text-gray-800 focus:outline-none focus:ring-2 focus:ring-sky-500 focus:ring-offset-2;
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
/* Dark Mode Overrides using 'dark:' prefix */
|
234 |
+
.dark .body-base {
|
235 |
+
@apply bg-[#0f1729] text-white;
|
236 |
+
/* Original dark bg */
|
237 |
+
}
|
238 |
+
|
239 |
+
.dark .nav-base {
|
240 |
+
@apply bg-gradient-to-r from-dark-900 to-purple-900 border-b border-purple-800 shadow-lg;
|
241 |
+
}
|
242 |
+
|
243 |
+
.dark .nav-link {
|
244 |
+
@apply text-primary-300 hover:text-white;
|
245 |
+
}
|
246 |
+
|
247 |
+
.dark .title-link {
|
248 |
+
@apply text-white;
|
249 |
+
}
|
250 |
+
|
251 |
+
.dark .card-base {
|
252 |
+
@apply bg-dark-800 border border-dark-700;
|
253 |
+
}
|
254 |
+
|
255 |
+
.dark .card-header {
|
256 |
+
@apply text-white;
|
257 |
+
}
|
258 |
+
|
259 |
+
.dark .card-footer {
|
260 |
+
@apply bg-dark-900 border-t border-dark-700;
|
261 |
+
}
|
262 |
+
|
263 |
+
.dark .label-base {
|
264 |
+
@apply text-gray-300;
|
265 |
+
/* Lighter gray for dark */
|
266 |
+
}
|
267 |
+
|
268 |
+
.dark .input-base {
|
269 |
+
@apply border-dark-600 bg-dark-700 text-white placeholder-gray-500 focus:ring-offset-dark-800;
|
270 |
+
}
|
271 |
+
|
272 |
+
.dark .select-base {
|
273 |
+
/* Dark mode arrow styling if needed */
|
274 |
+
}
|
275 |
+
|
276 |
+
.dark .btn-primary {
|
277 |
+
@apply bg-primary-600 text-white hover:bg-primary-700 focus:ring-primary-500 focus:ring-offset-dark-800;
|
278 |
+
}
|
279 |
+
|
280 |
+
.dark .btn-secondary {
|
281 |
+
@apply bg-dark-700 text-white border-dark-600 hover:bg-dark-600 focus:ring-purple-500 focus:ring-offset-dark-800;
|
282 |
+
}
|
283 |
+
|
284 |
+
.dark .btn-danger {
|
285 |
+
@apply bg-red-600 text-white hover:bg-red-700 focus:ring-red-500 focus:ring-offset-dark-800;
|
286 |
+
}
|
287 |
+
|
288 |
+
.dark .btn-purple {
|
289 |
+
@apply bg-purple-600 text-white hover:bg-purple-700 focus:ring-purple-500 focus:ring-offset-dark-800;
|
290 |
+
}
|
291 |
+
|
292 |
+
.dark .slider-base {
|
293 |
+
@apply bg-dark-600;
|
294 |
+
}
|
295 |
+
|
296 |
+
.dark .slider-thumb {
|
297 |
+
@apply bg-primary-500;
|
298 |
+
}
|
299 |
+
|
300 |
+
.dark .radio-label {
|
301 |
+
@apply border-dark-600 bg-dark-800 hover:border-primary-400;
|
302 |
+
}
|
303 |
+
|
304 |
+
.dark .radio-label-text {
|
305 |
+
@apply text-gray-300;
|
306 |
+
}
|
307 |
+
|
308 |
+
/* Apply checked styles directly using peer-checked utility on the container/text span */
|
309 |
+
/* .dark .radio-label input:checked+span {
|
310 |
+
@apply text-primary-400;
|
311 |
+
}
|
312 |
+
|
313 |
+
.dark .radio-label-checked {
|
314 |
+
@apply border-primary-500 ring-primary-500;
|
315 |
+
} */
|
316 |
+
/* Replaced these custom classes with Tailwind peer utilities in the HTML */
|
317 |
+
|
318 |
+
|
319 |
+
.dark .preset-button {
|
320 |
+
@apply bg-indigo-900 text-indigo-200 border-indigo-700 hover:bg-indigo-800 focus:ring-indigo-500 focus:ring-offset-dark-800;
|
321 |
+
}
|
322 |
+
|
323 |
+
.dark .notification-success {
|
324 |
+
@apply notification-base bg-green-900 border border-green-700 text-green-100;
|
325 |
+
}
|
326 |
+
|
327 |
+
.dark .notification-error {
|
328 |
+
@apply notification-base bg-red-900 border border-red-700 text-red-100;
|
329 |
+
}
|
330 |
+
|
331 |
+
.dark .notification-warning {
|
332 |
+
@apply notification-base bg-yellow-900 border border-yellow-700 text-yellow-100;
|
333 |
+
}
|
334 |
+
|
335 |
+
.dark .notification-info {
|
336 |
+
/* Added info style */
|
337 |
+
@apply notification-base bg-sky-900 border border-sky-700 text-sky-100;
|
338 |
+
}
|
339 |
+
|
340 |
+
.dark .code-inline {
|
341 |
+
@apply bg-dark-900 text-purple-300;
|
342 |
+
}
|
343 |
+
|
344 |
+
.dark .tooltip {
|
345 |
+
@apply bg-dark-950;
|
346 |
+
}
|
347 |
+
|
348 |
+
.dark .loading-overlay-base {
|
349 |
+
@apply bg-dark-900 bg-opacity-75;
|
350 |
+
}
|
351 |
+
|
352 |
+
.dark .loading-box-base {
|
353 |
+
@apply bg-dark-800 border border-dark-700;
|
354 |
+
}
|
355 |
+
|
356 |
+
.dark .loading-spinner {
|
357 |
+
@apply text-primary-500;
|
358 |
+
}
|
359 |
+
|
360 |
+
.dark .loading-text {
|
361 |
+
@apply text-white;
|
362 |
+
}
|
363 |
+
|
364 |
+
.dark .loading-status {
|
365 |
+
@apply text-gray-400;
|
366 |
+
}
|
367 |
+
|
368 |
+
.dark .waveform-container {
|
369 |
+
@apply bg-dark-900;
|
370 |
+
}
|
371 |
+
|
372 |
+
.dark .audio-player-info {
|
373 |
+
@apply text-purple-300;
|
374 |
+
}
|
375 |
+
|
376 |
+
.dark .theme-switch {
|
377 |
+
@apply text-gray-400 hover:bg-dark-700 hover:text-white focus:ring-offset-dark-900;
|
378 |
+
}
|
379 |
+
|
380 |
+
}
|
381 |
+
|
382 |
+
/* Specific slider thumb styling per browser */
|
383 |
+
/* Apply these within the <style> tag as they target pseudo-elements */
|
384 |
+
input[type="range"].slider-base::-webkit-slider-thumb {
|
385 |
+
@apply slider-thumb;
|
386 |
+
}
|
387 |
+
|
388 |
+
input[type="range"].slider-base::-moz-range-thumb {
|
389 |
+
@apply slider-thumb;
|
390 |
+
}
|
391 |
+
|
392 |
+
/* Dark mode thumbs need specific overrides if needed */
|
393 |
+
.dark input[type="range"].slider-base::-webkit-slider-thumb {
|
394 |
+
/* Apply dark mode thumb styles directly */
|
395 |
+
background-color: theme('colors.primary.500');
|
396 |
+
/* Replaced @apply dark:slider-thumb */
|
397 |
+
/* Inherit other base thumb styles if needed (like size, border-radius) or re-apply */
|
398 |
+
@apply appearance-none w-5 h-5 rounded-full cursor-pointer;
|
399 |
+
}
|
400 |
+
|
401 |
+
.dark input[type="range"].slider-base::-moz-range-thumb {
|
402 |
+
/* Apply dark mode thumb styles directly */
|
403 |
+
background-color: theme('colors.primary.500');
|
404 |
+
/* Replaced @apply dark:slider-thumb */
|
405 |
+
/* Inherit other base thumb styles if needed or re-apply */
|
406 |
+
@apply appearance-none w-5 h-5 rounded-full cursor-pointer;
|
407 |
+
}
|
408 |
+
</style>
|
409 |
+
</head>
|
410 |
+
|
411 |
+
<body class="body-base">
|
412 |
+
<div class="min-h-full">
|
413 |
+
<!-- Navigation -->
|
414 |
+
<nav class="nav-base">
|
415 |
+
<div class="mx-auto max-w-7xl px-4 sm:px-6 lg:px-8">
|
416 |
+
<div class="flex h-16 items-center justify-between">
|
417 |
+
<div class="flex items-center">
|
418 |
+
<div class="flex-shrink-0">
|
419 |
+
<!-- Make title clickable -->
|
420 |
+
<a href="/" class="title-link">Dia TTS Server</a>
|
421 |
+
</div>
|
422 |
+
</div>
|
423 |
+
<div class="flex items-center space-x-2 sm:space-x-4">
|
424 |
+
<a href="/docs" target="_blank" class="nav-link">API Docs</a>
|
425 |
+
<!-- Theme Toggle Button -->
|
426 |
+
<button id="theme-toggle-btn" type="button"
|
427 |
+
class="relative inline-flex items-center p-1 rounded-full bg-gray-200 dark:bg-dark-700 h-8 w-16 transition-colors"
|
428 |
+
title="Toggle light/dark mode">
|
429 |
+
<span class="sr-only">Toggle theme</span>
|
430 |
+
<span class="absolute inset-0 rounded-full transition-colors"></span>
|
431 |
+
<!-- Toggle thumb with icons -->
|
432 |
+
<span
|
433 |
+
class="relative rounded-full w-6 h-6 bg-white dark:bg-purple-600 transform transition-transform duration-200 ease-in-out translate-x-0 dark:translate-x-8 flex items-center justify-center shadow-md">
|
434 |
+
<!-- Sun icon (for light mode) -->
|
435 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
|
436 |
+
class="w-4 h-4 text-yellow-500 dark:opacity-0 transition-opacity">
|
437 |
+
<path
|
438 |
+
d="M10 2a.75.75 0 0 1 .75.75v1.5a.75.75 0 0 1-1.5 0v-1.5A.75.75 0 0 1 10 2ZM10 15a.75.75 0 0 1 .75.75v1.5a.75.75 0 0 1-1.5 0v-1.5A.75.75 0 0 1 10 15ZM10 7a3 3 0 1 0 0 6 3 3 0 0 0 0-6Z" />
|
439 |
+
</svg>
|
440 |
+
<!-- Moon icon (for dark mode) -->
|
441 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
|
442 |
+
class="w-4 h-4 text-white opacity-0 dark:opacity-100 transition-opacity">
|
443 |
+
<path
|
444 |
+
d="M7.455 1.75A8.5 8.5 0 0 1 18.25 12.55a8.5 8.5 0 0 1-8.46 8.46A8.5 8.5 0 0 1 1.75 12.55a8.5 8.5 0 0 1 5.705-10.8Z" />
|
445 |
+
</svg>
|
446 |
+
</span>
|
447 |
+
</button>
|
448 |
+
</div>
|
449 |
+
</div>
|
450 |
+
</div>
|
451 |
+
</nav>
|
452 |
+
|
453 |
+
<!-- Main content -->
|
454 |
+
<main>
|
455 |
+
<div class="mx-auto max-w-7xl px-4 py-8 sm:px-6 lg:px-8">
|
456 |
+
|
457 |
+
<!-- Notification area -->
|
458 |
+
<div id="notification-area" class="mb-6 space-y-3">
|
459 |
+
{% if error %}
|
460 |
+
<div class="notification-error" role="alert">
|
461 |
+
<svg class="h-5 w-5 text-red-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor">
|
462 |
+
<path fill-rule="evenodd"
|
463 |
+
d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.707 7.293a1 1 0 00-1.414 1.414L8.586 10l-1.293 1.293a1 1 0 101.414 1.414L10 11.414l1.293 1.293a1 1 0 001.414-1.414L11.414 10l1.293-1.293a1 1 0 00-1.414-1.414L10 8.586 8.707 7.293z"
|
464 |
+
clip-rule="evenodd" />
|
465 |
+
</svg>
|
466 |
+
<span class="block sm:inline">{{ error }}</span>
|
467 |
+
</div>
|
468 |
+
{% endif %}
|
469 |
+
{% if success %}
|
470 |
+
<div class="notification-success" role="alert">
|
471 |
+
<svg class="h-5 w-5 text-green-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor">
|
472 |
+
<path fill-rule="evenodd"
|
473 |
+
d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z"
|
474 |
+
clip-rule="evenodd" />
|
475 |
+
</svg>
|
476 |
+
<span class="block sm:inline">{{ success }}</span>
|
477 |
+
</div>
|
478 |
+
{% endif %}
|
479 |
+
</div>
|
480 |
+
|
481 |
+
<!-- TTS form -->
|
482 |
+
<div class="card-base">
|
483 |
+
<form id="tts-form" action="/web/generate" method="post" class="flex flex-col">
|
484 |
+
<div class="p-6">
|
485 |
+
<h2 class="card-header">Generate Speech with Dia</h2>
|
486 |
+
|
487 |
+
<!-- Text input -->
|
488 |
+
<div class="mb-6">
|
489 |
+
<label for="text" class="label-base">Text to speak</label>
|
490 |
+
<p class="text-xs text-purple-500 dark:text-purple-300 mb-2">
|
491 |
+
Use <code class="code-inline">[S1]</code> and <code class="code-inline">[S2]</code>
|
492 |
+
tags for speaker turns. Add non-verbals like <code
|
493 |
+
class="code-inline">(laughs)</code>.
|
494 |
+
</p>
|
495 |
+
<div class="relative">
|
496 |
+
<textarea name="text" id="text" rows="5" maxlength="8192" class="textarea-base"
|
497 |
+
placeholder="Example: [S1] Hello there! [S2] Hi! How are you? [S1] I'm doing well, thanks. (laughs)"
|
498 |
+
required>{{ submitted_text if submitted_text else "" }}</textarea>
|
499 |
+
<div class="absolute bottom-2 right-2 text-xs text-gray-500 dark:text-purple-300">
|
500 |
+
<span id="char-count">0</span> / 8192
|
501 |
+
</div>
|
502 |
+
</div>
|
503 |
+
</div>
|
504 |
+
|
505 |
+
<!-- Voice Mode Selection -->
|
506 |
+
<div class="mb-6">
|
507 |
+
<label class="label-base mb-2">Voice Mode</label>
|
508 |
+
<div class="grid grid-cols-1 md:grid-cols-2 gap-4">
|
509 |
+
<!-- Combined Dialogue / Single Speaker Mode -->
|
510 |
+
<label
|
511 |
+
class="radio-label peer-checked:border-sky-500 peer-checked:dark:border-primary-500 peer-checked:ring-2 peer-checked:ring-sky-500 peer-checked:dark:ring-primary-500">
|
512 |
+
<input type="radio" name="voice_mode" value="dialogue" class="hidden peer" {% if
|
513 |
+
submitted_voice_mode=='dialogue' or not submitted_voice_mode %}checked{%
|
514 |
+
endif %} onchange="toggleCloneOptions()">
|
515 |
+
<span
|
516 |
+
class="radio-label-text peer-checked:text-sky-600 dark:peer-checked:text-primary-400 peer-checked:font-semibold">
|
517 |
+
Single / Dialogue (Use [S1]/[S2])
|
518 |
+
</span>
|
519 |
+
</label>
|
520 |
+
<!-- Clone Mode -->
|
521 |
+
<label
|
522 |
+
class="radio-label peer-checked:border-sky-500 peer-checked:dark:border-primary-500 peer-checked:ring-2 peer-checked:ring-sky-500 peer-checked:dark:ring-primary-500">
|
523 |
+
<input type="radio" name="voice_mode" value="clone" class="hidden peer" {% if
|
524 |
+
submitted_voice_mode=='clone' %}checked{% endif %}
|
525 |
+
onchange="toggleCloneOptions()">
|
526 |
+
<span
|
527 |
+
class="radio-label-text peer-checked:text-sky-600 dark:peer-checked:text-primary-400 peer-checked:font-semibold">
|
528 |
+
Voice Clone (from Reference)
|
529 |
+
</span>
|
530 |
+
</label>
|
531 |
+
</div>
|
532 |
+
</div>
|
533 |
+
|
534 |
+
<!-- Presets Section -->
|
535 |
+
<div class="mb-6">
|
536 |
+
<label class="label-base mb-2">Load Example Preset</label>
|
537 |
+
<div id="presets-container" class="flex flex-wrap gap-2">
|
538 |
+
{% if presets %}
|
539 |
+
{% for preset in presets %}
|
540 |
+
<button type="button" id="preset-btn-{{ loop.index0 }}" class="preset-button"
|
541 |
+
title="Load '{{ preset.name }}' text and settings">
|
542 |
+
{{ preset.name }}
|
543 |
+
</button>
|
544 |
+
{% endfor %}
|
545 |
+
{% else %}
|
546 |
+
<p class="text-sm text-gray-500 dark:text-gray-400">No presets loaded. Check
|
547 |
+
presets.yaml.</p>
|
548 |
+
{% endif %}
|
549 |
+
</div>
|
550 |
+
</div>
|
551 |
+
|
552 |
+
|
553 |
+
<!-- Clone Options (Hidden by default) -->
|
554 |
+
<div id="clone-options" class="mb-6 hidden">
|
555 |
+
<label for="clone_reference_select" class="label-base">Reference Audio File</label>
|
556 |
+
<p class="text-xs text-purple-500 dark:text-purple-300 mb-2">
|
557 |
+
Select a <code class="code-inline">.wav</code> or <code
|
558 |
+
class="code-inline">.mp3</code> file from the <code
|
559 |
+
class="code-inline">reference_audio</code> folder.
|
560 |
+
<strong class="dark:text-yellow-300 text-yellow-600">Important:</strong> Prepend the
|
561 |
+
exact transcript of this audio to your text input above for best results.
|
562 |
+
</p>
|
563 |
+
<div class="flex items-center gap-2">
|
564 |
+
<select id="clone_reference_select" name="clone_reference_select"
|
565 |
+
class="select-base flex-grow">
|
566 |
+
<option value="none" {% if not submitted_clone_file %}selected{% endif %}>--
|
567 |
+
Select Reference File --</option>
|
568 |
+
{% for filename in reference_files %}
|
569 |
+
<option value="{{ filename }}" {% if submitted_clone_file==filename %}selected{%
|
570 |
+
endif %}>{{ filename }}</option>
|
571 |
+
{% endfor %}
|
572 |
+
</select>
|
573 |
+
<!-- Hidden file input triggered by the button -->
|
574 |
+
<input type="file" id="clone-file-input" class="hidden" multiple accept=".wav,.mp3"
|
575 |
+
aria-label="Upload reference audio file">
|
576 |
+
<!-- Modified Load Button -->
|
577 |
+
<button type="button" id="clone-load-button" class="btn-secondary hidden"
|
578 |
+
title="Upload new reference files">
|
579 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
|
580 |
+
class="w-5 h-5 mr-1">
|
581 |
+
<path
|
582 |
+
d="M9.25 13.25a.75.75 0 0 0 1.5 0V4.636l2.955 3.129a.75.75 0 0 0 1.09-1.03l-4.25-4.5a.75.75 0 0 0-1.09 0l-4.25 4.5a.75.75 0 1 0 1.09 1.03L9.25 4.636v8.614Z" />
|
583 |
+
<path
|
584 |
+
d="M3.5 12.75a.75.75 0 0 0-1.5 0v2.5A2.75 2.75 0 0 0 4.75 18h10.5A2.75 2.75 0 0 0 18 15.25v-2.5a.75.75 0 0 0-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5Z" />
|
585 |
+
</svg>
|
586 |
+
Load
|
587 |
+
</button>
|
588 |
+
</div>
|
589 |
+
</div>
|
590 |
+
|
591 |
+
|
592 |
+
<!-- Generation Parameters -->
|
593 |
+
<div class="mb-6">
|
594 |
+
<details class="group">
|
595 |
+
<summary class="list-none flex cursor-pointer items-center">
|
596 |
+
<span class="text-sm font-medium label-base">Generation Parameters</span>
|
597 |
+
<span class="ml-2 text-purple-500 dark:text-purple-300">
|
598 |
+
<svg class="group-open:rotate-180 h-5 w-5 transition-transform"
|
599 |
+
viewBox="0 0 20 20" fill="currentColor">
|
600 |
+
<path fill-rule="evenodd"
|
601 |
+
d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z"
|
602 |
+
clip-rule="evenodd" />
|
603 |
+
</svg>
|
604 |
+
</span>
|
605 |
+
</summary>
|
606 |
+
<div class="mt-4 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-4">
|
607 |
+
<!-- Use default_gen_params passed from server for initial values -->
|
608 |
+
{% set current_gen_params = submitted_gen_params if submitted_gen_params else
|
609 |
+
default_gen_params %}
|
610 |
+
<!-- Speed Factor -->
|
611 |
+
<div>
|
612 |
+
<label for="speed_factor" class="label-base">Speed Factor (<span
|
613 |
+
id="speed_factor_value">{{ current_gen_params.speed_factor
|
614 |
+
}}</span>)</label>
|
615 |
+
<input type="range" id="speed_factor" name="speed_factor" min="0.5"
|
616 |
+
max="2.0" step="0.01" value="{{ current_gen_params.speed_factor }}"
|
617 |
+
class="slider-base">
|
618 |
+
</div>
|
619 |
+
<!-- CFG Scale -->
|
620 |
+
<div>
|
621 |
+
<label for="cfg_scale" class="label-base">CFG Scale (<span
|
622 |
+
id="cfg_scale_value">{{ current_gen_params.cfg_scale
|
623 |
+
}}</span>)</label>
|
624 |
+
<input type="range" id="cfg_scale" name="cfg_scale" min="1.0" max="5.0"
|
625 |
+
step="0.1" value="{{ current_gen_params.cfg_scale }}"
|
626 |
+
class="slider-base">
|
627 |
+
</div>
|
628 |
+
<!-- Temperature -->
|
629 |
+
<div>
|
630 |
+
<label for="temperature" class="label-base">Temperature (<span
|
631 |
+
id="temperature_value">{{ current_gen_params.temperature
|
632 |
+
}}</span>)</label>
|
633 |
+
<input type="range" id="temperature" name="temperature" min="1.0" max="1.5"
|
634 |
+
step="0.05" value="{{ current_gen_params.temperature }}"
|
635 |
+
class="slider-base">
|
636 |
+
</div>
|
637 |
+
<!-- Top P -->
|
638 |
+
<div>
|
639 |
+
<label for="top_p" class="label-base">Top P (<span id="top_p_value">{{
|
640 |
+
current_gen_params.top_p }}</span>)</label>
|
641 |
+
<input type="range" id="top_p" name="top_p" min="0.8" max="1.0" step="0.01"
|
642 |
+
value="{{ current_gen_params.top_p }}" class="slider-base">
|
643 |
+
</div>
|
644 |
+
<!-- CFG Filter Top K -->
|
645 |
+
<div>
|
646 |
+
<label for="cfg_filter_top_k" class="label-base">CFG Filter Top K (<span
|
647 |
+
id="cfg_filter_top_k_value">{{ current_gen_params.cfg_filter_top_k
|
648 |
+
}}</span>)</label>
|
649 |
+
<input type="range" id="cfg_filter_top_k" name="cfg_filter_top_k" min="15"
|
650 |
+
max="50" step="1" value="{{ current_gen_params.cfg_filter_top_k }}"
|
651 |
+
class="slider-base">
|
652 |
+
</div>
|
653 |
+
<!-- Save Gen Defaults Button -->
|
654 |
+
<div class="col-span-1 md:col-span-2 mt-4 flex items-center gap-4">
|
655 |
+
<button id="save-gen-defaults-btn" type="button" class="btn-secondary">
|
656 |
+
Save Generation Defaults
|
657 |
+
</button>
|
658 |
+
<span id="gen-defaults-status" class="text-xs hidden"></span>
|
659 |
+
</div>
|
660 |
+
</div>
|
661 |
+
</details>
|
662 |
+
</div>
|
663 |
+
|
664 |
+
<!-- Server Configuration (Collapsible) -->
|
665 |
+
<div class="mb-6">
|
666 |
+
<details class="group">
|
667 |
+
<summary class="list-none flex cursor-pointer items-center">
|
668 |
+
<span class="text-sm font-medium label-base">Server Configuration</span>
|
669 |
+
<span class="ml-2 text-purple-500 dark:text-purple-300">
|
670 |
+
<svg class="group-open:rotate-180 h-5 w-5 transition-transform"
|
671 |
+
viewBox="0 0 20 20" fill="currentColor">
|
672 |
+
<path fill-rule="evenodd"
|
673 |
+
d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z"
|
674 |
+
clip-rule="evenodd" />
|
675 |
+
</svg>
|
676 |
+
</span>
|
677 |
+
</summary>
|
678 |
+
<div id="server-config-form"
|
679 |
+
class="mt-4 border-t border-gray-200 dark:border-dark-700 pt-4">
|
680 |
+
<p class="text-xs text-purple-500 dark:text-purple-300 mb-3">
|
681 |
+
These settings are saved to the <code class="code-inline">.env</code> file.
|
682 |
+
Restart the server to apply changes.
|
683 |
+
</p>
|
684 |
+
<div class="grid grid-cols-1 md:grid-cols-2 gap-4">
|
685 |
+
<!-- Dia Model Repo ID -->
|
686 |
+
<div>
|
687 |
+
<label for="config_model_repo" class="label-base text-xs">Model Repo
|
688 |
+
ID</label>
|
689 |
+
<input type="text" id="config_model_repo" name="DIA_MODEL_REPO_ID"
|
690 |
+
value="{{ config.DIA_MODEL_REPO_ID }}"
|
691 |
+
placeholder="ttj/dia-1.6b-safetensors" class="input-base text-sm">
|
692 |
+
</div>
|
693 |
+
<!-- Model Config Filename -->
|
694 |
+
<div>
|
695 |
+
<label for="config_model_config" class="label-base text-xs">Model Config
|
696 |
+
Filename</label>
|
697 |
+
<input type="text" id="config_model_config"
|
698 |
+
name="DIA_MODEL_CONFIG_FILENAME"
|
699 |
+
value="{{ config.DIA_MODEL_CONFIG_FILENAME }}"
|
700 |
+
placeholder="config.json" class="input-base text-sm">
|
701 |
+
</div>
|
702 |
+
<!-- Model Weights Filename -->
|
703 |
+
<div>
|
704 |
+
<label for="config_model_weights" class="label-base text-xs">Model
|
705 |
+
Weights Filename</label>
|
706 |
+
<input type="text" id="config_model_weights"
|
707 |
+
name="DIA_MODEL_WEIGHTS_FILENAME"
|
708 |
+
value="{{ config.DIA_MODEL_WEIGHTS_FILENAME }}"
|
709 |
+
placeholder="dia-v0_1_bf16.safetensors" class="input-base text-sm">
|
710 |
+
</div>
|
711 |
+
<!-- Model Cache Path -->
|
712 |
+
<div>
|
713 |
+
<label for="config_model_cache" class="label-base text-xs">Model Cache
|
714 |
+
Path</label>
|
715 |
+
<input type="text" id="config_model_cache" name="DIA_MODEL_CACHE_PATH"
|
716 |
+
value="{{ config.DIA_MODEL_CACHE_PATH }}"
|
717 |
+
placeholder="./model_cache" class="input-base text-sm">
|
718 |
+
</div>
|
719 |
+
<!-- Reference Audio Path -->
|
720 |
+
<div>
|
721 |
+
<label for="config_ref_audio" class="label-base text-xs">Reference Audio
|
722 |
+
Path</label>
|
723 |
+
<input type="text" id="config_ref_audio" name="REFERENCE_AUDIO_PATH"
|
724 |
+
value="{{ config.REFERENCE_AUDIO_PATH }}"
|
725 |
+
placeholder="./reference_audio" class="input-base text-sm">
|
726 |
+
</div>
|
727 |
+
<!-- Output Path -->
|
728 |
+
<div>
|
729 |
+
<label for="config_output_path" class="label-base text-xs">Output
|
730 |
+
Path</label>
|
731 |
+
<input type="text" id="config_output_path" name="OUTPUT_PATH"
|
732 |
+
value="{{ config.OUTPUT_PATH }}" placeholder="./outputs"
|
733 |
+
class="input-base text-sm">
|
734 |
+
</div>
|
735 |
+
<!-- Server Host -->
|
736 |
+
<div>
|
737 |
+
<label for="config_host" class="label-base text-xs">Server Host</label>
|
738 |
+
<input type="text" id="config_host" name="HOST"
|
739 |
+
value="{{ config.HOST }}" placeholder="0.0.0.0"
|
740 |
+
class="input-base text-sm">
|
741 |
+
</div>
|
742 |
+
<!-- Server Port -->
|
743 |
+
<div>
|
744 |
+
<label for="config_port" class="label-base text-xs">Server Port</label>
|
745 |
+
<input type="number" id="config_port" name="PORT"
|
746 |
+
value="{{ config.PORT }}" min="1024" max="65535" step="1"
|
747 |
+
class="input-base text-sm">
|
748 |
+
</div>
|
749 |
+
<!-- Save/Restart Buttons -->
|
750 |
+
<div
|
751 |
+
class="col-span-1 md:col-span-2 mt-4 flex flex-col md:flex-row gap-4 items-center">
|
752 |
+
<button id="save-config-btn" type="button"
|
753 |
+
class="btn-purple w-full md:w-auto">
|
754 |
+
Save Server Configuration
|
755 |
+
</button>
|
756 |
+
<button id="restart-server-btn" type="button"
|
757 |
+
class="btn-danger w-full md:w-auto hidden">
|
758 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none"
|
759 |
+
viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"
|
760 |
+
class="w-5 h-5 mr-1 inline-block">
|
761 |
+
<path stroke-linecap="round" stroke-linejoin="round"
|
762 |
+
d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0 3.181 3.183a8.25 8.25 0 0 0 13.803-3.7M4.031 9.865a8.25 8.25 0 0 1 13.803-3.7l3.181 3.182m0-4.991v4.99" />
|
763 |
+
</svg>
|
764 |
+
Restart Server
|
765 |
+
</button>
|
766 |
+
<span id="config-status" class="text-xs ml-2 hidden"></span>
|
767 |
+
</div>
|
768 |
+
</div>
|
769 |
+
</div>
|
770 |
+
</details>
|
771 |
+
</div>
|
772 |
+
|
773 |
+
</div> <!-- End p-6 -->
|
774 |
+
|
775 |
+
<!-- Form Actions -->
|
776 |
+
<div class="card-footer">
|
777 |
+
<div class="text-sm text-gray-600 dark:text-purple-300">
|
778 |
+
<p>Use <code class="code-inline">[S1]</code>/<code class="code-inline">[S2]</code> for
|
779 |
+
dialogue. Add <code class="code-inline">(laughs)</code> etc.</p>
|
780 |
+
</div>
|
781 |
+
<button type="submit" id="generate-btn" class="btn-primary">
|
782 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"
|
783 |
+
stroke-width="1.5" stroke="currentColor" class="w-5 h-5 mr-1 inline-block">
|
784 |
+
<path stroke-linecap="round" stroke-linejoin="round"
|
785 |
+
d="M19.114 5.636a9 9 0 0 1 0 12.728M16.463 8.288a5.25 5.25 0 0 1 0 7.424M6.75 8.25l4.72-4.72a.75.75 0 0 1 1.28.53v15.88a.75.75 0 0 1-1.28.53l-4.72-4.72H4.51c-.88 0-1.704-.507-1.938-1.354A9.009 9.009 0 0 1 2.25 12c0-.83.112-1.633.322-2.396C2.806 8.756 3.63 8.25 4.51 8.25H6.75Z" />
|
786 |
+
</svg>
|
787 |
+
Generate Speech
|
788 |
+
</button>
|
789 |
+
</div>
|
790 |
+
</form>
|
791 |
+
</div> <!-- End TTS Form Card -->
|
792 |
+
|
793 |
+
<!-- Audio player container - Populated by JavaScript if generation is successful -->
|
794 |
+
<div id="audio-player-container" class="mt-8">
|
795 |
+
{% if output_file_url %}
|
796 |
+
<!-- Template for initial load if result is passed from server -->
|
797 |
+
<!-- Add data attribute to signal JS that result is present -->
|
798 |
+
<div id="output-file-url-data" data-initial-audio-url="{{ output_file_url }}" class="hidden"></div>
|
799 |
+
<div class="audio-player-card">
|
800 |
+
<div class="p-6">
|
801 |
+
<h2 class="card-header">Generated Audio</h2>
|
802 |
+
<div class="mb-4">
|
803 |
+
<div id="waveform" class="waveform-container"></div>
|
804 |
+
</div>
|
805 |
+
<div class="audio-player-controls">
|
806 |
+
<div class="audio-player-buttons">
|
807 |
+
<button id="play-btn" class="btn-primary" disabled>
|
808 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
|
809 |
+
class="w-5 h-5 mr-1">
|
810 |
+
<path fill-rule="evenodd"
|
811 |
+
d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z"
|
812 |
+
clip-rule="evenodd" />
|
813 |
+
</svg>
|
814 |
+
Play
|
815 |
+
</button>
|
816 |
+
<a id="download-link" href="{{ output_file_url }}"
|
817 |
+
download="{{ output_file_url.split('/')[-1] }}" class="btn-secondary">
|
818 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
|
819 |
+
class="w-5 h-5 mr-1">
|
820 |
+
<path
|
821 |
+
d="M10.75 2.75a.75.75 0 0 0-1.5 0v8.614L6.295 8.235a.75.75 0 1 0-1.09 1.03l4.25 4.5a.75.75 0 0 0 1.09 0l4.25-4.5a.75.75 0 0 0-1.09-1.03l-2.955 3.129V2.75Z" />
|
822 |
+
<path
|
823 |
+
d="M3.5 12.75a.75.75 0 0 0-1.5 0v2.5A2.75 2.75 0 0 0 4.75 18h10.5A2.75 2.75 0 0 0 18 15.25v-2.5a.75.75 0 0 0-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5Z" />
|
824 |
+
</svg>
|
825 |
+
Download WAV
|
826 |
+
</a>
|
827 |
+
</div>
|
828 |
+
<div class="audio-player-info">
|
829 |
+
Mode: <span class="font-medium">{{ submitted_voice_mode }}</span>
|
830 |
+
{% if submitted_voice_mode == 'clone' and submitted_clone_file %}
|
831 |
+
(<span class="font-medium">{{ submitted_clone_file }}</span>)
|
832 |
+
{% endif %}
|
833 |
+
• Gen Time: <span class="font-medium">{{ generation_time }}s</span>
|
834 |
+
• Duration: <span id="audio-duration" class="font-medium">--:--</span>
|
835 |
+
</div>
|
836 |
+
</div>
|
837 |
+
</div>
|
838 |
+
</div>
|
839 |
+
{% endif %}
|
840 |
+
</div>
|
841 |
+
|
842 |
+
<!-- Tips Section -->
|
843 |
+
<div class="mt-8">
|
844 |
+
<h2 class="card-header mb-4">Tips & Tricks for Dia</h2>
|
845 |
+
<div class="card-base">
|
846 |
+
<div class="p-6">
|
847 |
+
<ul class="list-disc pl-5 text-sm text-gray-700 dark:text-purple-300 space-y-2">
|
848 |
+
<li>For **Dialogue** mode, clearly mark speaker turns using <code
|
849 |
+
class="code-inline">[S1]</code> and <code class="code-inline">[S2]</code>.</li>
|
850 |
+
<li>Add non-verbal sounds like <code class="code-inline">(laughs)</code>, <code
|
851 |
+
class="code-inline">(sighs)</code>, <code
|
852 |
+
class="code-inline">(clears throat)</code> within the text where desired.</li>
|
853 |
+
<li>For **Voice Clone** mode, upload a clean reference audio file (<code
|
854 |
+
class="code-inline">.wav</code>/<code class="code-inline">.mp3</code>) using the
|
855 |
+
"Load" button. <strong class="dark:text-yellow-300 text-yellow-600">Crucially,
|
856 |
+
include the exact transcript of the reference audio at the beginning of your
|
857 |
+
text input</strong> (e.g., <code
|
858 |
+
class="code-inline">[S1] Reference transcript. [S1] Target text...</code>).</li>
|
859 |
+
<li>Experiment with **CFG Scale** (higher = more adherence to text, potentially less
|
860 |
+
natural) and **Temperature** (higher = more random/varied).</li>
|
861 |
+
<li>The **Speed Factor** adjusts playback speed (0.8 = slower, 1.0 = original).</li>
|
862 |
+
<li>Use the <code class="code-inline">/v1/audio/speech</code> endpoint for OpenAI
|
863 |
+
compatibility. Use the <code class="code-inline">voice</code> parameter to specify
|
864 |
+
mode ('S1', 'S2', 'dialogue', 'reference_file.wav').</li>
|
865 |
+
</ul>
|
866 |
+
</div>
|
867 |
+
</div>
|
868 |
+
</div>
|
869 |
+
</div>
|
870 |
+
</main>
|
871 |
+
|
872 |
+
<footer class="nav-base py-6 mt-12">
|
873 |
+
<div class="mx-auto max-w-7xl px-4 sm:px-6 lg:px-8">
|
874 |
+
<div class="flex justify-center">
|
875 |
+
<a href="https://github.com/devnen/Dia-TTS-Server"
|
876 |
+
class="flex items-center gap-2 text-gray-600 dark:text-purple-300 text-sm hover:text-sky-600 dark:hover:text-primary-400 transition-colors">
|
877 |
+
<!-- GitHub icon -->
|
878 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="currentColor"
|
879 |
+
viewBox="0 0 16 16" class="flex-shrink-0">
|
880 |
+
<path
|
881 |
+
d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.012 8.012 0 0 0 16 8c0-4.42-3.58-8-8-8z" />
|
882 |
+
</svg>
|
883 |
+
<span>Dia TTS Server | Powered by FastAPI</span>
|
884 |
+
</a>
|
885 |
+
</div>
|
886 |
+
</div>
|
887 |
+
</footer>
|
888 |
+
</div>
|
889 |
+
|
890 |
+
<!-- Loading spinner template (hidden by default) -->
|
891 |
+
<div id="loading-overlay" class="loading-overlay-base hidden">
|
892 |
+
<div class="loading-box-base">
|
893 |
+
<svg class="loading-spinner" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
894 |
+
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
895 |
+
<path class="opacity-75" fill="currentColor"
|
896 |
+
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z">
|
897 |
+
</path>
|
898 |
+
</svg>
|
899 |
+
<p id="loading-message" class="loading-text">Generating audio...</p>
|
900 |
+
<p id="loading-status" class="loading-status">Please wait.</p>
|
901 |
+
<button id="loading-cancel-btn" type="button" class="btn-secondary mt-4">Cancel</button>
|
902 |
+
</div>
|
903 |
+
</div>
|
904 |
+
|
905 |
+
<!-- Pass data from server to JavaScript -->
|
906 |
+
<script>
|
907 |
+
// Make presets data available to script.js
|
908 |
+
// Ensure this is correctly populated by your Jinja2 template context
|
909 |
+
window.appPresets = {{ presets | tojson | safe }};
|
910 |
+
</script>
|
911 |
+
|
912 |
+
<!-- Link External JavaScript (Ensure it's loaded AFTER the DOM) -->
|
913 |
+
<script src="/ui/script.js" defer></script>
|
914 |
+
</body>
|
915 |
+
|
916 |
+
</html>
|
ui/presets.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ui/presets.yaml
|
2 |
+
# Predefined examples for the Dia TTS UI
|
3 |
+
|
4 |
+
- name: "Standard Dialogue"
|
5 |
+
voice_mode: "dialogue"
|
6 |
+
text: |
|
7 |
+
[S1] Hey, how's it going?
|
8 |
+
[S2] Pretty good! Just grabbing some coffee. You?
|
9 |
+
[S1] Same here. Need the fuel! (laughs)
|
10 |
+
params:
|
11 |
+
cfg_scale: 3.0
|
12 |
+
temperature: 1.3
|
13 |
+
top_p: 0.95
|
14 |
+
cfg_filter_top_k: 35
|
15 |
+
# speed_factor uses the saved default
|
16 |
+
|
17 |
+
- name: "Expressive Narration"
|
18 |
+
voice_mode: "dialogue" # Use dialogue mode with single speaker tag
|
19 |
+
text: |
|
20 |
+
[S1] The old house stood on a windswept hill, its windows like empty eyes staring out at the stormy sea. (sighs) It felt... lonely.
|
21 |
+
params:
|
22 |
+
cfg_scale: 3.0
|
23 |
+
temperature: 1.2 # Slightly lower temp for clarity
|
24 |
+
top_p: 0.95
|
25 |
+
cfg_filter_top_k: 35
|
26 |
+
|
27 |
+
- name: "Quick Announcement"
|
28 |
+
voice_mode: "dialogue" # Use dialogue mode with single speaker tag
|
29 |
+
text: |
|
30 |
+
[S1] Attention shoppers! The store will be closing in 15 minutes. Please bring your final purchases to the checkout.
|
31 |
+
params:
|
32 |
+
cfg_scale: 2.8 # Slightly lower CFG for potentially more natural tone
|
33 |
+
temperature: 1.3
|
34 |
+
top_p: 0.95
|
35 |
+
cfg_filter_top_k: 35
|
36 |
+
|
37 |
+
- name: "Funny Exchange"
|
38 |
+
voice_mode: "dialogue"
|
39 |
+
text: |
|
40 |
+
[S1] Did you remember to buy the alien repellent?
|
41 |
+
[S2] The what now? (laughs) I thought you were joking!
|
42 |
+
[S1] Joking? They're landing tonight! (clears throat) Probably.
|
43 |
+
params:
|
44 |
+
cfg_scale: 3.2 # Slightly higher CFG
|
45 |
+
temperature: 1.35 # Slightly higher temp
|
46 |
+
top_p: 0.95
|
47 |
+
cfg_filter_top_k: 35
|
48 |
+
|
49 |
+
- name: "Simple Sentence"
|
50 |
+
voice_mode: "dialogue" # Use dialogue mode with single speaker tag
|
51 |
+
text: |
|
52 |
+
[S1] This is a test of the text to speech system.
|
53 |
+
params:
|
54 |
+
cfg_scale: 3.0
|
55 |
+
temperature: 1.3
|
56 |
+
top_p: 0.95
|
57 |
+
cfg_filter_top_k: 35
|
ui/script.js
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// ui/script.js
|
2 |
+
|
3 |
+
document.addEventListener('DOMContentLoaded', function () {
|
4 |
+
// --- Global Flags ---
|
5 |
+
let isGenerating = false;
|
6 |
+
let isGenerationCancelled = false;
|
7 |
+
let wavesurfer = null; // Global wavesurfer instance
|
8 |
+
|
9 |
+
// --- Element Selectors ---
|
10 |
+
const ttsForm = document.getElementById('tts-form');
|
11 |
+
const textArea = document.getElementById('text');
|
12 |
+
const charCount = document.getElementById('char-count');
|
13 |
+
const voiceModeRadios = document.querySelectorAll('input[name="voice_mode"]');
|
14 |
+
const cloneOptionsDiv = document.getElementById('clone-options');
|
15 |
+
const cloneReferenceSelect = document.getElementById('clone_reference_select');
|
16 |
+
const cloneLoadButton = document.getElementById('clone-load-button'); // New ID
|
17 |
+
const cloneFileInput = document.getElementById('clone-file-input'); // New ID
|
18 |
+
const generateBtn = document.getElementById('generate-btn');
|
19 |
+
const loadingOverlay = document.getElementById('loading-overlay');
|
20 |
+
const loadingMessage = document.getElementById('loading-message');
|
21 |
+
const loadingStatus = document.getElementById('loading-status'); // New element for status
|
22 |
+
const loadingCancelBtn = document.getElementById('loading-cancel-btn'); // New ID
|
23 |
+
const notificationArea = document.getElementById('notification-area');
|
24 |
+
const audioPlayerContainer = document.getElementById('audio-player-container');
|
25 |
+
const configSaveBtn = document.getElementById('save-config-btn');
|
26 |
+
const configRestartBtn = document.getElementById('restart-server-btn');
|
27 |
+
const configStatus = document.getElementById('config-status');
|
28 |
+
const genDefaultsSaveBtn = document.getElementById('save-gen-defaults-btn'); // New ID
|
29 |
+
const genDefaultsStatus = document.getElementById('gen-defaults-status'); // New ID
|
30 |
+
const themeToggleButton = document.getElementById('theme-toggle-btn'); // New ID
|
31 |
+
const themeIconLight = document.getElementById('theme-icon-light'); // New ID
|
32 |
+
const themeIconDark = document.getElementById('theme-icon-dark'); // New ID
|
33 |
+
const presetsContainer = document.getElementById('presets-container'); // New ID
|
34 |
+
|
35 |
+
// --- Initial Setup ---
|
36 |
+
|
37 |
+
// Character counter
|
38 |
+
function updateCharCount() {
|
39 |
+
if (textArea && charCount) {
|
40 |
+
charCount.textContent = textArea.value.length;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
if (textArea) {
|
44 |
+
textArea.addEventListener('input', updateCharCount);
|
45 |
+
updateCharCount(); // Initial count
|
46 |
+
}
|
47 |
+
|
48 |
+
// Toggle Clone Options Visibility & Required Attribute
|
49 |
+
function toggleCloneOptions() {
|
50 |
+
const selectedMode = document.querySelector('input[name="voice_mode"]:checked')?.value;
|
51 |
+
if (cloneOptionsDiv && cloneReferenceSelect && cloneLoadButton) {
|
52 |
+
if (selectedMode === 'clone') {
|
53 |
+
cloneOptionsDiv.classList.remove('hidden');
|
54 |
+
cloneReferenceSelect.required = true;
|
55 |
+
cloneLoadButton.classList.remove('hidden');
|
56 |
+
} else {
|
57 |
+
cloneOptionsDiv.classList.add('hidden');
|
58 |
+
cloneReferenceSelect.required = false;
|
59 |
+
// cloneReferenceSelect.value = 'none'; // Don't reset if user might switch back
|
60 |
+
cloneLoadButton.classList.add('hidden');
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
64 |
+
voiceModeRadios.forEach(radio => radio.addEventListener('change', toggleCloneOptions));
|
65 |
+
toggleCloneOptions(); // Initial check
|
66 |
+
|
67 |
+
// Update slider value displays dynamically
|
68 |
+
const sliders = [
|
69 |
+
{ id: 'speed_factor', valueId: 'speed_factor_value' },
|
70 |
+
{ id: 'cfg_scale', valueId: 'cfg_scale_value' },
|
71 |
+
{ id: 'temperature', valueId: 'temperature_value' },
|
72 |
+
{ id: 'top_p', valueId: 'top_p_value' },
|
73 |
+
{ id: 'cfg_filter_top_k', valueId: 'cfg_filter_top_k_value' },
|
74 |
+
];
|
75 |
+
sliders.forEach(sliderInfo => {
|
76 |
+
const slider = document.getElementById(sliderInfo.id);
|
77 |
+
const valueDisplay = document.getElementById(sliderInfo.valueId);
|
78 |
+
if (slider && valueDisplay) {
|
79 |
+
// Set initial display from slider's current value (set by template)
|
80 |
+
valueDisplay.textContent = slider.value;
|
81 |
+
// Add event listener to update display on change
|
82 |
+
slider.addEventListener('input', () => valueDisplay.textContent = slider.value);
|
83 |
+
}
|
84 |
+
});
|
85 |
+
|
86 |
+
// --- Notifications ---
|
87 |
+
function showNotification(message, type = 'success', duration = 5000) {
|
88 |
+
if (!notificationArea) return;
|
89 |
+
// notificationArea.innerHTML = ''; // Clear previous? Or allow multiple? Let's allow multiple for now.
|
90 |
+
const colors = {
|
91 |
+
success: 'notification-success',
|
92 |
+
error: 'notification-error',
|
93 |
+
warning: 'notification-warning',
|
94 |
+
info: 'notification-info' // Add info style if needed
|
95 |
+
};
|
96 |
+
const icons = { // SVG icons or classes
|
97 |
+
success: '<svg class="h-5 w-5 text-green-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd" /></svg>',
|
98 |
+
error: '<svg class="h-5 w-5 text-red-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.707 7.293a1 1 0 00-1.414 1.414L8.586 10l-1.293 1.293a1 1 0 101.414 1.414L10 11.414l1.293 1.293a1 1 0 001.414-1.414L11.414 10l1.293-1.293a1 1 0 00-1.414-1.414L10 8.586 8.707 7.293z" clip-rule="evenodd" /></svg>',
|
99 |
+
warning: '<svg class="h-5 w-5 text-yellow-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M8.485 2.495c.673-1.167 2.357-1.167 3.03 0l6.28 10.875c.673 1.167-.17 2.625-1.516 2.625H3.72c-1.347 0-2.189-1.458-1.515-2.625L8.485 2.495zM10 5a.75.75 0 01.75.75v3.5a.75.75 0 01-1.5 0v-3.5A.75.75 0 0110 5zm0 9a1 1 0 100-2 1 1 0 000 2z" clip-rule="evenodd" /></svg>',
|
100 |
+
info: '<svg class="h-5 w-5 text-sky-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a.75.75 0 000 1.5h.253a.25.25 0 01.244.304l-.459 2.066A1.75 1.75 0 0010.747 15H11a.75.75 0 000-1.5h-.253a.25.25 0 01-.244-.304l.459-2.066A1.75 1.75 0 009.253 9H9z" clip-rule="evenodd" /></svg>'
|
101 |
+
};
|
102 |
+
|
103 |
+
const notificationDiv = document.createElement('div');
|
104 |
+
notificationDiv.className = colors[type] || colors['info']; // Default to info style
|
105 |
+
notificationDiv.innerHTML = `${icons[type] || icons['info']} <span class="block sm:inline">${message}</span>`;
|
106 |
+
notificationArea.appendChild(notificationDiv);
|
107 |
+
|
108 |
+
// Auto-hide after specified duration
|
109 |
+
if (duration > 0) {
|
110 |
+
setTimeout(() => {
|
111 |
+
notificationDiv.style.transition = 'opacity 0.5s ease-out';
|
112 |
+
notificationDiv.style.opacity = '0';
|
113 |
+
setTimeout(() => notificationDiv.remove(), 500);
|
114 |
+
}, duration);
|
115 |
+
}
|
116 |
+
return notificationDiv; // Return the element if manual removal is needed
|
117 |
+
}
|
118 |
+
|
119 |
+
// --- Presets ---
|
120 |
+
function applyPreset(presetData) {
|
121 |
+
console.log("Applying preset:", presetData);
|
122 |
+
if (!presetData) return;
|
123 |
+
|
124 |
+
// Update text area
|
125 |
+
if (textArea && presetData.text !== undefined) {
|
126 |
+
textArea.value = presetData.text;
|
127 |
+
updateCharCount(); // Update counter
|
128 |
+
}
|
129 |
+
|
130 |
+
// Update voice mode
|
131 |
+
if (presetData.voice_mode) {
|
132 |
+
const radio = document.querySelector(`input[name="voice_mode"][value="${presetData.voice_mode}"]`);
|
133 |
+
if (radio) {
|
134 |
+
radio.checked = true;
|
135 |
+
toggleCloneOptions(); // Update UI based on new mode
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
// Update generation parameters
|
140 |
+
if (presetData.params) {
|
141 |
+
for (const [key, value] of Object.entries(presetData.params)) {
|
142 |
+
const slider = document.getElementById(key); // Assumes slider ID matches param key
|
143 |
+
const valueDisplay = document.getElementById(`${key}_value`);
|
144 |
+
if (slider) {
|
145 |
+
slider.value = value;
|
146 |
+
if (valueDisplay) {
|
147 |
+
valueDisplay.textContent = value; // Update display
|
148 |
+
}
|
149 |
+
} else {
|
150 |
+
console.warn(`Slider element not found for preset parameter: ${key}`);
|
151 |
+
}
|
152 |
+
}
|
153 |
+
}
|
154 |
+
showNotification(`Preset "${presetData.name}" loaded.`, 'info', 3000);
|
155 |
+
}
|
156 |
+
|
157 |
+
// Add event listeners to preset buttons (assuming they exist)
|
158 |
+
// Presets data should be available globally, e.g., from template `window.appPresets = {{ presets | tojson }};`
|
159 |
+
if (window.appPresets && presetsContainer) {
|
160 |
+
window.appPresets.forEach((preset, index) => {
|
161 |
+
const button = document.getElementById(`preset-btn-${index}`);
|
162 |
+
if (button) {
|
163 |
+
button.addEventListener('click', () => applyPreset(preset));
|
164 |
+
}
|
165 |
+
});
|
166 |
+
} else if (presetsContainer) {
|
167 |
+
console.warn("Presets data (window.appPresets) not found, preset buttons will not work.");
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
// --- Audio Player ---
|
172 |
+
function initializeWaveSurfer(audioUrl) {
|
173 |
+
if (wavesurfer) {
|
174 |
+
wavesurfer.destroy();
|
175 |
+
}
|
176 |
+
const waveformDiv = document.getElementById('waveform');
|
177 |
+
const playBtn = document.getElementById('play-btn');
|
178 |
+
const durationSpan = document.getElementById('audio-duration');
|
179 |
+
|
180 |
+
if (!waveformDiv || !playBtn || !durationSpan) {
|
181 |
+
console.error("Audio player elements not found in the container.");
|
182 |
+
// Clear the container if elements are missing after generation
|
183 |
+
if (audioPlayerContainer) audioPlayerContainer.innerHTML = '<p class="text-red-500 dark:text-red-400">Error displaying audio player.</p>';
|
184 |
+
return;
|
185 |
+
}
|
186 |
+
|
187 |
+
// Ensure button text doesn't wrap
|
188 |
+
playBtn.classList.add('whitespace-nowrap', 'flex-shrink-0');
|
189 |
+
const downloadLink = document.getElementById('download-link');
|
190 |
+
if (downloadLink) downloadLink.classList.add('whitespace-nowrap', 'flex-shrink-0');
|
191 |
+
|
192 |
+
|
193 |
+
wavesurfer = WaveSurfer.create({
|
194 |
+
container: waveformDiv,
|
195 |
+
waveColor: document.documentElement.classList.contains('dark') ? '#38bdf8' : '#0ea5e9', // primary-400(dark) / primary-500(light)
|
196 |
+
progressColor: document.documentElement.classList.contains('dark') ? '#0284c7' : '#0369a1', // primary-600(dark) / primary-700(light)
|
197 |
+
cursorColor: document.documentElement.classList.contains('dark') ? '#a855f7' : '#9333ea', // purple-500(dark) / purple-600(light)
|
198 |
+
barWidth: 3,
|
199 |
+
barRadius: 3,
|
200 |
+
cursorWidth: 1,
|
201 |
+
height: 80,
|
202 |
+
barGap: 2,
|
203 |
+
responsive: true,
|
204 |
+
url: audioUrl,
|
205 |
+
mediaControls: false, // Use custom controls
|
206 |
+
normalize: true,
|
207 |
+
});
|
208 |
+
|
209 |
+
wavesurfer.on('ready', () => {
|
210 |
+
const duration = wavesurfer.getDuration();
|
211 |
+
const minutes = Math.floor(duration / 60);
|
212 |
+
const seconds = Math.floor(duration % 60);
|
213 |
+
durationSpan.textContent = `${minutes}:${seconds < 10 ? '0' : ''}${seconds}`;
|
214 |
+
playBtn.disabled = false;
|
215 |
+
playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z" clip-rule="evenodd" /></svg> Play`;
|
216 |
+
});
|
217 |
+
|
218 |
+
wavesurfer.on('play', () => {
|
219 |
+
playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm5-2.25A.75.75 0 0 1 7.75 7h4.5a.75.75 0 0 1 .75.75v4.5a.75.75 0 0 1-.75.75h-4.5a.75.75 0 0 1-.75-.75v-4.5Z" clip-rule="evenodd" /></svg> Pause`;
|
220 |
+
});
|
221 |
+
|
222 |
+
wavesurfer.on('pause', () => {
|
223 |
+
playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z" clip-rule="evenodd" /></svg> Play`;
|
224 |
+
});
|
225 |
+
wavesurfer.on('finish', () => {
|
226 |
+
playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z" clip-rule="evenodd" /></svg> Play`;
|
227 |
+
});
|
228 |
+
|
229 |
+
playBtn.onclick = () => {
|
230 |
+
wavesurfer.playPause();
|
231 |
+
};
|
232 |
+
|
233 |
+
// Scroll to the player after initialization
|
234 |
+
setTimeout(() => {
|
235 |
+
audioPlayerContainer.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
236 |
+
}, 100); // Short delay to ensure rendering
|
237 |
+
}
|
238 |
+
|
239 |
+
// Initialize player if audio URL is present on initial page load
|
240 |
+
// This logic needs to be adapted as the player is now dynamically added
|
241 |
+
// We'll call initializeWaveSurfer if the template renders the player container
|
242 |
+
const initialAudioUrlElement = document.querySelector('[data-initial-audio-url]');
|
243 |
+
if (initialAudioUrlElement && initialAudioUrlElement.dataset.initialAudioUrl) {
|
244 |
+
console.log("Initializing WaveSurfer for initially loaded audio.");
|
245 |
+
initializeWaveSurfer(initialAudioUrlElement.dataset.initialAudioUrl);
|
246 |
+
}
|
247 |
+
|
248 |
+
|
249 |
+
// --- Form Submission & Cancellation ---
|
250 |
+
if (ttsForm) {
|
251 |
+
ttsForm.addEventListener('submit', function (event) {
|
252 |
+
// Client-side validation
|
253 |
+
const text = textArea.value.trim();
|
254 |
+
const mode = document.querySelector('input[name="voice_mode"]:checked')?.value;
|
255 |
+
const cloneFile = cloneReferenceSelect?.value;
|
256 |
+
|
257 |
+
if (!text) {
|
258 |
+
showNotification("Please enter some text.", 'error');
|
259 |
+
event.preventDefault(); return;
|
260 |
+
}
|
261 |
+
if (mode === 'clone' && (!cloneFile || cloneFile === 'none')) {
|
262 |
+
showNotification("Please select a reference file for clone mode.", 'error');
|
263 |
+
event.preventDefault(); return;
|
264 |
+
}
|
265 |
+
|
266 |
+
// Handle cancellation of previous request if Generate is clicked again
|
267 |
+
if (isGenerating) {
|
268 |
+
console.log("Generate clicked while previous generation in progress. Setting cancel flag.");
|
269 |
+
showNotification("Cancelling previous request...", 'warning', 2000);
|
270 |
+
isGenerationCancelled = true;
|
271 |
+
// We don't actually stop the backend here (Fake Cancel)
|
272 |
+
// but the result processing will ignore the previous result.
|
273 |
+
}
|
274 |
+
|
275 |
+
// Reset flags and show loading overlay for the new request
|
276 |
+
isGenerating = true;
|
277 |
+
isGenerationCancelled = false; // Reset cancel flag for the new request
|
278 |
+
if (loadingOverlay && generateBtn && loadingCancelBtn) {
|
279 |
+
loadingMessage.textContent = 'Generating audio...'; // Initial status
|
280 |
+
loadingStatus.textContent = 'Please wait.';
|
281 |
+
loadingOverlay.classList.remove('hidden');
|
282 |
+
generateBtn.disabled = true;
|
283 |
+
generateBtn.classList.add('opacity-50', 'cursor-not-allowed');
|
284 |
+
loadingCancelBtn.disabled = false; // Enable cancel button
|
285 |
+
}
|
286 |
+
// Allow default form submission to proceed
|
287 |
+
// The page will reload with results rendered by the template
|
288 |
+
});
|
289 |
+
}
|
290 |
+
|
291 |
+
// Handle Cancel button click
|
292 |
+
if (loadingCancelBtn) {
|
293 |
+
loadingCancelBtn.addEventListener('click', () => {
|
294 |
+
if (isGenerating) {
|
295 |
+
console.log("Cancel button clicked.");
|
296 |
+
isGenerationCancelled = true;
|
297 |
+
isGenerating = false; // Stop considering it "generating" from UI perspective
|
298 |
+
if (loadingOverlay && generateBtn) {
|
299 |
+
loadingOverlay.classList.add('hidden'); // Hide overlay
|
300 |
+
generateBtn.disabled = false; // Re-enable generate button
|
301 |
+
generateBtn.classList.remove('opacity-50', 'cursor-not-allowed');
|
302 |
+
}
|
303 |
+
showNotification("Generation cancelled by user.", 'info');
|
304 |
+
// Note: Backend request continues, but result will be ignored on page reload/update
|
305 |
+
}
|
306 |
+
});
|
307 |
+
}
|
308 |
+
|
309 |
+
// --- Result Handling (on page load after form submission) ---
|
310 |
+
// This logic runs every time the page loads. We check if specific elements
|
311 |
+
// indicating a successful generation are present.
|
312 |
+
const outputUrlElement = document.getElementById('output-file-url-data'); // Need to add this element in HTML
|
313 |
+
if (outputUrlElement && outputUrlElement.dataset.url) {
|
314 |
+
const outputUrl = outputUrlElement.dataset.url;
|
315 |
+
console.log("Page loaded with generation result:", outputUrl);
|
316 |
+
|
317 |
+
if (isGenerationCancelled) {
|
318 |
+
console.log("Generation was cancelled, ignoring result.");
|
319 |
+
showNotification("Previous generation was cancelled.", "warning");
|
320 |
+
// Reset flag after checking
|
321 |
+
isGenerationCancelled = false;
|
322 |
+
} else {
|
323 |
+
console.log("Processing successful generation result.");
|
324 |
+
// The audio player structure should be rendered by the template.
|
325 |
+
// We just need to initialize wavesurfer for it.
|
326 |
+
initializeWaveSurfer(outputUrl);
|
327 |
+
}
|
328 |
+
}
|
329 |
+
// Always reset generating flag on page load, as any active generation is now finished or irrelevant
|
330 |
+
isGenerating = false;
|
331 |
+
if (generateBtn) { // Re-enable button if page reloads for any reason
|
332 |
+
generateBtn.disabled = false;
|
333 |
+
generateBtn.classList.remove('opacity-50', 'cursor-not-allowed');
|
334 |
+
}
|
335 |
+
|
336 |
+
|
337 |
+
// --- Configuration Management ---
|
338 |
+
async function updateConfigStatus(button, statusElement, message, success = true, duration = 5000) {
|
339 |
+
const successClass = 'text-green-500 dark:text-green-400';
|
340 |
+
const errorClass = 'text-red-500 dark:text-red-400';
|
341 |
+
const savingClass = 'text-yellow-500 dark:text-yellow-400';
|
342 |
+
|
343 |
+
statusElement.textContent = message;
|
344 |
+
statusElement.className = `text-xs ml-2 ${success ? successClass : (message.startsWith('Saving') || message.startsWith('Restarting') ? savingClass : errorClass)}`;
|
345 |
+
statusElement.classList.remove('hidden');
|
346 |
+
if (button) button.disabled = true; // Disable button while processing
|
347 |
+
|
348 |
+
// Clear status after duration, re-enable button
|
349 |
+
if (duration > 0) {
|
350 |
+
setTimeout(() => {
|
351 |
+
statusElement.classList.add('hidden');
|
352 |
+
if (button) button.disabled = false;
|
353 |
+
}, duration);
|
354 |
+
}
|
355 |
+
}
|
356 |
+
|
357 |
+
// Save Server Configuration
|
358 |
+
if (configSaveBtn) {
|
359 |
+
configSaveBtn.addEventListener('click', async () => {
|
360 |
+
const configData = {};
|
361 |
+
document.querySelectorAll('#server-config-form input[name]').forEach(input => { // Assume inputs are within a form/div
|
362 |
+
configData[input.name] = input.value;
|
363 |
+
});
|
364 |
+
|
365 |
+
updateConfigStatus(configSaveBtn, configStatus, 'Saving...', true, 0); // Indefinite until success/error
|
366 |
+
|
367 |
+
try {
|
368 |
+
const response = await fetch('/save_config', {
|
369 |
+
method: 'POST',
|
370 |
+
headers: { 'Content-Type': 'application/json' },
|
371 |
+
body: JSON.stringify(configData)
|
372 |
+
});
|
373 |
+
const result = await response.json();
|
374 |
+
if (!response.ok) throw new Error(result.detail || 'Failed to save');
|
375 |
+
|
376 |
+
updateConfigStatus(configSaveBtn, configStatus, result.message, true);
|
377 |
+
if (configRestartBtn) configRestartBtn.classList.remove('hidden'); // Show restart button
|
378 |
+
|
379 |
+
} catch (error) {
|
380 |
+
console.error('Error saving server config:', error);
|
381 |
+
updateConfigStatus(configSaveBtn, configStatus, `Error: ${error.message}`, false);
|
382 |
+
}
|
383 |
+
});
|
384 |
+
}
|
385 |
+
|
386 |
+
// Restart Server
|
387 |
+
if (configRestartBtn) {
|
388 |
+
configRestartBtn.addEventListener('click', async () => {
|
389 |
+
configRestartBtn.disabled = true;
|
390 |
+
configRestartBtn.innerHTML = `
|
391 |
+
<svg class="animate-spin h-5 w-5 mr-1 inline-block" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
392 |
+
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
393 |
+
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
394 |
+
</svg>
|
395 |
+
Restarting...`;
|
396 |
+
updateConfigStatus(configRestartBtn, configStatus, 'Restarting...', true, 0); // Indefinite
|
397 |
+
|
398 |
+
try {
|
399 |
+
const response = await fetch('/restart_server', { method: 'POST' });
|
400 |
+
const result = await response.json();
|
401 |
+
if (!response.ok) throw new Error(result.detail || 'Failed to trigger restart');
|
402 |
+
|
403 |
+
updateConfigStatus(configRestartBtn, configStatus, result.message + " Page will attempt reload.", true, 15000); // Show longer
|
404 |
+
// Show main loading overlay during restart check
|
405 |
+
if (loadingOverlay) {
|
406 |
+
loadingMessage.textContent = 'Server restarting...';
|
407 |
+
loadingStatus.textContent = 'Waiting for server to respond...';
|
408 |
+
loadingCancelBtn.disabled = true; // Disable cancel during restart
|
409 |
+
loadingOverlay.classList.remove('hidden');
|
410 |
+
}
|
411 |
+
|
412 |
+
// Poll for server readiness
|
413 |
+
let attempts = 0;
|
414 |
+
const maxAttempts = 45; // Wait up to 45 seconds
|
415 |
+
function checkServerReady() {
|
416 |
+
attempts++;
|
417 |
+
console.log(`Checking server readiness (Attempt ${attempts}/${maxAttempts})...`);
|
418 |
+
loadingStatus.textContent = `Waiting for server... (${attempts}/${maxAttempts})`;
|
419 |
+
fetch('/health?cache=' + Date.now(), { cache: 'no-store', headers: { 'pragma': 'no-cache' } })
|
420 |
+
.then(res => {
|
421 |
+
if (res.ok) {
|
422 |
+
console.log("Server is ready. Reloading page.");
|
423 |
+
window.location.reload(true); // Force reload from server
|
424 |
+
} else if (attempts < maxAttempts) {
|
425 |
+
setTimeout(checkServerReady, 1000); // Check again in 1 second
|
426 |
+
} else {
|
427 |
+
throw new Error('Server did not become ready after restart.');
|
428 |
+
}
|
429 |
+
})
|
430 |
+
.catch(() => {
|
431 |
+
if (attempts < maxAttempts) {
|
432 |
+
setTimeout(checkServerReady, 1000); // Check again on connection error
|
433 |
+
} else {
|
434 |
+
throw new Error('Server did not respond after restart.');
|
435 |
+
}
|
436 |
+
});
|
437 |
+
}
|
438 |
+
setTimeout(checkServerReady, 3000); // Start checking after 3 seconds
|
439 |
+
|
440 |
+
} catch (error) {
|
441 |
+
console.error('Error restarting server:', error);
|
442 |
+
updateConfigStatus(configRestartBtn, configStatus, `Restart Error: ${error.message}`, false);
|
443 |
+
configRestartBtn.disabled = false; // Re-enable button on error
|
444 |
+
configRestartBtn.innerHTML = `
|
445 |
+
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-5 h-5 mr-1 inline-block"><path stroke-linecap="round" stroke-linejoin="round" d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0 3.181 3.183a8.25 8.25 0 0 0 13.803-3.7M4.031 9.865a8.25 8.25 0 0 1 13.803-3.7l3.181 3.182m0-4.991v4.99" /></svg>
|
446 |
+
Restart Server`;
|
447 |
+
if (loadingOverlay) loadingOverlay.classList.add('hidden');
|
448 |
+
}
|
449 |
+
});
|
450 |
+
}
|
451 |
+
|
452 |
+
// Save Generation Defaults
|
453 |
+
if (genDefaultsSaveBtn) {
|
454 |
+
genDefaultsSaveBtn.addEventListener('click', async () => {
|
455 |
+
const genParams = {};
|
456 |
+
sliders.forEach(s => {
|
457 |
+
const slider = document.getElementById(s.id);
|
458 |
+
if (slider) genParams[s.id] = slider.value;
|
459 |
+
});
|
460 |
+
|
461 |
+
updateConfigStatus(genDefaultsSaveBtn, genDefaultsStatus, 'Saving...', true, 0);
|
462 |
+
|
463 |
+
try {
|
464 |
+
const response = await fetch('/save_generation_defaults', {
|
465 |
+
method: 'POST',
|
466 |
+
headers: { 'Content-Type': 'application/json' },
|
467 |
+
body: JSON.stringify(genParams)
|
468 |
+
});
|
469 |
+
const result = await response.json();
|
470 |
+
if (!response.ok) throw new Error(result.detail || 'Failed to save');
|
471 |
+
updateConfigStatus(genDefaultsSaveBtn, genDefaultsStatus, result.message, true);
|
472 |
+
|
473 |
+
} catch (error) {
|
474 |
+
console.error('Error saving generation defaults:', error);
|
475 |
+
updateConfigStatus(genDefaultsSaveBtn, genDefaultsStatus, `Error: ${error.message}`, false);
|
476 |
+
}
|
477 |
+
});
|
478 |
+
}
|
479 |
+
|
480 |
+
// --- Reference Audio Upload ---
|
481 |
+
if (cloneLoadButton && cloneFileInput && cloneReferenceSelect) {
|
482 |
+
cloneLoadButton.addEventListener('click', () => {
|
483 |
+
cloneFileInput.click(); // Trigger hidden file input
|
484 |
+
});
|
485 |
+
|
486 |
+
cloneFileInput.addEventListener('change', async (event) => {
|
487 |
+
const files = event.target.files;
|
488 |
+
if (!files || files.length === 0) {
|
489 |
+
return; // No files selected
|
490 |
+
}
|
491 |
+
|
492 |
+
cloneLoadButton.disabled = true;
|
493 |
+
cloneLoadButton.textContent = 'Uploading...';
|
494 |
+
showNotification(`Uploading ${files.length} file(s)...`, 'info', 0); // Indefinite
|
495 |
+
|
496 |
+
const formData = new FormData();
|
497 |
+
for (const file of files) {
|
498 |
+
formData.append('files', file);
|
499 |
+
}
|
500 |
+
|
501 |
+
try {
|
502 |
+
const response = await fetch('/upload_reference', {
|
503 |
+
method: 'POST',
|
504 |
+
body: formData
|
505 |
+
// Content-Type is set automatically for FormData
|
506 |
+
});
|
507 |
+
|
508 |
+
const result = await response.json();
|
509 |
+
|
510 |
+
// Clear existing notifications before showing results
|
511 |
+
notificationArea.innerHTML = '';
|
512 |
+
|
513 |
+
if (!response.ok) {
|
514 |
+
throw new Error(result.message || `Upload failed with status ${response.status}`);
|
515 |
+
}
|
516 |
+
|
517 |
+
// Process results
|
518 |
+
if (result.errors && result.errors.length > 0) {
|
519 |
+
result.errors.forEach(err => showNotification(err, 'error'));
|
520 |
+
}
|
521 |
+
if (result.uploaded_files && result.uploaded_files.length > 0) {
|
522 |
+
showNotification(`Successfully uploaded: ${result.uploaded_files.join(', ')}`, 'success');
|
523 |
+
} else if (!result.errors || result.errors.length === 0) {
|
524 |
+
showNotification("Files processed, but no new files were added (might already exist).", 'info');
|
525 |
+
}
|
526 |
+
|
527 |
+
|
528 |
+
// Update dropdown
|
529 |
+
const currentSelection = cloneReferenceSelect.value;
|
530 |
+
cloneReferenceSelect.innerHTML = '<option value="none">-- Select Reference File --</option>'; // Clear existing options
|
531 |
+
result.all_reference_files.forEach(filename => {
|
532 |
+
const option = document.createElement('option');
|
533 |
+
option.value = filename;
|
534 |
+
option.textContent = filename;
|
535 |
+
cloneReferenceSelect.appendChild(option);
|
536 |
+
});
|
537 |
+
|
538 |
+
// Select the first newly uploaded file, or keep current selection if still valid
|
539 |
+
const firstUploaded = result.uploaded_files ? result.uploaded_files[0] : null;
|
540 |
+
if (firstUploaded) {
|
541 |
+
cloneReferenceSelect.value = firstUploaded;
|
542 |
+
} else if (result.all_reference_files.includes(currentSelection)) {
|
543 |
+
cloneReferenceSelect.value = currentSelection; // Restore previous valid selection
|
544 |
+
} else {
|
545 |
+
cloneReferenceSelect.value = 'none'; // Default if nothing else matches
|
546 |
+
}
|
547 |
+
|
548 |
+
} catch (error) {
|
549 |
+
console.error('Error uploading reference files:', error);
|
550 |
+
showNotification(`Upload Error: ${error.message}`, 'error');
|
551 |
+
} finally {
|
552 |
+
cloneLoadButton.disabled = false;
|
553 |
+
cloneLoadButton.textContent = 'Load';
|
554 |
+
cloneFileInput.value = ''; // Reset file input
|
555 |
+
}
|
556 |
+
});
|
557 |
+
}
|
558 |
+
|
559 |
+
// --- Theme Toggle ---
|
560 |
+
function applyTheme(theme) {
|
561 |
+
if (theme === 'light') {
|
562 |
+
document.documentElement.classList.remove('dark');
|
563 |
+
if (themeIconLight) themeIconLight.classList.remove('hidden');
|
564 |
+
if (themeIconDark) themeIconDark.classList.add('hidden');
|
565 |
+
} else {
|
566 |
+
document.documentElement.classList.add('dark');
|
567 |
+
if (themeIconLight) themeIconLight.classList.add('hidden');
|
568 |
+
if (themeIconDark) themeIconDark.classList.remove('hidden');
|
569 |
+
}
|
570 |
+
// Update wavesurfer colors if it exists
|
571 |
+
if (wavesurfer) {
|
572 |
+
wavesurfer.setOptions({
|
573 |
+
waveColor: theme === 'light' ? '#0ea5e9' : '#38bdf8',
|
574 |
+
progressColor: theme === 'light' ? '#0369a1' : '#0284c7',
|
575 |
+
cursorColor: theme === 'light' ? '#9333ea' : '#a855f7',
|
576 |
+
});
|
577 |
+
}
|
578 |
+
}
|
579 |
+
|
580 |
+
if (themeToggleButton) {
|
581 |
+
// Check localStorage on load
|
582 |
+
const savedTheme = localStorage.getItem('theme') || 'dark'; // Default to dark
|
583 |
+
applyTheme(savedTheme);
|
584 |
+
|
585 |
+
themeToggleButton.addEventListener('click', () => {
|
586 |
+
const isDark = document.documentElement.classList.contains('dark');
|
587 |
+
const newTheme = isDark ? 'light' : 'dark';
|
588 |
+
applyTheme(newTheme);
|
589 |
+
localStorage.setItem('theme', newTheme); // Save preference
|
590 |
+
});
|
591 |
+
}
|
592 |
+
|
593 |
+
}); // End DOMContentLoaded
|
utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utils.py
|
2 |
+
# Utility functions for the Dia TTS server
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import io
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
from typing import Optional, Tuple
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# --- Audio Processing ---
|
15 |
+
|
16 |
+
|
17 |
+
def encode_audio(
|
18 |
+
audio_array: np.ndarray, sample_rate: int, output_format: str = "opus"
|
19 |
+
) -> Optional[bytes]:
|
20 |
+
"""
|
21 |
+
Encodes a NumPy audio array into the specified format in memory.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
audio_array: NumPy array containing audio data (float32, range [-1, 1]).
|
25 |
+
sample_rate: Sample rate of the audio data.
|
26 |
+
output_format: Desired output format ('opus' or 'wav').
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Bytes object containing the encoded audio, or None on failure.
|
30 |
+
"""
|
31 |
+
if audio_array is None or audio_array.size == 0:
|
32 |
+
logger.warning("encode_audio received empty or None audio array.")
|
33 |
+
return None
|
34 |
+
|
35 |
+
start_time = time.time()
|
36 |
+
output_buffer = io.BytesIO()
|
37 |
+
|
38 |
+
try:
|
39 |
+
if output_format == "opus":
|
40 |
+
# Soundfile expects int16 for Opus usually, but let's try float32 first
|
41 |
+
# It might convert internally or require specific subtypes.
|
42 |
+
# If this fails, we might need to convert to int16 first:
|
43 |
+
# audio_int16 = (audio_array * 32767).astype(np.int16)
|
44 |
+
# sf.write(output_buffer, audio_int16, sample_rate, format='ogg', subtype='opus')
|
45 |
+
sf.write(
|
46 |
+
output_buffer, audio_array, sample_rate, format="ogg", subtype="opus"
|
47 |
+
)
|
48 |
+
content_type = "audio/ogg; codecs=opus"
|
49 |
+
elif output_format == "wav":
|
50 |
+
# WAV typically uses int16
|
51 |
+
audio_int16 = (audio_array * 32767).astype(np.int16)
|
52 |
+
sf.write(
|
53 |
+
output_buffer, audio_int16, sample_rate, format="wav", subtype="pcm_16"
|
54 |
+
)
|
55 |
+
content_type = "audio/wav"
|
56 |
+
else:
|
57 |
+
logger.error(f"Unsupported output format requested: {output_format}")
|
58 |
+
return None
|
59 |
+
|
60 |
+
encoded_bytes = output_buffer.getvalue()
|
61 |
+
end_time = time.time()
|
62 |
+
logger.info(
|
63 |
+
f"Encoded {len(encoded_bytes)} bytes to {output_format} in {end_time - start_time:.3f} seconds."
|
64 |
+
)
|
65 |
+
return encoded_bytes
|
66 |
+
|
67 |
+
except ImportError:
|
68 |
+
logger.critical(
|
69 |
+
"`soundfile` or its dependency `libsndfile` not found/installed correctly. Cannot encode audio."
|
70 |
+
)
|
71 |
+
raise # Re-raise critical error
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"Error encoding audio to {output_format}: {e}", exc_info=True)
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
+
def save_audio_to_file(
|
78 |
+
audio_array: np.ndarray, sample_rate: int, file_path: str
|
79 |
+
) -> bool:
|
80 |
+
"""
|
81 |
+
Saves a NumPy audio array to a WAV file.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
audio_array: NumPy array containing audio data (float32, range [-1, 1]).
|
85 |
+
sample_rate: Sample rate of the audio data.
|
86 |
+
file_path: Path to save the WAV file.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
True if saving was successful, False otherwise.
|
90 |
+
"""
|
91 |
+
if audio_array is None or audio_array.size == 0:
|
92 |
+
logger.warning("save_audio_to_file received empty or None audio array.")
|
93 |
+
return False
|
94 |
+
if not file_path.lower().endswith(".wav"):
|
95 |
+
logger.warning(
|
96 |
+
f"File path '{file_path}' does not end with .wav. Saving as WAV anyway."
|
97 |
+
)
|
98 |
+
# Optionally change the extension: file_path += ".wav"
|
99 |
+
|
100 |
+
start_time = time.time()
|
101 |
+
try:
|
102 |
+
# Ensure output directory exists
|
103 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
104 |
+
|
105 |
+
# WAV typically uses int16
|
106 |
+
audio_int16 = (audio_array * 32767).astype(np.int16)
|
107 |
+
sf.write(file_path, audio_int16, sample_rate, format="wav", subtype="pcm_16")
|
108 |
+
|
109 |
+
end_time = time.time()
|
110 |
+
logger.info(
|
111 |
+
f"Saved WAV file to {file_path} in {end_time - start_time:.3f} seconds."
|
112 |
+
)
|
113 |
+
return True
|
114 |
+
except ImportError:
|
115 |
+
logger.critical(
|
116 |
+
"`soundfile` or its dependency `libsndfile` not found/installed correctly. Cannot save audio."
|
117 |
+
)
|
118 |
+
return False # Indicate failure
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Error saving WAV file to {file_path}: {e}", exc_info=True)
|
121 |
+
return False
|
122 |
+
|
123 |
+
|
124 |
+
# --- Other Utilities (Optional) ---
|
125 |
+
|
126 |
+
|
127 |
+
class PerformanceMonitor:
|
128 |
+
"""Simple performance monitoring."""
|
129 |
+
|
130 |
+
def __init__(self):
|
131 |
+
self.start_time = time.time()
|
132 |
+
self.events = []
|
133 |
+
|
134 |
+
def record(self, event_name: str):
|
135 |
+
self.events.append((event_name, time.time()))
|
136 |
+
|
137 |
+
def report(self) -> str:
|
138 |
+
report_lines = ["Performance Report:"]
|
139 |
+
last_time = self.start_time
|
140 |
+
total_duration = time.time() - self.start_time
|
141 |
+
for name, timestamp in self.events:
|
142 |
+
duration = timestamp - last_time
|
143 |
+
report_lines.append(f" - {name}: {duration:.3f}s")
|
144 |
+
last_time = timestamp
|
145 |
+
report_lines.append(f"Total Duration: {total_duration:.3f}s")
|
146 |
+
return "\n".join(report_lines)
|