Michael Hu commited on
Commit
ac5de5b
·
1 Parent(s): 2493d3b

initial check in of the dia tts server

Browse files
Files changed (20) hide show
  1. .env +35 -0
  2. Dockerfile +36 -0
  3. README.md +1 -0
  4. config.py +295 -0
  5. dia/__init__.py +0 -0
  6. dia/audio.py +280 -0
  7. dia/config.py +206 -0
  8. dia/layers.py +903 -0
  9. dia/model.py +956 -0
  10. docker-compose.yml +23 -0
  11. documentation.md +549 -0
  12. download_model.py +41 -0
  13. engine.py +356 -0
  14. models.py +97 -0
  15. requirements.txt +22 -0
  16. server.py +1061 -0
  17. ui/index.html +916 -0
  18. ui/presets.yaml +57 -0
  19. ui/script.js +593 -0
  20. utils.py +146 -0
.env ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .env - Configuration for Dia TTS Server
2
+ # Values in this file override the defaults set in config.py
3
+
4
+ # --- Server Settings ---
5
+ HOST='0.0.0.0'
6
+ PORT='8003'
7
+
8
+ # --- Path Settings ---
9
+ # Defaults are usually fine unless you want custom locations.
10
+ DIA_MODEL_CACHE_PATH='./model_cache'
11
+ REFERENCE_AUDIO_PATH='./reference_audio'
12
+ OUTPUT_PATH='./outputs'
13
+
14
+ # --- Model Source Settings ---
15
+ # Defaulting to BF16 safetensors. Uncomment and modify lines below to use other models.
16
+ DIA_MODEL_REPO_ID='ttj/dia-1.6b-safetensors'
17
+ DIA_MODEL_CONFIG_FILENAME='config.json'
18
+ DIA_MODEL_WEIGHTS_FILENAME='dia-v0_1_bf16.safetensors'
19
+
20
+ # Example: Use full precision safetensors
21
+ # DIA_MODEL_REPO_ID=ttj/dia-1.6b-safetensors
22
+ # DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.safetensors
23
+
24
+ # Example: Use original Nari Labs .pth model
25
+ # DIA_MODEL_REPO_ID=nari-labs/Dia-1.6B
26
+ # DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.pth
27
+
28
+ # --- Default Generation Parameters ---
29
+ # These set the initial values loaded in the UI.
30
+ # They can be changed in the UI and saved back here using the 'Save Generation Defaults' button.
31
+ GEN_DEFAULT_SPEED_FACTOR='0.9'
32
+ GEN_DEFAULT_CFG_SCALE='3'
33
+ GEN_DEFAULT_TEMPERATURE='1.3'
34
+ GEN_DEFAULT_TOP_P='0.95'
35
+ GEN_DEFAULT_CFG_FILTER_TOP_K='35'
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ libsndfile1 \
12
+ ffmpeg \
13
+ python3 \
14
+ python3-pip \
15
+ python3-dev \
16
+ && apt-get clean \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Set up working directory
20
+ WORKDIR /app
21
+
22
+ # Install Python dependencies
23
+ COPY requirements.txt .
24
+ RUN pip3 install --no-cache-dir -r requirements.txt
25
+
26
+ # Copy application code
27
+ COPY . .
28
+
29
+ # Create required directories
30
+ RUN mkdir -p model_cache reference_audio outputs
31
+
32
+ # Expose the port the application will run on (default to 8003 as per config)
33
+ EXPOSE 8003
34
+
35
+ # Command to run the application
36
+ CMD ["python3", "server.py"]
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: indigo
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8003
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
config.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ # Configuration management for Dia TTS server
3
+
4
+ import os
5
+ import logging
6
+ from dotenv import load_dotenv, find_dotenv, set_key
7
+ from typing import Dict, Any, Optional
8
+
9
+ # Configure logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Default configuration values (used if not found in .env or environment)
13
+ DEFAULT_CONFIG = {
14
+ # Server Settings
15
+ "HOST": "0.0.0.0",
16
+ "PORT": "8003",
17
+ # Model Source Settings
18
+ "DIA_MODEL_REPO_ID": "ttj/dia-1.6b-safetensors", # Default to safetensors repo
19
+ "DIA_MODEL_CONFIG_FILENAME": "config.json", # Standard config filename
20
+ "DIA_MODEL_WEIGHTS_FILENAME": "dia-v0_1_bf16.safetensors", # Default to BF16 weights
21
+ # Path Settings
22
+ "DIA_MODEL_CACHE_PATH": "./model_cache",
23
+ "REFERENCE_AUDIO_PATH": "./reference_audio",
24
+ "OUTPUT_PATH": "./outputs",
25
+ # Default Generation Parameters (can be overridden by user in UI/API)
26
+ # These are saved to .env via the UI's "Save Generation Defaults" button
27
+ "GEN_DEFAULT_SPEED_FACTOR": "0.90", # Default speed slightly slower
28
+ "GEN_DEFAULT_CFG_SCALE": "3.0",
29
+ "GEN_DEFAULT_TEMPERATURE": "1.3",
30
+ "GEN_DEFAULT_TOP_P": "0.95",
31
+ "GEN_DEFAULT_CFG_FILTER_TOP_K": "35",
32
+ }
33
+
34
+
35
+ class ConfigManager:
36
+ """Manages configuration for the TTS server with .env file support."""
37
+
38
+ def __init__(self):
39
+ """Initialize the configuration manager."""
40
+ self.config = {}
41
+ self.env_file = find_dotenv()
42
+
43
+ if not self.env_file:
44
+ self.env_file = os.path.join(os.getcwd(), ".env")
45
+ logger.info(
46
+ f"No .env file found, creating one with defaults at {self.env_file}"
47
+ )
48
+ self._create_default_env_file()
49
+ else:
50
+ logger.info(f"Loading configuration from: {self.env_file}")
51
+
52
+ self.reload()
53
+
54
+ def _create_default_env_file(self):
55
+ """Create a default .env file with default values."""
56
+ try:
57
+ with open(self.env_file, "w") as f:
58
+ for key, value in DEFAULT_CONFIG.items():
59
+ f.write(f"{key}={value}\n")
60
+ logger.info("Created default .env file")
61
+ except Exception as e:
62
+ logger.error(f"Failed to create default .env file: {e}")
63
+
64
+ def reload(self):
65
+ """Reload configuration from .env file and environment variables."""
66
+ load_dotenv(self.env_file, override=True)
67
+ loaded_config = {}
68
+ for key, default_value in DEFAULT_CONFIG.items():
69
+ loaded_config[key] = os.environ.get(key, default_value)
70
+ self.config = loaded_config
71
+ logger.info("Configuration loaded/reloaded.")
72
+ logger.debug(f"Current config: {self.config}")
73
+ return self.config
74
+
75
+ def get(self, key: str, default: Any = None) -> Any:
76
+ """Get a configuration value by key."""
77
+ return self.config.get(key, default)
78
+
79
+ def set(self, key: str, value: Any) -> None:
80
+ """Set a configuration value in memory (does not save automatically)."""
81
+ self.config[key] = value
82
+ logger.debug(f"Configuration value set in memory: {key}={value}")
83
+
84
+ def save(self) -> bool:
85
+ """Save the current in-memory configuration to the .env file."""
86
+ if not self.env_file:
87
+ logger.error("Cannot save configuration, .env file path not set.")
88
+ return False
89
+ try:
90
+ for key in DEFAULT_CONFIG.keys():
91
+ if key not in self.config:
92
+ logger.warning(
93
+ f"Key '{key}' missing from current config, adding default value before saving."
94
+ )
95
+ self.config[key] = DEFAULT_CONFIG[key]
96
+ for key, value in self.config.items():
97
+ if key in DEFAULT_CONFIG:
98
+ set_key(self.env_file, key, str(value))
99
+ logger.info(f"Configuration saved to {self.env_file}")
100
+ return True
101
+ except Exception as e:
102
+ logger.error(
103
+ f"Failed to save configuration to {self.env_file}: {e}", exc_info=True
104
+ )
105
+ return False
106
+
107
+ def get_all(self) -> Dict[str, Any]:
108
+ """Get all current configuration values."""
109
+ return self.config.copy()
110
+
111
+ def update(self, new_config: Dict[str, Any]) -> None:
112
+ """Update multiple configuration values in memory from a dictionary."""
113
+ updated_keys = []
114
+ for key, value in new_config.items():
115
+ if key in DEFAULT_CONFIG:
116
+ self.config[key] = value
117
+ updated_keys.append(key)
118
+ else:
119
+ logger.warning(
120
+ f"Attempted to update unknown config key: {key}. Ignoring."
121
+ )
122
+ if updated_keys:
123
+ logger.debug(
124
+ f"Configuration values updated in memory for keys: {updated_keys}"
125
+ )
126
+
127
+ def get_int(self, key: str, default: Optional[int] = None) -> int:
128
+ """Get a configuration value as an integer, with error handling."""
129
+ value_str = self.get(key) # Get value which might be from env (str) or default
130
+ if value_str is None: # Key not found at all
131
+ if default is not None:
132
+ logger.warning(
133
+ f"Config key '{key}' not found, using provided default: {default}"
134
+ )
135
+ return default
136
+ else:
137
+ logger.error(
138
+ f"Mandatory config key '{key}' not found and no default provided. Returning 0."
139
+ )
140
+ return 0 # Or raise error
141
+
142
+ try:
143
+ return int(value_str)
144
+ except (ValueError, TypeError):
145
+ logger.warning(
146
+ f"Invalid integer value '{value_str}' for config key '{key}', using default: {default}"
147
+ )
148
+ if isinstance(default, int):
149
+ return default
150
+ elif default is None:
151
+ logger.error(
152
+ f"Cannot parse '{value_str}' as int for key '{key}' and no valid default. Returning 0."
153
+ )
154
+ return 0
155
+ else: # Default was provided but not an int
156
+ logger.error(
157
+ f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0."
158
+ )
159
+ return 0
160
+
161
+ def get_float(self, key: str, default: Optional[float] = None) -> float:
162
+ """Get a configuration value as a float, with error handling."""
163
+ value_str = self.get(key)
164
+ if value_str is None:
165
+ if default is not None:
166
+ logger.warning(
167
+ f"Config key '{key}' not found, using provided default: {default}"
168
+ )
169
+ return default
170
+ else:
171
+ logger.error(
172
+ f"Mandatory config key '{key}' not found and no default provided. Returning 0.0."
173
+ )
174
+ return 0.0
175
+
176
+ try:
177
+ return float(value_str)
178
+ except (ValueError, TypeError):
179
+ logger.warning(
180
+ f"Invalid float value '{value_str}' for config key '{key}', using default: {default}"
181
+ )
182
+ if isinstance(default, float):
183
+ return default
184
+ elif default is None:
185
+ logger.error(
186
+ f"Cannot parse '{value_str}' as float for key '{key}' and no valid default. Returning 0.0."
187
+ )
188
+ return 0.0
189
+ else:
190
+ logger.error(
191
+ f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0.0."
192
+ )
193
+ return 0.0
194
+
195
+
196
+ # --- Create a singleton instance for global access ---
197
+ config_manager = ConfigManager()
198
+
199
+
200
+ # --- Export common getters for easy access ---
201
+
202
+
203
+ # Server Settings
204
+ def get_host() -> str:
205
+ """Gets the host address for the server."""
206
+ return config_manager.get("HOST", DEFAULT_CONFIG["HOST"])
207
+
208
+
209
+ def get_port() -> int:
210
+ """Gets the port number for the server."""
211
+ # Ensure default is parsed correctly if get_int fails on env var
212
+ return config_manager.get_int("PORT", int(DEFAULT_CONFIG["PORT"]))
213
+
214
+
215
+ # Model Source Settings
216
+ def get_model_repo_id() -> str:
217
+ """Gets the Hugging Face repository ID for the model."""
218
+ return config_manager.get("DIA_MODEL_REPO_ID", DEFAULT_CONFIG["DIA_MODEL_REPO_ID"])
219
+
220
+
221
+ def get_model_config_filename() -> str:
222
+ """Gets the filename for the model's configuration file within the repo."""
223
+ return config_manager.get(
224
+ "DIA_MODEL_CONFIG_FILENAME", DEFAULT_CONFIG["DIA_MODEL_CONFIG_FILENAME"]
225
+ )
226
+
227
+
228
+ def get_model_weights_filename() -> str:
229
+ """Gets the filename for the model's weights file within the repo."""
230
+ return config_manager.get(
231
+ "DIA_MODEL_WEIGHTS_FILENAME", DEFAULT_CONFIG["DIA_MODEL_WEIGHTS_FILENAME"]
232
+ )
233
+
234
+
235
+ # Path Settings
236
+ def get_model_cache_path() -> str:
237
+ """Gets the local directory path for caching downloaded models."""
238
+ return os.path.abspath(
239
+ config_manager.get(
240
+ "DIA_MODEL_CACHE_PATH", DEFAULT_CONFIG["DIA_MODEL_CACHE_PATH"]
241
+ )
242
+ )
243
+
244
+
245
+ def get_reference_audio_path() -> str:
246
+ """Gets the local directory path for storing reference audio files for cloning."""
247
+ return os.path.abspath(
248
+ config_manager.get(
249
+ "REFERENCE_AUDIO_PATH", DEFAULT_CONFIG["REFERENCE_AUDIO_PATH"]
250
+ )
251
+ )
252
+
253
+
254
+ def get_output_path() -> str:
255
+ """Gets the local directory path for saving generated audio outputs."""
256
+ return os.path.abspath(
257
+ config_manager.get("OUTPUT_PATH", DEFAULT_CONFIG["OUTPUT_PATH"])
258
+ )
259
+
260
+
261
+ # Default Generation Parameter Getters
262
+ def get_gen_default_speed_factor() -> float:
263
+ """Gets the default speed factor for generation."""
264
+ return config_manager.get_float(
265
+ "GEN_DEFAULT_SPEED_FACTOR", float(DEFAULT_CONFIG["GEN_DEFAULT_SPEED_FACTOR"])
266
+ )
267
+
268
+
269
+ def get_gen_default_cfg_scale() -> float:
270
+ """Gets the default CFG scale for generation."""
271
+ return config_manager.get_float(
272
+ "GEN_DEFAULT_CFG_SCALE", float(DEFAULT_CONFIG["GEN_DEFAULT_CFG_SCALE"])
273
+ )
274
+
275
+
276
+ def get_gen_default_temperature() -> float:
277
+ """Gets the default temperature for generation."""
278
+ return config_manager.get_float(
279
+ "GEN_DEFAULT_TEMPERATURE", float(DEFAULT_CONFIG["GEN_DEFAULT_TEMPERATURE"])
280
+ )
281
+
282
+
283
+ def get_gen_default_top_p() -> float:
284
+ """Gets the default top_p for generation."""
285
+ return config_manager.get_float(
286
+ "GEN_DEFAULT_TOP_P", float(DEFAULT_CONFIG["GEN_DEFAULT_TOP_P"])
287
+ )
288
+
289
+
290
+ def get_gen_default_cfg_filter_top_k() -> int:
291
+ """Gets the default CFG filter top_k for generation."""
292
+ return config_manager.get_int(
293
+ "GEN_DEFAULT_CFG_FILTER_TOP_K",
294
+ int(DEFAULT_CONFIG["GEN_DEFAULT_CFG_FILTER_TOP_K"]),
295
+ )
dia/__init__.py ADDED
File without changes
dia/audio.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+ from .config import DataConfig
6
+
7
+
8
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9
+ """
10
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11
+ Negative t_idx => BOS; t_idx >= T => PAD.
12
+ """
13
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
14
+
15
+ t_idx_BxT = torch.broadcast_to(
16
+ torch.arange(T, dtype=torch.int32)[None, :],
17
+ [B, T],
18
+ )
19
+ t_idx_BxTx1 = t_idx_BxT[..., None]
20
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
21
+
22
+ b_idx_BxTxC = torch.broadcast_to(
23
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
24
+ [B, T, C],
25
+ )
26
+ c_idx_BxTxC = torch.broadcast_to(
27
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
28
+ [B, T, C],
29
+ )
30
+
31
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
32
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
33
+
34
+ indices_BTCx3 = torch.stack(
35
+ [
36
+ b_idx_BxTxC.reshape(-1),
37
+ t_clamped_BxTxC.reshape(-1),
38
+ c_idx_BxTxC.reshape(-1),
39
+ ],
40
+ dim=1,
41
+ ).long() # Ensure indices are long type for indexing
42
+
43
+ return t_idx_BxTxC, indices_BTCx3
44
+
45
+
46
+ def apply_audio_delay(
47
+ audio_BxTxC: torch.Tensor,
48
+ pad_value: int,
49
+ bos_value: int,
50
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
51
+ ) -> torch.Tensor:
52
+ """
53
+ Applies the delay pattern to batched audio tokens using precomputed indices,
54
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
55
+
56
+ Args:
57
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
58
+ pad_value: the padding token
59
+ bos_value: the BOS token
60
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
61
+
62
+ Returns:
63
+ result_BxTxC: [B, T, C] delayed audio tokens
64
+ """
65
+ device = audio_BxTxC.device # Get device from input tensor
66
+ t_idx_BxTxC, indices_BTCx3 = precomp
67
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
68
+ indices_BTCx3 = indices_BTCx3.to(device)
69
+
70
+ # Equivalent of tf.gather_nd using advanced indexing
71
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
72
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
73
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
74
+
75
+ # Create masks on the correct device
76
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
77
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
78
+
79
+ # Create scalar tensors on the correct device
80
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
81
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
82
+
83
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
84
+ # All tensors should now be on the same device
85
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
86
+
87
+ return result_BxTxC
88
+
89
+
90
+ @torch.no_grad()
91
+ @torch.inference_mode()
92
+ def audio_to_codebook(
93
+ model,
94
+ input_values,
95
+ data_config: DataConfig,
96
+ padding_mask=None,
97
+ sample_rate=44100,
98
+ ):
99
+ """
100
+ Encodes the input audio waveform into discrete codes.
101
+
102
+ Args:
103
+ model: The model to use for encoding.
104
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
105
+ Float values of the input audio waveform.
106
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
107
+ Padding mask used to pad the `input_values`.
108
+ sample_rate (`int`, *optional*) :
109
+ Signal sampling_rate
110
+
111
+ Returns:
112
+ A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
113
+ factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
114
+ `codebook` of shape `[batch_size, num_codebooks, frames]`.
115
+ Scale is not used here.
116
+
117
+ """
118
+ audio_data = model.preprocess(input_values, sample_rate)
119
+
120
+ if padding_mask is None:
121
+ padding_mask = torch.ones_like(input_values).bool()
122
+
123
+ _, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
124
+ seq_length = encoded_frame.shape[2]
125
+
126
+ t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
127
+ B=1,
128
+ T=seq_length,
129
+ C=data_config.channels,
130
+ delay_pattern=data_config.delay_pattern,
131
+ )
132
+
133
+ encoded_frame = apply_audio_delay(
134
+ audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
135
+ pad_value=data_config.audio_pad_value,
136
+ bos_value=data_config.audio_bos_value,
137
+ precomp=(t_idx_BxTxC, indices_BTCx3),
138
+ )
139
+
140
+ return encoded_frame
141
+
142
+
143
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
144
+ """
145
+ Precompute indices for the revert operation using PyTorch.
146
+
147
+ Returns:
148
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
149
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
150
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
151
+ batch indices, clamped time indices, and channel indices.
152
+ """
153
+ # Use default device unless specified otherwise; assumes inputs might define device later
154
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
155
+
156
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
157
+
158
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
159
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
160
+
161
+ t_idx_BxTxC = torch.minimum(
162
+ t_idx_BT1 + delay_arr.view(1, 1, C),
163
+ torch.tensor(T - 1, device=device),
164
+ )
165
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
166
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
167
+
168
+ indices_BTCx3 = torch.stack(
169
+ [
170
+ b_idx_BxTxC.reshape(-1),
171
+ t_idx_BxTxC.reshape(-1),
172
+ c_idx_BxTxC.reshape(-1),
173
+ ],
174
+ axis=1,
175
+ ).long() # Ensure indices are long type
176
+
177
+ return t_idx_BxTxC, indices_BTCx3
178
+
179
+
180
+ def revert_audio_delay(
181
+ audio_BxTxC: torch.Tensor,
182
+ pad_value: int,
183
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
184
+ T: int,
185
+ ) -> torch.Tensor:
186
+ """
187
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
188
+
189
+ Args:
190
+ audio_BxTxC: Input delayed audio tensor
191
+ pad_value: Padding value for out-of-bounds indices
192
+ precomp: Precomputed revert indices tuple containing:
193
+ - t_idx_BxTxC: Time offset indices tensor
194
+ - indices_BTCx3: Gather indices tensor for original audio
195
+ T: Original sequence length before padding
196
+
197
+ Returns:
198
+ Reverted audio tensor with same shape as input
199
+ """
200
+ t_idx_BxTxC, indices_BTCx3 = precomp
201
+ device = audio_BxTxC.device # Get device from input tensor
202
+
203
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
204
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
205
+ indices_BTCx3 = indices_BTCx3.to(device)
206
+
207
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
208
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
209
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
210
+
211
+ # Create pad_tensor on the correct device
212
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
213
+ # Create T tensor on the correct device for comparison
214
+ T_tensor = torch.tensor(T, device=device)
215
+
216
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
217
+
218
+ return result_BxTxC
219
+
220
+
221
+ @torch.no_grad()
222
+ @torch.inference_mode()
223
+ def decode(
224
+ model,
225
+ audio_codes,
226
+ ):
227
+ """
228
+ Decodes the given frames into an output audio waveform
229
+ """
230
+ if len(audio_codes) != 1:
231
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
232
+
233
+ try:
234
+ audio_values = model.quantizer.from_codes(audio_codes)
235
+ audio_values = model.decode(audio_values[0])
236
+
237
+ return audio_values
238
+ except Exception as e:
239
+ print(f"Error in decode method: {str(e)}")
240
+ raise
241
+
242
+
243
+ def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
244
+ """Process a single codebook file to generate audio"""
245
+ # Remove BOS token
246
+ generated_codes = generated_codes[:, 1:]
247
+
248
+ if generated_codes.shape[1] > T:
249
+ generated_codes = generated_codes[:, :T]
250
+
251
+ seq_length = generated_codes.shape[1]
252
+
253
+ # Build revert indices
254
+ t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
255
+
256
+ # Transpose and add batch dimension
257
+ audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
258
+ reverted_codebook = revert_audio_delay(
259
+ audio_BxTxC=audio_BxTxC,
260
+ pad_value=0,
261
+ precomp=(t_idx_BxTxC, indices_BTCx3),
262
+ T=seq_length,
263
+ )
264
+ reverted_codebook = reverted_codebook[:, :-30, :]
265
+
266
+ codebook = reverted_codebook.transpose(1, 2)
267
+
268
+ min_valid_index = 0
269
+ max_valid_index = 1023
270
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
271
+
272
+ num_invalid = torch.sum(invalid_mask).item()
273
+ if num_invalid > 0:
274
+ print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
275
+
276
+ # Set invalid values to 0 (modify the tensor in-place)
277
+ codebook[invalid_mask] = 0
278
+ audio_array = decode(model, codebook)
279
+
280
+ return audio_array
dia/config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+ from typing import Annotated
18
+
19
+ from pydantic import BaseModel, BeforeValidator, Field
20
+
21
+
22
+ class DataConfig(BaseModel, frozen=True):
23
+ """Configuration for data loading and preprocessing.
24
+
25
+ Attributes:
26
+ text_length: Maximum length of text sequences (must be multiple of 128).
27
+ audio_length: Maximum length of audio sequences (must be multiple of 128).
28
+ channels: Number of audio channels.
29
+ text_pad_value: Value used for padding text sequences.
30
+ audio_eos_value: Value representing the end of audio sequences.
31
+ audio_bos_value: Value representing the beginning of audio sequences.
32
+ audio_pad_value: Value used for padding audio sequences.
33
+ delay_pattern: List of delay values for each audio channel.
34
+ """
35
+
36
+ text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
37
+ audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
38
+ channels: int = Field(default=9, gt=0, multiple_of=1)
39
+ text_pad_value: int = Field(default=0)
40
+ audio_eos_value: int = Field(default=1024)
41
+ audio_pad_value: int = Field(default=1025)
42
+ audio_bos_value: int = Field(default=1026)
43
+ delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
44
+
45
+ def __hash__(self) -> int:
46
+ """Generate a hash based on all fields of the config."""
47
+ return hash(
48
+ (
49
+ self.text_length,
50
+ self.audio_length,
51
+ self.channels,
52
+ self.text_pad_value,
53
+ self.audio_pad_value,
54
+ self.audio_bos_value,
55
+ self.audio_eos_value,
56
+ tuple(self.delay_pattern),
57
+ )
58
+ )
59
+
60
+
61
+ class EncoderConfig(BaseModel, frozen=True):
62
+ """Configuration for the encoder component of the Dia model.
63
+
64
+ Attributes:
65
+ n_layer: Number of transformer layers.
66
+ n_embd: Embedding dimension.
67
+ n_hidden: Hidden dimension size in the MLP layers.
68
+ n_head: Number of attention heads.
69
+ head_dim: Dimension per attention head.
70
+ mlp_activations: List of activation functions for the MLP layers.
71
+ use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
72
+ """
73
+
74
+ n_layer: int = Field(gt=0)
75
+ n_embd: int = Field(gt=0)
76
+ n_hidden: int = Field(gt=0)
77
+ n_head: int = Field(gt=0)
78
+ head_dim: int = Field(gt=0)
79
+ mlp_activations: list[str] = Field(default=["silu", "linear"])
80
+ use_pre_norm: bool = Field(default=False)
81
+
82
+
83
+ class DecoderConfig(BaseModel, frozen=True):
84
+ """Configuration for the decoder component of the Dia model.
85
+
86
+ Attributes:
87
+ n_layer: Number of transformer layers.
88
+ n_embd: Embedding dimension.
89
+ n_hidden: Hidden dimension size in the MLP layers.
90
+ gqa_query_heads: Number of query heads for grouped-query self-attention.
91
+ kv_heads: Number of key/value heads for grouped-query self-attention.
92
+ gqa_head_dim: Dimension per query head for grouped-query self-attention.
93
+ cross_query_heads: Number of query heads for cross-attention.
94
+ cross_head_dim: Dimension per cross-attention head.
95
+ mlp_activations: List of activation functions for the MLP layers.
96
+ use_pre_norm: Whether to use pre-normalization.
97
+ """
98
+
99
+ n_layer: int = Field(gt=0)
100
+ n_embd: int = Field(gt=0)
101
+ n_hidden: int = Field(gt=0)
102
+ gqa_query_heads: int = Field(gt=0)
103
+ kv_heads: int = Field(gt=0)
104
+ gqa_head_dim: int = Field(gt=0)
105
+ cross_query_heads: int = Field(gt=0)
106
+ cross_head_dim: int = Field(gt=0)
107
+ mlp_activations: list[str] = Field(default=["silu", "linear"])
108
+ use_pre_norm: bool = Field(default=False)
109
+
110
+
111
+ class ModelConfig(BaseModel, frozen=True):
112
+ """Main configuration container for the Dia model architecture.
113
+
114
+ Attributes:
115
+ encoder: Configuration for the encoder component.
116
+ decoder: Configuration for the decoder component.
117
+ src_vocab_size: Size of the source (text) vocabulary.
118
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
119
+ dropout: Dropout probability applied within the model.
120
+ normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
121
+ weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
122
+ rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
123
+ rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
124
+ """
125
+
126
+ encoder: EncoderConfig
127
+ decoder: DecoderConfig
128
+ src_vocab_size: int = Field(default=128, gt=0)
129
+ tgt_vocab_size: int = Field(default=1028, gt=0)
130
+ dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
131
+ normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
132
+ weight_dtype: str = Field(default="float32", description="Weight precision")
133
+ rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
134
+ rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
135
+
136
+
137
+ class TrainingConfig(BaseModel, frozen=True):
138
+ """Training process configuration and hyperparameters.
139
+
140
+ Note: This configuration currently only includes precision settings.
141
+ Other training parameters (like batch size, learning rate, optimizer settings)
142
+ are assumed to be handled externally.
143
+
144
+ Attributes:
145
+ dtype: Data type for activations during training (e.g., "bfloat16", "float32").
146
+ logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
147
+ """
148
+
149
+ dtype: str = Field(default="bfloat16", description="Activation precision")
150
+ logits_dot_in_fp32: bool = Field(default=False)
151
+
152
+
153
+ class DiaConfig(BaseModel, frozen=True):
154
+ """Master configuration for the Dia model.
155
+
156
+ Combines all sub-configurations into a single validated object.
157
+
158
+ Attributes:
159
+ version: Configuration version string.
160
+ model: Model architecture configuration.
161
+ training: Training process configuration (precision settings).
162
+ data: Data loading and processing configuration.
163
+ """
164
+
165
+ version: str = Field(default="1.0")
166
+ model: ModelConfig
167
+ training: TrainingConfig
168
+ data: DataConfig
169
+
170
+ def save(self, path: str) -> None:
171
+ """Save the current configuration instance to a JSON file.
172
+
173
+ Ensures the parent directory exists and the file has a .json extension.
174
+
175
+ Args:
176
+ path: The target file path to save the configuration.
177
+
178
+ Raises:
179
+ ValueError: If the path is not a file with a .json extension.
180
+ """
181
+ os.makedirs(os.path.dirname(path), exist_ok=True)
182
+ config_json = self.model_dump_json(indent=2)
183
+ with open(path, "w") as f:
184
+ f.write(config_json)
185
+
186
+ @classmethod
187
+ def load(cls, path: str) -> "DiaConfig | None":
188
+ """Load and validate a Dia configuration from a JSON file.
189
+
190
+ Args:
191
+ path: The path to the configuration file.
192
+
193
+ Returns:
194
+ A validated DiaConfig instance if the file exists and is valid,
195
+ otherwise None if the file is not found.
196
+
197
+ Raises:
198
+ ValueError: If the path does not point to an existing .json file.
199
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
200
+ """
201
+ try:
202
+ with open(path, "r") as f:
203
+ content = f.read()
204
+ return cls.model_validate_json(content)
205
+ except FileNotFoundError:
206
+ return None
dia/layers.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+ from torch.nn import RMSNorm
8
+
9
+ from .config import DiaConfig
10
+
11
+
12
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
+
15
+
16
+ def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
17
+ # Allow None for default behavior
18
+ if dtype_str is None or dtype_str.lower() == "none":
19
+ return None
20
+ if dtype_str == "float32":
21
+ return torch.float32
22
+ elif dtype_str == "float16":
23
+ return torch.float16
24
+ elif dtype_str == "bfloat16":
25
+ return torch.bfloat16
26
+ else:
27
+ raise ValueError(f"Unsupported dtype string: {dtype_str}")
28
+
29
+
30
+ class DenseGeneral(nn.Module):
31
+ """
32
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
33
+
34
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
35
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
36
+ and parameters created during initialization based on config.
37
+ `load_weights` validates shapes and copies data.
38
+
39
+ Attributes:
40
+ axis (Tuple[int, ...]): Input axis or axes to contract.
41
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
42
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
43
+ use_bias (bool): Whether to add a bias term.
44
+ weight (nn.Parameter): The kernel parameter.
45
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ in_shapes: tuple[int, ...],
51
+ out_features: tuple[int, ...],
52
+ axis: tuple[int, ...] = (-1,),
53
+ dtype: torch.dtype | None = None,
54
+ weight_dtype: torch.dtype | None = None,
55
+ device: torch.device | None = None,
56
+ ):
57
+ super().__init__()
58
+ self.in_shapes = in_shapes
59
+ self.out_features = out_features
60
+ self.axis = axis
61
+ self.dtype = dtype
62
+ self.kernel_shape = self.in_shapes + self.out_features
63
+
64
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
65
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
66
+ self.register_parameter("bias", None)
67
+
68
+ def forward(self, inputs: Tensor) -> Tensor:
69
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
70
+ kernel_contract_axes = tuple(range(len(norm_axis)))
71
+
72
+ output = torch.tensordot(
73
+ inputs.float(),
74
+ self.weight.float(),
75
+ dims=(norm_axis, kernel_contract_axes),
76
+ ).to(inputs.dtype)
77
+ return output
78
+
79
+
80
+ def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
81
+ """Maps activation string to PyTorch activation function module."""
82
+ if activation_string == "gelu":
83
+ return nn.GELU()
84
+ elif activation_string == "relu":
85
+ return nn.ReLU()
86
+ elif activation_string == "silu" or activation_string == "swish":
87
+ return nn.SiLU()
88
+ elif activation_string == "linear":
89
+ return nn.Identity()
90
+ else:
91
+ raise ValueError(f"Unsupported activation function: {activation_string}")
92
+
93
+
94
+ class MlpBlock(nn.Module):
95
+ """MLP block using DenseGeneral."""
96
+
97
+ def __init__(
98
+ self,
99
+ config: DiaConfig,
100
+ embed_dim: int,
101
+ intermediate_dim: int,
102
+ dropout_rate: float,
103
+ activations: list[str] = ["silu", "linear"],
104
+ use_pre_norm: bool = False,
105
+ ):
106
+ super().__init__()
107
+ self.use_pre_norm = use_pre_norm
108
+ num_activations = len(activations)
109
+ compute_dtype = _str_to_dtype(config.training.dtype)
110
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
111
+ self.dtype = compute_dtype
112
+ # Assume default device for now, could be passed in config
113
+
114
+ if use_pre_norm:
115
+ self.pre_norm = RMSNorm(
116
+ embed_dim,
117
+ eps=config.model.normalization_layer_epsilon,
118
+ dtype=torch.float32,
119
+ )
120
+
121
+ self.wi_fused = DenseGeneral(
122
+ in_shapes=(embed_dim,),
123
+ out_features=(
124
+ num_activations,
125
+ intermediate_dim,
126
+ ),
127
+ axis=(-1,),
128
+ dtype=compute_dtype,
129
+ weight_dtype=weight_dtype,
130
+ )
131
+
132
+ self.activation_fn_0 = get_activation_fn(activations[0]) # silu
133
+ self.activation_fn_1 = get_activation_fn(activations[1]) # linear
134
+
135
+ self.dropout = nn.Dropout(dropout_rate)
136
+
137
+ # Output layer using DenseGeneral
138
+ self.wo = DenseGeneral(
139
+ in_shapes=(intermediate_dim,),
140
+ out_features=(embed_dim,),
141
+ axis=(-1,),
142
+ dtype=compute_dtype,
143
+ weight_dtype=weight_dtype,
144
+ )
145
+
146
+ def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
147
+ """Forward pass."""
148
+ if self.use_pre_norm and hasattr(self, "pre_norm"):
149
+ x = self.pre_norm(x)
150
+
151
+ fused_x = self.wi_fused(x)
152
+
153
+ gate_input = fused_x[..., 0, :]
154
+ up_input = fused_x[..., 1, :]
155
+
156
+ gate = self.activation_fn_0(gate_input)
157
+ up = self.activation_fn_1(up_input)
158
+ hidden = torch.mul(gate, up).to(self.dtype)
159
+
160
+ if not deterministic:
161
+ hidden = self.dropout(hidden)
162
+
163
+ output = self.wo(hidden)
164
+ return output
165
+
166
+
167
+ class RotaryEmbedding(nn.Module):
168
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
169
+
170
+ def __init__(
171
+ self,
172
+ embedding_dims: int,
173
+ min_timescale: int = 1,
174
+ max_timescale: int = 10000,
175
+ dtype: torch.dtype = torch.float32,
176
+ ):
177
+ super().__init__()
178
+ if embedding_dims % 2 != 0:
179
+ raise ValueError("Embedding dim must be even for RoPE.")
180
+ self.embedding_dims = embedding_dims
181
+ self.min_timescale = min_timescale
182
+ self.max_timescale = max_timescale
183
+ self.dtype = dtype
184
+
185
+ half_embedding_dim = embedding_dims // 2
186
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
187
+ self.register_buffer(
188
+ "timescale",
189
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
190
+ persistent=False,
191
+ )
192
+
193
+ def extra_repr(self) -> str:
194
+ s = f"{self.timescale.shape}"
195
+ return s
196
+
197
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
198
+ """Applies RoPE."""
199
+ position = position.unsqueeze(-1).unsqueeze(-1)
200
+ timescale = self.timescale.to(inputs.device)
201
+ sinusoid_inp = position / timescale
202
+ sin = torch.sin(sinusoid_inp).to(inputs.dtype)
203
+ cos = torch.cos(sinusoid_inp).to(inputs.dtype)
204
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
205
+ first_part = first_half * cos - second_half * sin
206
+ second_part = second_half * cos + first_half * sin
207
+ return torch.cat((first_part, second_part), dim=-1)
208
+
209
+
210
+ class KVCache:
211
+ def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
212
+ self.k = (
213
+ torch.zeros((2, num_heads, max_len, head_dim), device=device)
214
+ if k is None
215
+ else k
216
+ )
217
+ self.v = (
218
+ torch.zeros((2, num_heads, max_len, head_dim), device=device)
219
+ if v is None
220
+ else v
221
+ )
222
+ self.current_idx = 0
223
+ self.max_len = max_len
224
+
225
+ def get_kv_for_attention(self, current_k, current_v):
226
+ if self.current_idx == 0:
227
+ return current_k, current_v
228
+ else:
229
+ past_k = self.k[:, :, : self.current_idx, :]
230
+ past_v = self.v[:, :, : self.current_idx, :]
231
+ attn_k = torch.cat((past_k, current_k), dim=2)
232
+ attn_v = torch.cat((past_v, current_v), dim=2)
233
+ return attn_k, attn_v
234
+
235
+ def update_cache(self, k, v):
236
+ assert self.current_idx < self.max_len
237
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
238
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
239
+ self.current_idx += 1
240
+
241
+ def prefill_kv(self, k, v):
242
+ prefill_len = k.shape[2]
243
+ assert prefill_len <= self.max_len
244
+ self.k[:, :, :prefill_len, :] = k
245
+ self.v[:, :, :prefill_len, :] = v
246
+ self.current_idx = prefill_len
247
+
248
+
249
+ class Attention(nn.Module):
250
+ """Attention using DenseGeneral."""
251
+
252
+ def __init__(
253
+ self,
254
+ config: DiaConfig,
255
+ q_embed_dim: int,
256
+ kv_embed_dim: int,
257
+ num_query_heads: int,
258
+ num_kv_heads: int,
259
+ head_dim: int,
260
+ dropout_rate: float,
261
+ is_cross_attn: bool = False,
262
+ out_embed_dim: int | None = None,
263
+ ):
264
+ super().__init__()
265
+ self.num_query_heads = num_query_heads
266
+ self.num_kv_heads = num_kv_heads
267
+ self.head_dim = head_dim
268
+ self.is_cross_attn = is_cross_attn
269
+ self.dropout_rate = dropout_rate
270
+ compute_dtype = _str_to_dtype(config.training.dtype)
271
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
272
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
273
+ self.projected_query_dim = num_query_heads * head_dim
274
+ if num_query_heads % num_kv_heads != 0:
275
+ raise ValueError(
276
+ f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
277
+ )
278
+ self.num_gqa_groups = num_query_heads // num_kv_heads
279
+
280
+ # --- Projection Layers using DenseGeneral ---
281
+ self.q_proj = DenseGeneral(
282
+ in_shapes=(q_embed_dim,),
283
+ out_features=(num_query_heads, head_dim),
284
+ axis=(-1,),
285
+ dtype=compute_dtype,
286
+ weight_dtype=weight_dtype,
287
+ )
288
+ self.k_proj = DenseGeneral(
289
+ in_shapes=(kv_embed_dim,),
290
+ out_features=(num_kv_heads, head_dim),
291
+ axis=(-1,),
292
+ dtype=compute_dtype,
293
+ weight_dtype=weight_dtype,
294
+ )
295
+ self.v_proj = DenseGeneral(
296
+ in_shapes=(kv_embed_dim,),
297
+ out_features=(num_kv_heads, head_dim),
298
+ axis=(-1,),
299
+ dtype=compute_dtype,
300
+ weight_dtype=weight_dtype,
301
+ )
302
+ self.o_proj = DenseGeneral(
303
+ in_shapes=(num_query_heads, head_dim),
304
+ out_features=(self.output_dim,),
305
+ axis=(-2, -1),
306
+ dtype=compute_dtype,
307
+ weight_dtype=weight_dtype,
308
+ )
309
+
310
+ # --- Rotary Embedding ---
311
+ self.rotary_emb = RotaryEmbedding(
312
+ embedding_dims=self.head_dim,
313
+ min_timescale=config.model.rope_min_timescale,
314
+ max_timescale=config.model.rope_max_timescale,
315
+ dtype=compute_dtype,
316
+ )
317
+
318
+ def forward(
319
+ self,
320
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
321
+ Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
322
+ q_positions: torch.Tensor, # (B, T)
323
+ kv_positions: torch.Tensor | None = None, # (B, S)
324
+ deterministic: bool = True,
325
+ attn_mask: (
326
+ torch.Tensor | None
327
+ ) = None, # None in Decoder Self Attention, Valid mask in Others
328
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
329
+ prefill: bool = False, # True only when prefilling KV Cache
330
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
331
+ """
332
+ Performs attention calculation with optional KV caching.
333
+
334
+ Args:
335
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
336
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
337
+ q_positions: Positions for queries (B, T).
338
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
339
+ deterministic: If True, disable dropout.
340
+ attn_mask: Attention mask.
341
+ cache: KVCache.
342
+ prefill: If True, use prefill mode.
343
+
344
+ Returns:
345
+ A tuple containing:
346
+ - output: The attention output tensor (B, T, output_dim).
347
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
348
+ """
349
+ if kv_positions is None:
350
+ kv_positions = q_positions
351
+ original_dtype = Xq.dtype
352
+
353
+ Xq_BxTxNxH = self.q_proj(Xq)
354
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
355
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
356
+
357
+ # Input values into attention calculation
358
+ attn_k: torch.Tensor | None = None
359
+ attn_v: torch.Tensor | None = None
360
+ new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
361
+
362
+ # Decoder Cross Attention
363
+ if self.is_cross_attn:
364
+ # Directly use cache (no need to check index)
365
+ attn_k, attn_v = cache.k, cache.v
366
+ if (
367
+ attn_k.shape[1] != self.num_query_heads
368
+ or attn_v.shape[1] != self.num_query_heads
369
+ ):
370
+ raise ValueError(
371
+ f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
372
+ f"does not match num_query_heads ({self.num_query_heads}). "
373
+ "Cache should be pre-repeated for GQA."
374
+ )
375
+ # Self Attention
376
+ else:
377
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
378
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
379
+ Xk_BxSxKxH = self.rotary_emb(
380
+ Xk_BxSxKxH, position=kv_positions
381
+ ) # (B, S, K, H)
382
+
383
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
384
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
385
+ # S=1 for Decode Step
386
+
387
+ if self.num_gqa_groups > 1:
388
+ Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
389
+ Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
390
+ else:
391
+ Xk_BxNxSxH = Xk_BxKxSxH
392
+ Xv_BxNxSxH = Xv_BxKxSxH
393
+
394
+ # Encoder Self Attention
395
+ if cache is None:
396
+ attn_k = Xk_BxNxSxH
397
+ attn_v = Xv_BxNxSxH
398
+ # Decoder Self Attention
399
+ else:
400
+ # In prefill mode, we fill in cache until prefill length
401
+ if prefill:
402
+ attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
403
+ cache.prefill_kv(attn_k, attn_v)
404
+ # In decode step, we add current K/V to cache step by step
405
+ else:
406
+ new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
407
+ attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
408
+
409
+ # Add the dtype conversion here - after both cross-attention and self-attention paths
410
+ if attn_k is not None and attn_v is not None:
411
+ attn_k = attn_k.to(Xq_BxNxTxH.dtype)
412
+ attn_v = attn_v.to(Xq_BxNxTxH.dtype)
413
+
414
+ attn_output = F.scaled_dot_product_attention(
415
+ Xq_BxNxTxH,
416
+ attn_k,
417
+ attn_v,
418
+ attn_mask=attn_mask,
419
+ dropout_p=self.dropout_rate if not deterministic else 0.0,
420
+ scale=1.0,
421
+ )
422
+
423
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
424
+ output = self.o_proj(attn_output)
425
+
426
+ return output.to(original_dtype), new_kv_cache
427
+
428
+
429
+ class EncoderLayer(nn.Module):
430
+ """Transformer Encoder Layer using DenseGeneral."""
431
+
432
+ def __init__(self, config: DiaConfig):
433
+ super().__init__()
434
+ self.config = config
435
+ model_config = config.model
436
+ enc_config = config.model.encoder
437
+ embed_dim = enc_config.n_embd
438
+
439
+ self.pre_sa_norm = RMSNorm(
440
+ embed_dim,
441
+ eps=model_config.normalization_layer_epsilon,
442
+ dtype=torch.float32,
443
+ )
444
+ self.self_attention = Attention(
445
+ config=config,
446
+ q_embed_dim=embed_dim,
447
+ kv_embed_dim=embed_dim,
448
+ num_query_heads=enc_config.n_head,
449
+ num_kv_heads=enc_config.n_head,
450
+ head_dim=enc_config.head_dim,
451
+ dropout_rate=model_config.dropout,
452
+ is_cross_attn=False,
453
+ out_embed_dim=embed_dim,
454
+ )
455
+ self.post_sa_norm = RMSNorm(
456
+ embed_dim,
457
+ eps=model_config.normalization_layer_epsilon,
458
+ dtype=torch.float32,
459
+ )
460
+ self.mlp = MlpBlock(
461
+ config=config,
462
+ embed_dim=embed_dim,
463
+ intermediate_dim=enc_config.n_hidden,
464
+ activations=enc_config.mlp_activations,
465
+ dropout_rate=model_config.dropout,
466
+ use_pre_norm=enc_config.use_pre_norm,
467
+ )
468
+ self.dropout = nn.Dropout(model_config.dropout)
469
+
470
+ def forward(
471
+ self,
472
+ x: torch.Tensor,
473
+ src_positions: torch.Tensor | None = None,
474
+ deterministic: bool = True,
475
+ attn_mask: torch.Tensor | None = None,
476
+ ) -> torch.Tensor:
477
+ residual = x
478
+ x_norm = self.pre_sa_norm(x)
479
+
480
+ sa_out, _ = self.self_attention(
481
+ Xq=x_norm,
482
+ Xkv=x_norm,
483
+ q_positions=src_positions,
484
+ kv_positions=src_positions,
485
+ deterministic=deterministic,
486
+ attn_mask=attn_mask,
487
+ )
488
+ x = residual + sa_out
489
+
490
+ residual = x
491
+ x_norm = self.post_sa_norm(x)
492
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
493
+ x = residual + mlp_out
494
+
495
+ if not deterministic:
496
+ x = self.dropout(x)
497
+ return x
498
+
499
+
500
+ class Encoder(nn.Module):
501
+ """Transformer Encoder Stack using DenseGeneral."""
502
+
503
+ def __init__(self, config: DiaConfig):
504
+ super().__init__()
505
+ self.config = config
506
+ model_config = config.model
507
+ enc_config = config.model.encoder
508
+ compute_dtype = _str_to_dtype(config.training.dtype)
509
+
510
+ self.embedding = nn.Embedding(
511
+ model_config.src_vocab_size,
512
+ enc_config.n_embd,
513
+ dtype=compute_dtype,
514
+ )
515
+ self.dropout = nn.Dropout(model_config.dropout)
516
+ self.layers = nn.ModuleList(
517
+ [EncoderLayer(config=config) for _ in range(enc_config.n_layer)]
518
+ )
519
+ self.norm = RMSNorm(
520
+ enc_config.n_embd,
521
+ eps=model_config.normalization_layer_epsilon,
522
+ dtype=torch.float32,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ x_ids: torch.Tensor,
528
+ src_positions: torch.Tensor | None = None,
529
+ deterministic: bool = True,
530
+ attn_mask: torch.Tensor | None = None,
531
+ ) -> torch.Tensor:
532
+ x = self.embedding(x_ids)
533
+
534
+ if not deterministic:
535
+ x = self.dropout(x)
536
+
537
+ for layer in self.layers:
538
+ x = layer(
539
+ x,
540
+ src_positions=src_positions,
541
+ deterministic=deterministic,
542
+ attn_mask=attn_mask,
543
+ )
544
+ x = self.norm(x)
545
+ if not deterministic:
546
+ x = self.dropout(x)
547
+ return x
548
+
549
+
550
+ class DecoderLayer(nn.Module):
551
+ """Transformer Decoder Layer using DenseGeneral."""
552
+
553
+ def __init__(self, config: DiaConfig):
554
+ super().__init__()
555
+ self.config = config
556
+ model_config = config.model
557
+ dec_config = config.model.decoder
558
+ enc_config = config.model.encoder
559
+ dec_embed_dim = dec_config.n_embd
560
+ enc_embed_dim = enc_config.n_embd
561
+
562
+ # Norms
563
+ self.pre_sa_norm = RMSNorm(
564
+ dec_embed_dim,
565
+ eps=model_config.normalization_layer_epsilon,
566
+ dtype=torch.float32,
567
+ )
568
+ self.pre_ca_norm = RMSNorm(
569
+ dec_embed_dim,
570
+ eps=model_config.normalization_layer_epsilon,
571
+ dtype=torch.float32,
572
+ )
573
+ self.pre_mlp_norm = RMSNorm(
574
+ dec_embed_dim,
575
+ eps=model_config.normalization_layer_epsilon,
576
+ dtype=torch.float32,
577
+ )
578
+
579
+ # Self-Attention (GQA) with Causal Masking
580
+ self.self_attention = Attention(
581
+ config=config,
582
+ q_embed_dim=dec_embed_dim,
583
+ kv_embed_dim=dec_embed_dim,
584
+ num_query_heads=dec_config.gqa_query_heads,
585
+ num_kv_heads=dec_config.kv_heads,
586
+ head_dim=dec_config.gqa_head_dim,
587
+ dropout_rate=model_config.dropout,
588
+ is_cross_attn=False,
589
+ out_embed_dim=dec_embed_dim,
590
+ )
591
+ # Cross-Attention (MHA)
592
+ self.cross_attention = Attention(
593
+ config=config,
594
+ q_embed_dim=dec_embed_dim,
595
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
596
+ num_query_heads=dec_config.cross_query_heads,
597
+ num_kv_heads=dec_config.cross_query_heads,
598
+ head_dim=dec_config.cross_head_dim,
599
+ dropout_rate=model_config.dropout,
600
+ is_cross_attn=True,
601
+ out_embed_dim=dec_embed_dim,
602
+ )
603
+ # MLP
604
+ self.mlp = MlpBlock(
605
+ config=config,
606
+ embed_dim=dec_embed_dim,
607
+ intermediate_dim=dec_config.n_hidden,
608
+ activations=dec_config.mlp_activations,
609
+ dropout_rate=model_config.dropout,
610
+ use_pre_norm=dec_config.use_pre_norm,
611
+ )
612
+
613
+ def forward(
614
+ self,
615
+ x: torch.Tensor,
616
+ encoder_out: torch.Tensor,
617
+ tgt_positions: torch.Tensor,
618
+ src_positions: torch.Tensor | None,
619
+ deterministic: bool,
620
+ self_attn_mask: torch.Tensor,
621
+ cross_attn_mask: torch.Tensor,
622
+ self_attn_cache: KVCache,
623
+ cross_attn_cache: KVCache,
624
+ prefill: bool = False,
625
+ ) -> torch.Tensor:
626
+ residual = x
627
+ x_norm = self.pre_sa_norm(x)
628
+
629
+ sa_out, new_kv_cache = self.self_attention(
630
+ Xq=x_norm, # (2, 1, D)
631
+ Xkv=x_norm, # (2, 1, D)
632
+ q_positions=tgt_positions, # (2, 1)
633
+ kv_positions=tgt_positions, # (2, 1)
634
+ deterministic=deterministic,
635
+ attn_mask=self_attn_mask, # (2, 1, 1, S_max)
636
+ cache=self_attn_cache,
637
+ prefill=prefill,
638
+ )
639
+
640
+ x = residual + sa_out
641
+
642
+ # 2. Cross-Attention
643
+ residual = x
644
+ x_norm = self.pre_ca_norm(x)
645
+ ca_out, _ = self.cross_attention(
646
+ Xq=x_norm,
647
+ Xkv=encoder_out,
648
+ q_positions=tgt_positions,
649
+ kv_positions=src_positions,
650
+ deterministic=deterministic,
651
+ attn_mask=cross_attn_mask,
652
+ cache=cross_attn_cache,
653
+ )
654
+ x = residual + ca_out
655
+
656
+ # 3. MLP
657
+ residual = x
658
+ x_norm = self.pre_mlp_norm(x)
659
+ mlp_out = self.mlp(x_norm, deterministic=deterministic)
660
+ x = residual + mlp_out
661
+
662
+ return x, new_kv_cache
663
+
664
+
665
+ class Decoder(nn.Module):
666
+ """Transformer Decoder Stack using DenseGeneral."""
667
+
668
+ def __init__(self, config: DiaConfig):
669
+ super().__init__()
670
+ self.config = config
671
+ model_config = config.model
672
+ dec_config = config.model.decoder
673
+ train_config = config.training
674
+ data_config = config.data
675
+ compute_dtype = _str_to_dtype(config.training.dtype)
676
+ weight_dtype = _str_to_dtype(config.model.weight_dtype)
677
+ self.num_channels = data_config.channels
678
+ self.num_layers = dec_config.n_layer
679
+
680
+ self.embeddings = nn.ModuleList(
681
+ [
682
+ nn.Embedding(
683
+ model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
684
+ )
685
+ for _ in range(self.num_channels)
686
+ ]
687
+ )
688
+ self.dropout = nn.Dropout(model_config.dropout)
689
+ self.layers = nn.ModuleList(
690
+ [DecoderLayer(config=config) for _ in range(self.num_layers)]
691
+ )
692
+ self.norm = RMSNorm(
693
+ dec_config.n_embd,
694
+ eps=model_config.normalization_layer_epsilon,
695
+ dtype=torch.float32,
696
+ )
697
+
698
+ # Final Logits Projection using DenseGeneral
699
+ self.logits_dense = DenseGeneral(
700
+ in_shapes=(dec_config.n_embd,),
701
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
702
+ axis=(-1,),
703
+ dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
704
+ weight_dtype=weight_dtype,
705
+ )
706
+ self.logits_in_fp32 = train_config.logits_dot_in_fp32
707
+
708
+ def precompute_cross_attention_kv(
709
+ self,
710
+ max_len: int,
711
+ encoder_out: torch.Tensor, # (B, S, E)
712
+ src_positions: torch.Tensor | None, # (B, S)
713
+ ) -> list[KVCache]:
714
+ """
715
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
716
+ """
717
+ per_layer_kv_cache: list[KVCache] = []
718
+
719
+ for layer in self.layers:
720
+ cross_attn_module = layer.cross_attention
721
+ k_proj = cross_attn_module.k_proj(encoder_out)
722
+ v_proj = cross_attn_module.v_proj(encoder_out)
723
+
724
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
725
+ k = k_proj.transpose(1, 2)
726
+ v = v_proj.transpose(1, 2)
727
+
728
+ per_layer_kv_cache.append(
729
+ KVCache(
730
+ cross_attn_module.num_kv_heads,
731
+ max_len,
732
+ cross_attn_module.head_dim,
733
+ k.device,
734
+ k=k,
735
+ v=v,
736
+ )
737
+ )
738
+
739
+ return per_layer_kv_cache
740
+
741
+ def decode_step(
742
+ self,
743
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
744
+ tgt_pos_Bx1: torch.Tensor, # [B, 1]
745
+ encoder_out: torch.Tensor, # [B, S, E]
746
+ self_attn_mask: Any, # None
747
+ cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
748
+ self_attention_cache: list[KVCache],
749
+ cross_attention_cache: list[KVCache],
750
+ ) -> torch.Tensor:
751
+ """
752
+ Performs a single decoding step, managing KV caches layer by layer.
753
+
754
+ Returns:
755
+ A tuple containing:
756
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
757
+ """
758
+ assert (
759
+ self_attn_mask is None
760
+ ), "Self-attention mask should be None, kept for pattern"
761
+
762
+ x = None
763
+ for i in range(self.num_channels):
764
+ channel_tokens = tgt_ids_Bx1xC[..., i]
765
+ channel_embed = self.embeddings[i](channel_tokens)
766
+ x = channel_embed if x is None else x + channel_embed
767
+
768
+ new_cache = []
769
+
770
+ for i, layer in enumerate(self.layers):
771
+ self_cache = self_attention_cache[i]
772
+ cross_cache = cross_attention_cache[i]
773
+ x, new_kv_cache = layer(
774
+ x, # (2, 1, D)
775
+ encoder_out, # (2, S, E)
776
+ src_positions=None, # CA KV is already computed
777
+ tgt_positions=tgt_pos_Bx1, # (2, 1)
778
+ deterministic=True,
779
+ self_attn_mask=None,
780
+ cross_attn_mask=cross_attn_mask,
781
+ self_attn_cache=self_cache,
782
+ cross_attn_cache=cross_cache,
783
+ )
784
+ new_cache.append(new_kv_cache)
785
+
786
+ x = self.norm(x)
787
+ logits_Bx1xCxV = self.logits_dense(x)
788
+
789
+ return logits_Bx1xCxV.to(torch.float32), new_cache
790
+
791
+ def forward(
792
+ self,
793
+ tgt_ids_BxTxC: torch.Tensor,
794
+ encoder_out: torch.Tensor,
795
+ tgt_positions: torch.Tensor,
796
+ src_positions: torch.Tensor,
797
+ deterministic: bool,
798
+ self_attn_mask: torch.Tensor,
799
+ cross_attn_mask: torch.Tensor,
800
+ self_attention_cache: list[KVCache],
801
+ cross_attention_cache: list[KVCache],
802
+ ) -> torch.Tensor:
803
+ """
804
+ Forward pass for the Decoder stack, managing KV caches.
805
+
806
+ Args:
807
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
808
+ encoder_out: Output from the encoder (B, S, E).
809
+ tgt_positions: Positions for target sequence (B, T).
810
+ src_positions: Positions for source sequence (B, S).
811
+ deterministic: Disable dropout if True.
812
+ self_attn_mask: Mask for self-attention.
813
+ cross_attn_mask: Mask for cross-attention.
814
+ past_key_values: List containing the self-attention KV cache for each layer
815
+ from the previous decoding step. `len(past_key_values)` should
816
+ equal `num_layers`.
817
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
818
+ derived from `encoder_out`. This is passed identically
819
+ to all layers.
820
+
821
+ Returns:
822
+ A tuple containing:
823
+ - logits: The final output logits (B, T, C * V), cast to float32.
824
+ - present_key_values: A list containing the updated self-attention KV cache
825
+ for each layer for the *current* decoding step.
826
+ """
827
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
828
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
829
+
830
+ # Embeddings
831
+ x = None
832
+ for i in range(self.num_channels):
833
+ channel_tokens = tgt_ids_BxTxC[..., i]
834
+ channel_embed = self.embeddings[i](channel_tokens)
835
+ x = channel_embed if x is None else x + channel_embed
836
+
837
+ if not deterministic:
838
+ x = self.dropout(x)
839
+
840
+ for i, layer in enumerate(self.layers):
841
+ x, _ = layer(
842
+ x,
843
+ encoder_out,
844
+ tgt_positions=tgt_positions,
845
+ src_positions=src_positions,
846
+ deterministic=deterministic,
847
+ self_attn_mask=self_attn_mask,
848
+ cross_attn_mask=cross_attn_mask,
849
+ self_attn_cache=self_attention_cache[i],
850
+ cross_attn_cache=cross_attention_cache[i],
851
+ prefill=True,
852
+ )
853
+
854
+ # Final Norm
855
+ x = self.norm(x)
856
+ logits_BxTxCxV = self.logits_dense(x)
857
+
858
+ return logits_BxTxCxV.to(torch.float32)
859
+
860
+
861
+ class DiaModel(nn.Module):
862
+ """PyTorch Dia Model using DenseGeneral."""
863
+
864
+ def __init__(self, config: DiaConfig):
865
+ super().__init__()
866
+ self.config = config
867
+ self.encoder = Encoder(config)
868
+ self.decoder = Decoder(config)
869
+
870
+ def forward(
871
+ self,
872
+ src_BxS: torch.Tensor,
873
+ tgt_BxTxC: torch.Tensor,
874
+ src_positions: torch.Tensor | None = None,
875
+ tgt_positions: torch.Tensor | None = None,
876
+ enc_self_attn_mask: torch.Tensor | None = None,
877
+ dec_self_attn_mask: torch.Tensor | None = None,
878
+ dec_cross_attn_mask: torch.Tensor | None = None,
879
+ enable_dropout: bool = True,
880
+ ):
881
+ deterministic = not enable_dropout
882
+
883
+ # --- Encoder Pass ---
884
+ encoder_out = self.encoder(
885
+ x_ids=src_BxS,
886
+ src_positions=src_positions,
887
+ deterministic=deterministic,
888
+ attn_mask=enc_self_attn_mask,
889
+ )
890
+
891
+ # --- Decoder Pass ---
892
+ logits, _ = self.decoder(
893
+ tgt_ids_BxTxC=tgt_BxTxC,
894
+ encoder_out=encoder_out,
895
+ tgt_positions=tgt_positions,
896
+ src_positions=src_positions,
897
+ deterministic=deterministic,
898
+ self_attn_mask=dec_self_attn_mask,
899
+ cross_attn_mask=dec_cross_attn_mask,
900
+ precomputed_cross_attn_kv=None,
901
+ )
902
+
903
+ return logits
dia/model.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dia/model.py
2
+
3
+ import os
4
+ import logging
5
+ import time
6
+ import dac # Keep this import name
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
+ from huggingface_hub import hf_hub_download
11
+ from safetensors.torch import load_file # <<< ADDED Import for safetensors
12
+
13
+ from .audio import audio_to_codebook, codebook_to_audio
14
+ from .config import (
15
+ DiaConfig,
16
+ ) # Assuming this is the Pydantic config for model structure
17
+ from .layers import DiaModel, KVCache # Assuming these are the nn.Module definitions
18
+
19
+ # --- Get a logger instance for this module ---
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Optional: Add a check after import to verify the library looks correct
23
+ # Note: We now expect 'utils' based on original code
24
+ if (
25
+ not hasattr(dac, "utils")
26
+ or not hasattr(dac.utils, "download")
27
+ or not hasattr(dac, "DAC")
28
+ ):
29
+ logger.warning(
30
+ "The imported 'dac' module does not appear to have the 'utils.download' structure expected by the original Dia code."
31
+ )
32
+ logger.warning(
33
+ "Ensure 'descript-audio-codec' is installed correctly (pip install descript-audio-codec)."
34
+ )
35
+ # If this check fails, _load_dac_model will likely raise an error later anyway.
36
+
37
+
38
+ def _sample_next_token(
39
+ logits_BCxV: torch.Tensor,
40
+ temperature: float,
41
+ top_p: float,
42
+ use_cfg_filter: bool,
43
+ cfg_filter_top_k: int | None = None,
44
+ ) -> torch.Tensor:
45
+ """Samples the next token based on logits, temperature, and top_p."""
46
+ if temperature == 0.0:
47
+ # Greedy sampling
48
+ return torch.argmax(logits_BCxV, dim=-1)
49
+
50
+ # Apply temperature scaling
51
+ logits_BCxV = logits_BCxV / temperature
52
+
53
+ # Apply CFG Top-K filtering (optional)
54
+ if use_cfg_filter and cfg_filter_top_k is not None:
55
+ # Get top K values and indices
56
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
57
+ # Create a mask to keep only top K logits
58
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
59
+ mask.scatter_(
60
+ dim=-1, index=top_k_indices_BCxV, value=False
61
+ ) # Set top K positions to False (don't mask)
62
+ # Mask out logits not in the top K
63
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
64
+
65
+ # Apply Top-P (Nucleus) sampling
66
+ if top_p < 1.0:
67
+ # Convert logits to probabilities
68
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
69
+ # Sort probabilities in descending order
70
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
71
+ probs_BCxV, dim=-1, descending=True
72
+ )
73
+ # Calculate cumulative probabilities
74
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
75
+
76
+ # Create mask for tokens to remove (those exceeding top_p threshold)
77
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
78
+ # Shift the mask: keep the first token that crosses the threshold
79
+ sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
80
+ ..., :-1
81
+ ].clone()
82
+ sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
83
+
84
+ # Scatter the mask back to the original order
85
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
86
+ indices_to_remove_BCxV.scatter_(
87
+ dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
88
+ )
89
+ # Apply the mask to the logits
90
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
91
+
92
+ # Calculate final probabilities after filtering
93
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
94
+
95
+ # Sample from the filtered distribution
96
+ # multinomial expects probabilities for each item in the batch
97
+ sampled_indices_BC = torch.multinomial(
98
+ final_probs_BCxV, num_samples=1
99
+ ) # Shape [B*C, 1]
100
+ sampled_indices_C = sampled_indices_BC.squeeze(
101
+ -1
102
+ ) # Shape [B*C] -> should be [C] if input was [C,V]
103
+ return sampled_indices_C
104
+
105
+
106
+ class Dia:
107
+ """
108
+ Main class for the Dia Text-to-Speech model, handling loading and generation.
109
+ """
110
+
111
+ def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")):
112
+ """
113
+ Initializes the Dia model structure based on the provided configuration.
114
+ Does not load weights here.
115
+
116
+ Args:
117
+ config: The DiaConfig object defining model parameters.
118
+ device: The torch device (e.g., 'cuda', 'cpu') the model should eventually run on.
119
+ Note: The model is instantiated but not moved to the device here.
120
+ """
121
+ super().__init__()
122
+ logger.info(
123
+ f"Initializing Dia model structure with config version: {config.version}"
124
+ )
125
+ self.config = config
126
+ # Store the target device, but don't move the model yet. Loading weights will handle device placement.
127
+ self.target_device = device
128
+ # Instantiate the underlying PyTorch model based on the config
129
+ self.model = DiaModel(config)
130
+ self.dac_model = None # DAC model will be loaded separately
131
+ logger.info("Dia model structure initialized.")
132
+
133
+ @classmethod
134
+ def load_model_from_files(
135
+ cls,
136
+ config_path: str,
137
+ weights_path: str,
138
+ device: torch.device = torch.device("cuda"),
139
+ ) -> "Dia":
140
+ """
141
+ Loads the Dia model from local configuration and weights files.
142
+ Handles both .pth and .safetensors weight formats.
143
+
144
+ Args:
145
+ config_path: Path to the configuration JSON file (e.g., 'config.json').
146
+ weights_path: Path to the model weights file (e.g., 'model.pth' or 'model.safetensors').
147
+ device: The torch device ('cuda', 'cpu', etc.) to load the model onto.
148
+
149
+ Returns:
150
+ An instance of the Dia model loaded with weights and set to eval mode.
151
+
152
+ Raises:
153
+ FileNotFoundError: If the config or weights file is not found.
154
+ ValueError: If the weights file format is unsupported.
155
+ RuntimeError: If there is an error loading the config, weights, or DAC model.
156
+ """
157
+ logger.info(f"Loading Dia model from local files:")
158
+ logger.info(f" Config: {config_path}")
159
+ logger.info(f" Weights: {weights_path}")
160
+ logger.info(f" Target Device: {device}")
161
+
162
+ # 1. Load Configuration
163
+ try:
164
+ config = DiaConfig.load(config_path)
165
+ if config is None:
166
+ # DiaConfig.load returns None on FileNotFoundError
167
+ logger.error(f"Configuration file not found at {config_path}")
168
+ raise FileNotFoundError(
169
+ f"Configuration file not found at {config_path}"
170
+ )
171
+ logger.info("Configuration loaded successfully.")
172
+ except Exception as e:
173
+ logger.error(
174
+ f"Error loading or validating configuration from {config_path}: {e}",
175
+ exc_info=True,
176
+ )
177
+ raise RuntimeError(
178
+ f"Failed to load configuration from {config_path}"
179
+ ) from e
180
+
181
+ # 2. Instantiate Model Structure
182
+ # Pass the target device during instantiation if the underlying DiaModel supports it,
183
+ # otherwise, we move it later. Assuming __init__ doesn't take device for now.
184
+ dia_instance = cls(
185
+ config, device
186
+ ) # Pass device mainly for storing target_device
187
+
188
+ # 3. Load Weights (State Dictionary)
189
+ try:
190
+ logger.info(f"Loading weights from: {weights_path}")
191
+ weights_filename = os.path.basename(weights_path)
192
+ state_dict = None
193
+
194
+ if weights_filename.endswith(".safetensors"):
195
+ logger.info(
196
+ "Detected .safetensors format. Loading using safetensors library."
197
+ )
198
+ # load_file loads directly to the specified device
199
+ state_dict = load_file(weights_path, device=str(device))
200
+ logger.info("Safetensors weights loaded.")
201
+ elif weights_filename.endswith(".pth"):
202
+ logger.info("Detected .pth format. Loading using torch.load.")
203
+ # torch.load needs map_location to load onto the correct device
204
+ state_dict = torch.load(weights_path, map_location=device)
205
+ logger.info("PyTorch weights (.pth) loaded.")
206
+ else:
207
+ logger.error(
208
+ f"Unsupported weights file format: {weights_filename}. Expected .pth or .safetensors."
209
+ )
210
+ raise ValueError(f"Unsupported weights file format: {weights_filename}")
211
+
212
+ # Load the state dictionary into the model structure
213
+ logger.info("Applying loaded weights to the model structure...")
214
+ # Use strict=True by default to catch mismatches. Can be set to False if needed for specific conversions (e.g., BF16 -> FP32 partial loads)
215
+ dia_instance.model.load_state_dict(state_dict, strict=True)
216
+ logger.info("Weights applied successfully.")
217
+
218
+ except FileNotFoundError:
219
+ logger.error(f"Weights file not found at {weights_path}")
220
+ raise FileNotFoundError(f"Weights file not found at {weights_path}")
221
+ except Exception as e:
222
+ logger.error(
223
+ f"Error loading weights from {weights_path}: {e}", exc_info=True
224
+ )
225
+ raise RuntimeError(f"Error loading weights from {weights_path}") from e
226
+
227
+ # 4. Move Model to Device and Set Eval Mode
228
+ logger.info(f"Moving model to device: {device}...")
229
+ dia_instance.model.to(device)
230
+ logger.info("Setting model to evaluation mode...")
231
+ dia_instance.model.eval()
232
+
233
+ # 5. Load Associated DAC Model
234
+ logger.info("Loading associated DAC model...")
235
+ dia_instance._load_dac_model() # This will log its own progress/errors
236
+
237
+ logger.info("Dia model fully loaded and ready.")
238
+ return dia_instance
239
+
240
+ # REMOVED from_pretrained - Responsibility moved to engine.py
241
+ # @classmethod
242
+ # def from_pretrained(...) -> "Dia": ...
243
+
244
+ def _load_dac_model(self):
245
+ """Loads the Descript Audio Codec (DAC) model using the original project's method."""
246
+ if self.dac_model is not None:
247
+ logger.info("DAC model already loaded.")
248
+ return
249
+
250
+ # Verify the imported module has the necessary structure expected by original code
251
+ if (
252
+ not hasattr(dac, "utils")
253
+ or not hasattr(dac.utils, "download")
254
+ or not hasattr(dac, "DAC")
255
+ ):
256
+ logger.error(
257
+ "Imported 'dac' module structure mismatch. Expected 'dac.utils.download()' and 'dac.DAC'."
258
+ )
259
+ logger.error(
260
+ "Ensure 'descript-audio-codec' is installed correctly via pip."
261
+ )
262
+ raise RuntimeError(
263
+ "Failed to load DAC model: required functions/structure missing from 'dac' module."
264
+ )
265
+
266
+ try:
267
+ # Use the original method found in the Dia repository
268
+ logger.info("Downloading/finding DAC model using dac.utils.download()...")
269
+ # This assumes dac.utils.download() handles caching internally
270
+ dac_model_path = dac.utils.download()
271
+ logger.info(f"DAC model path determined: {dac_model_path}")
272
+
273
+ logger.info("Loading DAC model from path...")
274
+ # Load DAC model and move it to the same device as the main Dia model
275
+ dac_model = dac.DAC.load(dac_model_path).to(self.target_device)
276
+ logger.info("DAC model loaded successfully.")
277
+
278
+ except AttributeError as ae:
279
+ logger.error(
280
+ f"AttributeError loading DAC model: '{ae}'. The installed 'descript-audio-codec' version might be incompatible with Dia's original code which expects 'dac.utils.download()'."
281
+ )
282
+ logger.error(
283
+ "Please check for specific version requirements of 'descript-audio-codec' for Dia, or potential installation issues."
284
+ )
285
+ raise RuntimeError(
286
+ "Failed to load DAC model due to incompatible library version or structure"
287
+ ) from ae
288
+ except Exception as e:
289
+ logger.error(f"General error loading DAC model: {e}", exc_info=True)
290
+ raise RuntimeError("Failed to load DAC model") from e
291
+
292
+ self.dac_model = dac_model
293
+
294
+ def _create_attn_mask(
295
+ self,
296
+ q_padding_mask_1d: torch.Tensor,
297
+ k_padding_mask_1d: torch.Tensor,
298
+ is_causal: bool = False,
299
+ ) -> torch.Tensor:
300
+ """
301
+ Creates the attention mask (self or cross) based on padding masks.
302
+ Mimics JAX segment ID logic where attention is allowed between non-padding tokens
303
+ OR between padding tokens, but not across the boundary.
304
+
305
+ Args:
306
+ q_padding_mask_1d: Boolean tensor [Batch, SeqLenQ] where True indicates non-padding.
307
+ k_padding_mask_1d: Boolean tensor [Batch, SeqLenK] where True indicates non-padding.
308
+ is_causal: If True, applies an additional causal mask (for decoder self-attention).
309
+
310
+ Returns:
311
+ Boolean attention mask tensor [Batch, 1, SeqLenQ, SeqLenK] ready for F.scaled_dot_product_attention.
312
+ """
313
+ B1, Tq = q_padding_mask_1d.shape
314
+ B2, Tk = k_padding_mask_1d.shape
315
+ if B1 != B2:
316
+ logger.warning(
317
+ f"Query ({B1}) and key ({B2}) batch dimensions do not match in _create_attn_mask"
318
+ )
319
+ assert B1 == B2, "Query and key batch dimensions must match"
320
+
321
+ # Expand masks for broadcasting: [B, Tq, 1] and [B, 1, Tk]
322
+ p_mask_q = q_padding_mask_1d.unsqueeze(2)
323
+ p_mask_k = k_padding_mask_1d.unsqueeze(1)
324
+
325
+ # True where a non-padding query token attends to a non-padding key token
326
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
327
+ # True where a padding query token attends to a padding key token
328
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
329
+
330
+ # Combine: Attention is allowed if tokens are both non-padding OR both padding.
331
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
332
+
333
+ if is_causal:
334
+ # Apply causal mask for self-attention (query cannot attend to future keys)
335
+ if Tq != Tk:
336
+ logger.warning(f"Causal mask requested but Tq ({Tq}) != Tk ({Tk})")
337
+ assert (
338
+ Tq == Tk
339
+ ), "Causal mask requires query and key sequence lengths to be equal"
340
+ # Create lower triangular matrix (True allows attention)
341
+ causal_mask_2d = torch.tril(
342
+ torch.ones((Tq, Tk), dtype=torch.bool, device=self.target_device)
343
+ )
344
+ # Combine with padding compatibility mask
345
+ mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
346
+
347
+ # Add head dimension for broadcasting: [B, 1, Tq, Tk]
348
+ return mask.unsqueeze(1)
349
+
350
+ def _prepare_text_input(
351
+ self, text: str
352
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
353
+ """
354
+ Encodes text prompt into byte tokens, pads to max length,
355
+ and creates position IDs and padding mask.
356
+
357
+ Args:
358
+ text: The input text string.
359
+
360
+ Returns:
361
+ Tuple containing:
362
+ - src_tokens: Padded token IDs [1, SeqLen].
363
+ - src_positions: Position IDs [1, SeqLen].
364
+ - src_padding_mask: Boolean mask (True=non-pad) [1, SeqLen].
365
+ - enc_self_attn_mask: Attention mask for encoder [1, 1, SeqLen, SeqLen].
366
+ """
367
+ text_pad_value = self.config.data.text_pad_value
368
+ max_len = self.config.data.text_length
369
+ logger.debug(
370
+ f"Preparing text input. Max length: {max_len}, Pad value: {text_pad_value}"
371
+ )
372
+ logger.debug(f"Original text (start): '{text[:100]}...'")
373
+
374
+ # Convert text to bytes and replace special speaker tokens
375
+ byte_text = text.encode("utf-8")
376
+ # Assuming Dia uses byte values 1 and 2 for S1/S2 based on original code context
377
+ replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
378
+ text_tokens = list(replaced_bytes) # List of integer byte values
379
+ logger.debug(
380
+ f"Text tokens after byte conversion (first 10): {text_tokens[:10]}"
381
+ )
382
+
383
+ # Pad or truncate sequence
384
+ current_len = len(text_tokens)
385
+ padding_needed = max_len - current_len
386
+ if padding_needed <= 0:
387
+ if current_len > max_len:
388
+ logger.warning(
389
+ f"Input text length ({current_len}) exceeds max length ({max_len}). Truncating."
390
+ )
391
+ text_tokens = text_tokens[:max_len]
392
+ padded_text_np = np.array(text_tokens, dtype=np.uint8)
393
+ else:
394
+ logger.debug(f"Padding text input with {padding_needed} pad tokens.")
395
+ padded_text_np = np.pad(
396
+ text_tokens,
397
+ (0, padding_needed),
398
+ mode="constant",
399
+ constant_values=text_pad_value,
400
+ ).astype(np.uint8)
401
+
402
+ # Convert to tensors and add batch dimension [1, SeqLen]
403
+ src_tokens = (
404
+ torch.from_numpy(padded_text_np)
405
+ .to(torch.long)
406
+ .to(self.target_device)
407
+ .unsqueeze(0)
408
+ )
409
+ src_positions = (
410
+ torch.arange(max_len, device=self.target_device).to(torch.long).unsqueeze(0)
411
+ )
412
+
413
+ # Create padding mask (True where token is NOT the pad value)
414
+ src_padding_mask = src_tokens != text_pad_value # Shape [1, SeqLen]
415
+
416
+ # Create attention mask for the encoder (non-causal self-attention)
417
+ # Needs shape [B, 1, Tq, Tk] -> [1, 1, SeqLen, SeqLen]
418
+ enc_self_attn_mask = self._create_attn_mask(
419
+ src_padding_mask, src_padding_mask, is_causal=False
420
+ )
421
+
422
+ logger.debug(f"Prepared src_tokens shape: {src_tokens.shape}")
423
+ logger.debug(f"Prepared src_positions shape: {src_positions.shape}")
424
+ logger.debug(
425
+ f"Prepared src_padding_mask shape: {src_padding_mask.shape} (True means non-padding)"
426
+ )
427
+ logger.debug(f"Prepared enc_self_attn_mask shape: {enc_self_attn_mask.shape}")
428
+
429
+ return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
430
+
431
+ @torch.inference_mode()
432
+ def generate(
433
+ self,
434
+ text: str,
435
+ max_tokens: int | None = None,
436
+ cfg_scale: float = 3.0,
437
+ temperature: float = 1.3,
438
+ top_p: float = 0.95,
439
+ use_cfg_filter: bool = True,
440
+ use_torch_compile: bool = False, # Default to False for broader compatibility
441
+ cfg_filter_top_k: int = 35,
442
+ audio_prompt_path: str | None = None,
443
+ ) -> np.ndarray:
444
+ """
445
+ Generates audio waveform from a text prompt, optionally conditioned on an audio prompt.
446
+
447
+ Args:
448
+ text: The input text string. For dialogue, use [S1]/[S2] markers.
449
+ For voice cloning, prepend the transcript of the audio prompt.
450
+ max_tokens: Maximum number of audio tokens (frames) to generate. Defaults to config value.
451
+ cfg_scale: Classifier-Free Guidance scale. Higher values increase adherence to text.
452
+ temperature: Sampling temperature. Higher values increase randomness.
453
+ top_p: Nucleus sampling probability. Filters vocabulary during sampling.
454
+ use_cfg_filter: Whether to apply Top-K filtering based on CFG logits.
455
+ use_torch_compile: If True, attempts to compile the decoder step for potential speedup.
456
+ cfg_filter_top_k: The 'K' value for CFG Top-K filtering.
457
+ audio_prompt_path: Path to an audio file (e.g., WAV, MP3) to use as a voice prompt/clone target.
458
+
459
+ Returns:
460
+ A 1D NumPy array containing the generated audio waveform (float32).
461
+ """
462
+ start_time_gen = time.time()
463
+ logger.info("Starting audio generation...")
464
+ logger.info(f" Text (start): '{text[:100]}...'")
465
+ logger.info(
466
+ f" Max tokens: {max_tokens if max_tokens is not None else 'Model Default'}"
467
+ )
468
+ logger.info(f" CFG Scale: {cfg_scale}")
469
+ logger.info(f" Temperature: {temperature}")
470
+ logger.info(f" Top P: {top_p}")
471
+ logger.info(f" Use CFG Filter: {use_cfg_filter}, Top K: {cfg_filter_top_k}")
472
+ logger.info(
473
+ f" Audio Prompt: {audio_prompt_path if audio_prompt_path else 'None'}"
474
+ )
475
+ logger.info(f" Use torch.compile: {use_torch_compile}")
476
+ logger.info(f" Target Device: {self.target_device}")
477
+
478
+ # --- Parameter Setup ---
479
+ num_channels = self.config.data.channels
480
+ audio_bos_value = self.config.data.audio_bos_value
481
+ audio_eos_value = self.config.data.audio_eos_value
482
+ audio_pad_value = self.config.data.audio_pad_value
483
+ delay_pattern = self.config.data.delay_pattern
484
+ # Use model's default audio length if max_tokens not provided
485
+ effective_max_tokens = (
486
+ max_tokens if max_tokens is not None else self.config.data.audio_length
487
+ )
488
+ logger.info(f" Effective max_tokens for generation: {effective_max_tokens}")
489
+
490
+ # Ensure delay pattern is usable
491
+ if not isinstance(delay_pattern, list) or not delay_pattern:
492
+ logger.warning("Delay pattern is invalid or empty. Using default [0].")
493
+ delay_pattern = [
494
+ 0
495
+ ] * num_channels # Fallback, though config should provide default
496
+
497
+ delay_tensor = torch.tensor(
498
+ delay_pattern, dtype=torch.long, device=self.target_device
499
+ )
500
+ max_delay_pattern = max(delay_pattern) if delay_pattern else 0
501
+ self.model.eval() # Ensure model is in eval mode
502
+
503
+ # --- Prepare Conditional and Unconditional Inputs ---
504
+ logger.info(
505
+ "Preparing text inputs for conditional and unconditional generation..."
506
+ )
507
+ (
508
+ cond_src_BxS,
509
+ cond_src_positions_BxS,
510
+ cond_src_padding_mask_BxS,
511
+ cond_enc_self_attn_mask_Bx1xSxS,
512
+ ) = self._prepare_text_input(text)
513
+
514
+ # Create unconditional input (batch of zeros representing padding)
515
+ # Assuming pad value 0 for text based on config default
516
+ unc_src_BxS = torch.full_like(
517
+ cond_src_BxS, fill_value=self.config.data.text_pad_value
518
+ )
519
+ # Batch conditional and unconditional inputs together [2, SeqLen]
520
+ src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
521
+ # Expand other inputs to match batch size 2
522
+ src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
523
+ src_padding_mask_BxS = torch.cat(
524
+ [
525
+ torch.zeros_like(cond_src_padding_mask_BxS[0:1]),
526
+ cond_src_padding_mask_BxS,
527
+ ],
528
+ dim=0,
529
+ ) # Uncond mask is all False (padding)
530
+ # Encoder mask needs to handle the batched input correctly
531
+ # For CFG, typically the unconditional branch attends to nothing useful from text,
532
+ # but the structure needs to be maintained. We can reuse the conditional mask structure,
533
+ # but the actual attention scores will be based on the zeroed-out unconditional input.
534
+ # Alternatively, create a specific mask for the unconditional part if needed.
535
+ # Let's expand the conditional mask for simplicity, assuming the model handles zero inputs appropriately.
536
+ enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(
537
+ 2, -1, -1, -1
538
+ )
539
+ logger.info("Text inputs prepared (batch size 2 for CFG).")
540
+
541
+ # --- Encoder Pass ---
542
+ logger.info("Running encoder pass...")
543
+ start_time_enc = time.time()
544
+ # Potentially use autocast for mixed precision if supported and beneficial on device
545
+ # Example: with torch.autocast(device_type=self.target_device.type, dtype=torch.bfloat16 if self.target_device.type == 'cuda' else torch.float32):
546
+ encoder_out = self.model.encoder(
547
+ x_ids=src_BxS, # Shape [2, S]
548
+ src_positions=src_positions_BxS, # Shape [2, S]
549
+ deterministic=True, # No dropout during inference
550
+ attn_mask=enc_self_attn_mask_Bx1xSxS, # Shape [2, 1, S, S]
551
+ )
552
+ logger.info(
553
+ f"Encoder pass completed in {time.time() - start_time_enc:.3f}s. Output shape: {encoder_out.shape}"
554
+ ) # Shape: [2, S, E]
555
+
556
+ # --- Prepare Decoder Inputs & KV Cache ---
557
+ logger.info("Preparing decoder inputs and KV cache...")
558
+ start_time_kv = time.time()
559
+ # 3-1. Precompute Cross-Attention KV Cache (Static) from encoder output
560
+ # This cache is computed once and reused for every decoding step.
561
+ decoder_cross_attention_cache: list[KVCache] = (
562
+ self.model.decoder.precompute_cross_attention_kv(
563
+ effective_max_tokens, encoder_out, src_positions_BxS
564
+ )
565
+ )
566
+ logger.debug(
567
+ f"Precomputed cross-attention KV cache for {len(decoder_cross_attention_cache)} layers."
568
+ )
569
+
570
+ # 3-2. Initialize Self-Attention KV Cache (Dynamic, grows with each step)
571
+ decoder_self_attention_cache: list[KVCache] = []
572
+ for i in range(self.model.decoder.num_layers):
573
+ decoder_self_attention_cache.append(
574
+ KVCache(
575
+ self.config.model.decoder.gqa_query_heads,
576
+ effective_max_tokens, # Max length the cache can hold
577
+ self.config.model.decoder.gqa_head_dim,
578
+ self.target_device, # Cache tensors should be on the target device
579
+ )
580
+ )
581
+ logger.debug(
582
+ f"Initialized self-attention KV cache for {len(decoder_self_attention_cache)} layers."
583
+ )
584
+ logger.info(
585
+ f"KV cache preparation completed in {time.time() - start_time_kv:.3f}s."
586
+ )
587
+
588
+ # 3-3. Initialize Decoder Start Tokens (BOS)
589
+ # Shape [2, 1, C] (Batch=2 for cond/uncond, T=1 for first step, C=channels)
590
+ generated_tokens_history = torch.full(
591
+ (2, 1, num_channels),
592
+ fill_value=audio_bos_value,
593
+ dtype=torch.long,
594
+ device=self.target_device,
595
+ )
596
+ logger.debug(f"Initial decoder input (BOS): {generated_tokens_history.shape}")
597
+
598
+ current_step_index = (
599
+ 0 # Index of the step we are currently generating (starts at 0)
600
+ )
601
+ prompt_len_inc_bos = 1 # Length of the initial prompt (just BOS initially)
602
+
603
+ # 3-4. Handle Audio Prompt (Prefill KV Cache)
604
+ if audio_prompt_path is not None:
605
+ logger.info("Processing audio prompt for prefilling...")
606
+ start_time_prompt = time.time()
607
+ try:
608
+ # Load and potentially resample audio
609
+ audio_prompt_waveform, sr = torchaudio.load(audio_prompt_path)
610
+ logger.debug(
611
+ f"Loaded audio prompt: {audio_prompt_waveform.shape}, Sample Rate: {sr}"
612
+ )
613
+ if sr != 44100:
614
+ logger.info(f"Resampling audio prompt from {sr}Hz to 44100Hz")
615
+ audio_prompt_waveform = torchaudio.functional.resample(
616
+ audio_prompt_waveform, sr, 44100
617
+ )
618
+ # Ensure correct shape [B, C, T_audio] and device
619
+ # Assuming DAC expects channels first, add batch dim
620
+ if audio_prompt_waveform.ndim == 1: # Mono
621
+ audio_prompt_waveform = audio_prompt_waveform.unsqueeze(
622
+ 0
623
+ ) # Add channel dim
624
+ audio_prompt_waveform = audio_prompt_waveform.unsqueeze(0).to(
625
+ self.target_device
626
+ ) # Add batch dim
627
+
628
+ # Encode audio prompt to codes using DAC
629
+ logger.info("Encoding audio prompt to codes using DAC...")
630
+ if self.dac_model is None:
631
+ raise RuntimeError(
632
+ "DAC model not loaded, required for audio prompt."
633
+ )
634
+ # audio_to_codebook returns [B, T_codes, C]
635
+ audio_prompt_codes = audio_to_codebook(
636
+ self.dac_model, audio_prompt_waveform, data_config=self.config.data
637
+ ) # Shape [1, T_codes, C]
638
+ logger.info(
639
+ f"Encoded audio prompt to codes: {audio_prompt_codes.shape}"
640
+ )
641
+
642
+ # Concatenate BOS tokens with prompt codes
643
+ # Expand prompt codes to batch size 2 (for cond/uncond)
644
+ generated_tokens_history = torch.cat(
645
+ [generated_tokens_history, audio_prompt_codes.expand(2, -1, -1)],
646
+ dim=1,
647
+ ) # Shape [2, 1 + T_codes, C]
648
+ logger.debug(
649
+ f"Decoder input history after prompt concatenation: {generated_tokens_history.shape}"
650
+ )
651
+
652
+ prefill_len = generated_tokens_history.shape[
653
+ 1
654
+ ] # Length including BOS + prompt
655
+ prompt_len_inc_bos = prefill_len
656
+ logger.info(f"Prefilling KV cache with length {prefill_len}...")
657
+
658
+ # Prepare inputs for prefill forward pass
659
+ prefill_tgt_pos = (
660
+ torch.arange(prefill_len, device=self.target_device)
661
+ .unsqueeze(0)
662
+ .expand(2, -1)
663
+ ) # Shape [2, T_prefill]
664
+ # Padding mask based on actual tokens (BOS and prompt codes are not PAD)
665
+ # Shape [2, T_prefill] (True where not PAD)
666
+ prefill_tgt_padding_mask = (
667
+ generated_tokens_history != audio_pad_value
668
+ ).any(dim=2)
669
+
670
+ # Create attention masks for prefill
671
+ # Shape [2, 1, T_prefill, T_prefill]
672
+ prefill_self_attn_mask = self._create_attn_mask(
673
+ prefill_tgt_padding_mask,
674
+ prefill_tgt_padding_mask,
675
+ is_causal=True,
676
+ )
677
+ # Shape [2, 1, T_prefill, S]
678
+ prefill_cross_attn_mask = self._create_attn_mask(
679
+ prefill_tgt_padding_mask,
680
+ src_padding_mask_BxS,
681
+ is_causal=False,
682
+ )
683
+
684
+ # Run forward pass through decoder to fill the self-attention KV cache
685
+ # We discard the logits from prefill
686
+ _ = self.model.decoder.forward(
687
+ tgt_ids_BxTxC=generated_tokens_history, # Pass the full history [2, T_prefill, C]
688
+ encoder_out=encoder_out,
689
+ tgt_positions=prefill_tgt_pos,
690
+ src_positions=src_positions_BxS,
691
+ deterministic=True,
692
+ self_attn_mask=prefill_self_attn_mask,
693
+ cross_attn_mask=prefill_cross_attn_mask,
694
+ self_attention_cache=decoder_self_attention_cache, # Pass cache to be filled
695
+ cross_attention_cache=decoder_cross_attention_cache, # Pass precomputed cache
696
+ # prefill=True # Pass prefill flag if decoder layer uses it
697
+ )
698
+
699
+ # Update the current step index. The next token to generate is at index prefill_len.
700
+ current_step_index = prefill_len
701
+ logger.info(
702
+ f"KV cache prefilled in {time.time() - start_time_prompt:.3f}s. Next step index: {current_step_index}"
703
+ )
704
+
705
+ except Exception as e:
706
+ logger.error(f"Error processing audio prompt: {e}", exc_info=True)
707
+ raise RuntimeError("Failed to process audio prompt") from e
708
+
709
+ # --- Autoregressive Generation Loop ---
710
+ logger.info("Starting autoregressive generation loop...")
711
+ start_time_loop = time.time()
712
+
713
+ eos_detected_channel_0 = False
714
+ eos_countdown = -1 # Countdown after EOS detected in channel 0
715
+ extra_steps_after_eos = (
716
+ 30 # Generate a few extra steps for delay pattern completion
717
+ )
718
+
719
+ # Pre-allocate tensor for storing *newly* generated tokens for efficiency
720
+ # We already have the prompt in generated_tokens_history
721
+ num_steps_to_generate = effective_max_tokens
722
+ newly_generated_tokens = torch.full(
723
+ (2, num_steps_to_generate, num_channels),
724
+ fill_value=audio_pad_value, # Fill with pad initially
725
+ dtype=torch.long,
726
+ device=self.target_device,
727
+ )
728
+ logger.debug(
729
+ f"Allocated tensor for newly generated tokens: {newly_generated_tokens.shape}"
730
+ )
731
+
732
+ # --- Compile decode_step if requested ---
733
+ decode_step_fn = self.model.decoder.decode_step
734
+ if use_torch_compile:
735
+ logger.info("Compiling decoder step function with torch.compile...")
736
+ try:
737
+ # Experiment with modes: "default", "reduce-overhead", "max-autotune"
738
+ decode_step_fn = torch.compile(decode_step_fn, mode="reduce-overhead")
739
+ logger.info("Decoder step function compiled.")
740
+ except Exception as e:
741
+ logger.warning(
742
+ f"torch.compile failed: {e}. Using eager mode.", exc_info=True
743
+ )
744
+
745
+ # --- Prepare static cross-attention mask for single-step decoding ---
746
+ # Query mask is always [B, 1] (True, as generated tokens are not PAD)
747
+ step_tgt_padding_mask = torch.ones(
748
+ (2, 1), dtype=torch.bool, device=self.target_device
749
+ )
750
+ # Shape [2, 1, 1, S]
751
+ step_decoder_cross_attn_mask = self._create_attn_mask(
752
+ step_tgt_padding_mask,
753
+ src_padding_mask_BxS,
754
+ is_causal=False,
755
+ )
756
+
757
+ # --- Generation Loop ---
758
+ steps_taken = 0
759
+ for step_offset in range(num_steps_to_generate):
760
+ # Absolute step index considering prompt length
761
+ current_absolute_step = current_step_index + step_offset
762
+
763
+ # Get the token IDs for the *previous* step to predict the current one
764
+ # Shape [2, 1, C]
765
+ # If step_offset is 0, use the last token from the prompt history
766
+ if step_offset == 0:
767
+ input_token_ids = generated_tokens_history[:, -1, :].unsqueeze(1)
768
+ else:
769
+ # Use the token generated in the previous iteration of this loop
770
+ input_token_ids = newly_generated_tokens[
771
+ :, step_offset - 1, :
772
+ ].unsqueeze(1)
773
+
774
+ # Position ID for the current absolute step
775
+ # Shape [2, 1]
776
+ tgt_pos_Bx1 = torch.full(
777
+ (2, 1),
778
+ fill_value=current_absolute_step,
779
+ dtype=torch.long,
780
+ device=self.target_device,
781
+ )
782
+
783
+ # --- Call Decoder Step ---
784
+ # self_attn_mask is None because KV cache handles causality implicitly in single-step decoding
785
+ logits_Bx1xCxV, new_self_kv_cache_list = decode_step_fn(
786
+ tgt_ids_Bx1xC=input_token_ids,
787
+ tgt_pos_Bx1=tgt_pos_Bx1,
788
+ encoder_out=encoder_out,
789
+ self_attn_mask=None,
790
+ cross_attn_mask=step_decoder_cross_attn_mask,
791
+ self_attention_cache=decoder_self_attention_cache,
792
+ cross_attention_cache=decoder_cross_attention_cache,
793
+ ) # Logits shape: [2, 1, C, V]
794
+
795
+ # --- Update Self-Attention KV Cache ---
796
+ for i, layer_cache in enumerate(decoder_self_attention_cache):
797
+ if (
798
+ new_self_kv_cache_list
799
+ and i < len(new_self_kv_cache_list)
800
+ and new_self_kv_cache_list[i] is not None
801
+ ):
802
+ # new_self_kv_cache_list[i] is a tuple (k_tensor, v_tensor) for the current step
803
+ # k_tensor shape: [2, NumHeads, 1, HeadDim]
804
+ # v_tensor shape: [2, NumHeads, 1, HeadDim]
805
+ layer_cache.update_cache(
806
+ new_self_kv_cache_list[i][0], new_self_kv_cache_list[i][1]
807
+ )
808
+ else:
809
+ logger.warning(
810
+ f"Missing KV cache update for layer {i} at step {current_absolute_step}"
811
+ )
812
+
813
+ # --- Sampling ---
814
+ V = self.config.model.tgt_vocab_size
815
+ # Get logits for the generated step [2, C, V]
816
+ logits_last_BxCxV = logits_Bx1xCxV.squeeze(1)
817
+
818
+ # Separate conditional and unconditional logits
819
+ uncond_logits_CxV = logits_last_BxCxV[0, :, :] # Shape [C, V]
820
+ cond_logits_CxV = logits_last_BxCxV[1, :, :] # Shape [C, V]
821
+
822
+ # Apply Classifier-Free Guidance (CFG)
823
+ cfg_logits_CxV = cond_logits_CxV + cfg_scale * (
824
+ cond_logits_CxV - uncond_logits_CxV
825
+ ) # Shape [C, V]
826
+
827
+ # --- Prevent sampling PAD/EOS/BOS tokens inappropriately ---
828
+ logits_for_sampling_CxV = (
829
+ cfg_logits_CxV.clone()
830
+ ) # Clone to avoid modifying original logits
831
+ logits_for_sampling_CxV[:, audio_pad_value] = -torch.inf # Never sample PAD
832
+ logits_for_sampling_CxV[:, audio_bos_value] = (
833
+ -torch.inf
834
+ ) # Never sample BOS after start
835
+ # Allow EOS only if not already detected or in countdown
836
+ if eos_detected_channel_0 and eos_countdown <= 0:
837
+ logits_for_sampling_CxV[:, audio_eos_value] = -torch.inf
838
+
839
+ # --- Sample the next token for each channel ---
840
+ pred_C = _sample_next_token(
841
+ logits_for_sampling_CxV.float(), # Ensure float32 for sampling stability
842
+ temperature=temperature,
843
+ top_p=top_p,
844
+ use_cfg_filter=use_cfg_filter,
845
+ cfg_filter_top_k=cfg_filter_top_k,
846
+ ) # Shape [C]
847
+
848
+ # --- Handle Delay Pattern (Only if no audio prompt was given) ---
849
+ # If there's no prompt, the first few tokens should be BOS according to delay
850
+ # generation_step_index is how many tokens generated *after* prompt/initial BOS
851
+ generation_step_index = step_offset
852
+ if audio_prompt_path is None:
853
+ is_before_delay = generation_step_index < delay_tensor # Shape [C]
854
+ pred_C = torch.where(
855
+ is_before_delay,
856
+ torch.tensor(
857
+ audio_bos_value, device=self.target_device, dtype=torch.long
858
+ ),
859
+ pred_C,
860
+ )
861
+
862
+ # --- Store the predicted token in the newly_generated_tokens tensor ---
863
+ newly_generated_tokens[:, step_offset, :] = pred_C.unsqueeze(0).expand(
864
+ 2, -1
865
+ )
866
+
867
+ steps_taken += 1 # Increment steps taken in this loop
868
+
869
+ # --- EOS Handling ---
870
+ if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
871
+ logger.info(
872
+ f"EOS token detected in channel 0 at step {current_absolute_step}. Starting countdown."
873
+ )
874
+ eos_detected_channel_0 = True
875
+ eos_countdown = extra_steps_after_eos
876
+
877
+ if eos_countdown > 0:
878
+ step_after_eos = extra_steps_after_eos - eos_countdown
879
+ logger.debug(
880
+ f"EOS countdown: {eos_countdown}, Step after EOS: {step_after_eos}"
881
+ )
882
+ # Modify the token *just generated* if needed for EOS/PAD forcing
883
+ current_new_tokens = newly_generated_tokens[
884
+ :, step_offset, :
885
+ ] # Shape [2, C]
886
+ for i, d in enumerate(delay_pattern):
887
+ if step_after_eos == d:
888
+ logger.debug(
889
+ f" Forcing EOS in channel {i} at step {current_absolute_step}"
890
+ )
891
+ current_new_tokens[:, i] = audio_eos_value
892
+ elif step_after_eos > d:
893
+ logger.debug(
894
+ f" Forcing PAD in channel {i} at step {current_absolute_step}"
895
+ )
896
+ current_new_tokens[:, i] = audio_pad_value
897
+ # Put the potentially modified tokens back
898
+ newly_generated_tokens[:, step_offset, :] = current_new_tokens
899
+
900
+ eos_countdown -= 1
901
+ if eos_countdown == 0:
902
+ logger.info(
903
+ f"EOS countdown finished at step {current_absolute_step}. Stopping generation."
904
+ )
905
+ break # Stop generation loop
906
+
907
+ # Check if we reached the max *new* tokens requested
908
+ if steps_taken >= num_steps_to_generate:
909
+ logger.info(
910
+ f"Reached max generation steps ({num_steps_to_generate}). Stopping."
911
+ )
912
+ break
913
+
914
+ logger.info(
915
+ f"Autoregressive loop finished after {steps_taken} steps in {time.time() - start_time_loop:.3f}s."
916
+ )
917
+
918
+ # --- Extract Generated Codes ---
919
+ # Get the conditional generation result (index 1) from the *newly* generated tokens
920
+ # Only take the number of steps actually taken
921
+ final_new_codes = newly_generated_tokens[
922
+ 1, :steps_taken, :
923
+ ] # Shape [T_generated, C]
924
+ logger.info(f"Extracted newly generated codes shape: {final_new_codes.shape}")
925
+
926
+ # --- Convert Codes to Audio using DAC ---
927
+ logger.info("Converting generated codes to audio using DAC...")
928
+ start_time_decode = time.time()
929
+ if self.dac_model is None:
930
+ raise RuntimeError("DAC model not loaded, required for audio decoding.")
931
+
932
+ # codebook_to_audio expects codes shape [C, T]
933
+ generated_codes_CxT = final_new_codes.transpose(0, 1) # Shape [C, T_generated]
934
+
935
+ if generated_codes_CxT.numel() == 0:
936
+ logger.warning("No new codes were generated. Returning empty audio.")
937
+ return np.array([], dtype=np.float32)
938
+
939
+ # Call the decoding function (handles delay reversal and DAC decoding)
940
+ audio_waveform = codebook_to_audio(
941
+ generated_codes_CxT,
942
+ self.dac_model,
943
+ delay_pattern,
944
+ B=1, # Batch size for decoding is 1
945
+ T=generated_codes_CxT.shape[1], # Pass the actual length of generated codes
946
+ C=num_channels,
947
+ ) # Returns shape [1, T_audio] or [T_audio]
948
+
949
+ # Ensure output is a 1D numpy array on CPU
950
+ final_audio_np = audio_waveform.squeeze().cpu().numpy()
951
+ logger.info(
952
+ f"Audio decoding completed in {time.time() - start_time_decode:.3f}s. Output shape: {final_audio_np.shape}"
953
+ )
954
+ logger.info(f"Total generation time: {time.time() - start_time_gen:.3f}s")
955
+
956
+ return final_audio_np
docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ dia-tts-server:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "${PORT:-8003}:${PORT:-8003}"
10
+ volumes:
11
+ - ./model_cache:/app/model_cache
12
+ - ./reference_audio:/app/reference_audio
13
+ - ./outputs:/app/outputs
14
+ deploy:
15
+ resources:
16
+ reservations:
17
+ devices:
18
+ - driver: nvidia
19
+ count: 1
20
+ capabilities: [gpu]
21
+ restart: unless-stopped
22
+ env_file:
23
+ - .env
documentation.md ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dia TTS Server - Technical Documentation
2
+
3
+ **Version:** 1.0.0
4
+ **Date:** 2025-04-22
5
+
6
+ **Table of Contents:**
7
+
8
+ 1. [Overview](#1-overview)
9
+ 2. [Visual Overview](#2-visual-overview)
10
+ * [Directory Structure](#21-directory-structure)
11
+ * [Component Diagram](#22-component-diagram)
12
+ 3. [System Prerequisites](#3-system-prerequisites)
13
+ 4. [Installation and Setup](#4-installation-and-setup)
14
+ * [Cloning the Repository](#41-cloning-the-repository)
15
+ * [Setting up Python Virtual Environment](#42-setting-up-python-virtual-environment)
16
+ * [Windows Setup](#421-windows-setup)
17
+ * [Linux Setup (Debian/Ubuntu Example)](#422-linux-setup-debianubuntu-example)
18
+ * [Installing Dependencies](#43-installing-dependencies)
19
+ * [NVIDIA Driver and CUDA Setup (Required for GPU Acceleration)](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration)
20
+ * [Step 1: Check/Install NVIDIA Drivers](#441-step-1-checkinstall-nvidia-drivers)
21
+ * [Step 2: Install PyTorch with CUDA Support](#442-step-2-install-pytorch-with-cuda-support)
22
+ * [Step 3: Verify PyTorch CUDA Installation](#443-step-3-verify-pytorch-cuda-installation)
23
+ 5. [Configuration](#5-configuration)
24
+ * [Configuration Files (`.env` and `config.py`)](#51-configuration-files-env-and-configpy)
25
+ * [Configuration Parameters](#52-configuration-parameters)
26
+ 6. [Running the Server](#6-running-the-server)
27
+ 7. [Usage](#7-usage)
28
+ * [Web User Interface (Web UI)](#71-web-user-interface-web-ui)
29
+ * [Main Generation Form](#711-main-generation-form)
30
+ * [Presets](#712-presets)
31
+ * [Voice Cloning](#713-voice-cloning)
32
+ * [Generation Parameters](#714-generation-parameters)
33
+ * [Server Configuration (UI)](#715-server-configuration-ui)
34
+ * [Generated Audio Player](#716-generated-audio-player)
35
+ * [Theme Toggle](#717-theme-toggle)
36
+ * [API Endpoints](#72-api-endpoints)
37
+ * [POST /v1/audio/speech (OpenAI Compatible)](#721-post-v1audiospeech-openai-compatible)
38
+ * [POST /tts (Custom Parameters)](#722-post-tts-custom-parameters)
39
+ * [Configuration & Helper Endpoints](#723-configuration--helper-endpoints)
40
+ 8. [Troubleshooting](#8-troubleshooting)
41
+ 9. [Project Architecture](#9-project-architecture)
42
+ 10. [License and Disclaimer](#10-license-and-disclaimer)
43
+
44
+ ---
45
+
46
+ ## 1. Overview
47
+
48
+ The Dia TTS Server provides a backend service and web interface for generating high-fidelity speech, including dialogue with multiple speakers and non-verbal sounds, using the Dia text-to-speech model family (originally from Nari Labs, with support for community conversions like SafeTensors).
49
+
50
+ This server is built using the FastAPI framework and offers both a RESTful API (including an OpenAI-compatible endpoint) and an interactive web UI powered by Jinja2, Tailwind CSS, and JavaScript. It supports voice cloning via audio prompts and allows configuration of various generation parameters.
51
+
52
+ **Key Features:**
53
+
54
+ * **High-Quality TTS:** Leverages the Dia model for realistic speech synthesis.
55
+ * **Dialogue Generation:** Supports `[S1]` and `[S2]` tags for multi-speaker dialogue.
56
+ * **Non-Verbal Sounds:** Can generate sounds like `(laughs)`, `(sighs)`, etc., when included in the text.
57
+ * **Voice Cloning:** Allows conditioning the output voice on a provided reference audio file.
58
+ * **Flexible Model Loading:** Supports loading models from Hugging Face repositories, including both `.pth` and `.safetensors` formats (defaults to BF16 SafeTensors for efficiency).
59
+ * **API Access:** Provides a custom API endpoint (`/tts`) and an OpenAI-compatible endpoint (`/v1/audio/speech`).
60
+ * **Web Interface:** Offers an easy-to-use UI for text input, parameter adjustment, preset loading, reference audio management, and audio playback.
61
+ * **Configuration:** Server settings, model sources, paths, and default generation parameters are configurable via an `.env` file.
62
+ * **GPU Acceleration:** Utilizes NVIDIA GPUs via CUDA for significantly faster inference when available, falling back to CPU otherwise.
63
+
64
+ ---
65
+
66
+ ## 2. Visual Overview
67
+
68
+ ### 2.1 Directory Structure
69
+
70
+ ```
71
+ dia-tts-server/
72
+
73
+ ├── .env # Local configuration overrides (user-created)
74
+ ├── config.py # Default configuration and management class
75
+ ├── engine.py # Core model loading and generation logic
76
+ ├── models.py # Pydantic models for API requests
77
+ ├── requirements.txt # Python dependencies
78
+ ├── server.py # Main FastAPI application, API endpoints, UI routes
79
+ ├── utils.py # Utility functions (audio encoding, saving, etc.)
80
+
81
+ ├── dia/ # Core Dia model implementation package
82
+ │ ├── __init__.py
83
+ │ ├── audio.py # Audio processing helpers (delay, codebook conversion)
84
+ │ ├── config.py # Pydantic models for Dia model architecture config
85
+ │ ├── layers.py # Custom PyTorch layers for the Dia model
86
+ │ └── model.py # Dia model class wrapper (loading, generation)
87
+
88
+ ├── static/ # Static assets (e.g., favicon.ico)
89
+ │ └── favicon.ico
90
+
91
+ ├── ui/ # Web User Interface files
92
+ │ ├── index.html # Main HTML template (Jinja2)
93
+ │ ├── presets.yaml # Predefined UI examples
94
+ │ ├── script.js # Frontend JavaScript logic
95
+ │ └── style.css # Frontend CSS styling (Tailwind via CDN/build)
96
+
97
+ ├── model_cache/ # Default directory for downloaded model files (configurable)
98
+ ├── outputs/ # Default directory for saved audio output (configurable)
99
+ └── reference_audio/ # Default directory for voice cloning reference files (configurable)
100
+ ```
101
+
102
+ ### 2.2 Component Diagram
103
+
104
+ ```
105
+ ┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐
106
+ │ User (Web UI / │────→ │ FastAPI Server │────→ │ TTS Engine │────→ │ Dia Model Wrapper │
107
+ │ API Client) │ │ (server.py) │ │ (engine.py) │ │ (dia/model.py) │
108
+ └───────────────────┘ └─────────┬─────────┘ └─────────┬─────────┘ └─────────┬─────────┘
109
+ │ │ │
110
+ │ Uses │ Uses │ Uses
111
+ ▼ ▼ ▼
112
+ ┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐
113
+ │ Configuration │ ←─── │ .env File │ │ Dia Model Layers │
114
+ │ (config.py) │ └───────────────────┘ │ (dia/layers.py) │
115
+ └───────────────────┘ └───────────────────┘
116
+ │ │ Uses
117
+ │ Uses │
118
+ ▼ │
119
+ ┌───────────────────┐ │ Uses
120
+ │ Utilities │ ▼
121
+ │ (utils.py) │ ┌───────────────────┐
122
+ └───────────────────┘ │ PyTorch / CUDA │
123
+ ▲ └───────────────────┘
124
+ │ Uses │ Uses
125
+ │ ▼
126
+ ┌───────────────────┐ ┌───────────────────┐ ┌───────────────────┐
127
+ │ Web UI Files │ ←─── │ Jinja2 Templates │ │ DAC Model │
128
+ │ (ui/) │ └───────────────────┘ │ (descript-audio..)│
129
+ └───────────────────┘ ▲ └───────────────────┘
130
+ │ Renders ▲
131
+ │ │ Uses
132
+ └────────────────────────────────────────────────┘
133
+ ```
134
+
135
+ **Diagram Legend:**
136
+
137
+ * Boxes represent major components or file groups.
138
+ * Arrows (`→`) indicate primary data flow or control flow.
139
+ * Lines with "Uses" indicate dependencies or function calls.
140
+
141
+ ---
142
+
143
+ ## 3. System Prerequisites
144
+
145
+ Before installing and running the Dia TTS Server, ensure your system meets the following requirements:
146
+
147
+ * **Operating System:**
148
+ * Windows 10/11 (64-bit)
149
+ * Linux (Debian/Ubuntu recommended, other distributions may require adjustments)
150
+ * **Python:** Python 3.10 or later (Python 3.10.x recommended based on tracebacks). Ensure Python and Pip are added to your system's PATH.
151
+ * **Version Control:** Git (for cloning the repository).
152
+ * **Internet Connection:** Required for downloading dependencies and model files.
153
+ * **(Optional but Highly Recommended for Performance):**
154
+ * **NVIDIA GPU:** A CUDA-compatible NVIDIA GPU (Maxwell architecture or newer). Check compatibility [here](https://developer.nvidia.com/cuda-gpus). Sufficient VRAM is needed (BF16 model requires ~5-6GB, full precision ~10GB).
155
+ * **NVIDIA Drivers:** Latest appropriate drivers for your GPU and OS.
156
+ * **CUDA Toolkit:** Version compatible with the chosen PyTorch build (e.g., 11.8, 12.1). See [Section 4.4](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration).
157
+ * **(Linux System Libraries):**
158
+ * `libsndfile1`: Required by the `soundfile` Python library for audio I/O. Install using your package manager (e.g., `sudo apt install libsndfile1` on Debian/Ubuntu).
159
+
160
+ ---
161
+
162
+ ## 4. Installation and Setup
163
+
164
+ Follow these steps to set up the project environment and install necessary dependencies.
165
+
166
+ ### 4.1 Cloning the Repository
167
+
168
+ Open your terminal or command prompt and navigate to the directory where you want to store the project. Then, clone the repository:
169
+
170
+ ```bash
171
+ git clone https://github.com/devnen/dia-tts-server.git # Replace with the actual repo URL if different
172
+ cd dia-tts-server
173
+ ```
174
+
175
+ ### 4.2 Setting up Python Virtual Environment
176
+
177
+ Using a virtual environment is strongly recommended to isolate project dependencies.
178
+
179
+ #### 4.2.1 Windows Setup
180
+
181
+ 1. **Open PowerShell or Command Prompt** in the project directory (`dia-tts-server`).
182
+ 2. **Create the virtual environment:**
183
+ ```powershell
184
+ python -m venv venv
185
+ ```
186
+ 3. **Activate the virtual environment:**
187
+ ```powershell
188
+ .\venv\Scripts\activate
189
+ ```
190
+ Your terminal prompt should now be prefixed with `(venv)`.
191
+
192
+ #### 4.2.2 Linux Setup (Debian/Ubuntu Example)
193
+
194
+ 1. **Install prerequisites (if not already present):**
195
+ ```bash
196
+ sudo apt update
197
+ sudo apt install python3 python3-venv python3-pip libsndfile1 -y
198
+ ```
199
+ 2. **Open your terminal** in the project directory (`dia-tts-server`).
200
+ 3. **Create the virtual environment:**
201
+ ```bash
202
+ python3 -m venv venv
203
+ ```
204
+ 4. **Activate the virtual environment:**
205
+ ```bash
206
+ source venv/bin/activate
207
+ ```
208
+ Your terminal prompt should now be prefixed with `(venv)`.
209
+
210
+ ### 4.3 Installing Dependencies
211
+
212
+ With your virtual environment activated (`(venv)` prefix visible), install the required Python packages:
213
+
214
+ ```bash
215
+ # Upgrade pip first (optional but good practice)
216
+ pip install --upgrade pip
217
+
218
+ # Install all dependencies from requirements.txt
219
+ pip install -r requirements.txt
220
+ ```
221
+
222
+ **Note:** This command installs the CPU-only version of PyTorch by default. If you have a compatible NVIDIA GPU and want acceleration, proceed to [Section 4.4](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration) **before** running the server.
223
+
224
+ ### 4.4 NVIDIA Driver and CUDA Setup (Required for GPU Acceleration)
225
+
226
+ Follow these steps **only if you have a compatible NVIDIA GPU** and want faster inference.
227
+
228
+ #### 4.4.1 Step 1: Check/Install NVIDIA Drivers
229
+
230
+ 1. **Check Existing Driver:** Open Command Prompt (Windows) or Terminal (Linux) and run:
231
+ ```bash
232
+ nvidia-smi
233
+ ```
234
+ 2. **Interpret Output:**
235
+ * If the command runs successfully, note the **Driver Version** and the **CUDA Version** listed in the top right corner. This CUDA version is the *maximum* supported by your current driver.
236
+ * If the command fails ("not recognized"), you need to install or update your NVIDIA drivers.
237
+ 3. **Install/Update Drivers:** Go to the [NVIDIA Driver Downloads](https://www.nvidia.com/Download/index.aspx) page. Select your GPU model and OS, then download and install the latest recommended driver (Game Ready or Studio). **Reboot your computer** after installation. Run `nvidia-smi` again to confirm it works.
238
+
239
+ #### 4.4.2 Step 2: Install PyTorch with CUDA Support
240
+
241
+ 1. **Go to PyTorch Website:** Visit [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/).
242
+ 2. **Configure:** Select:
243
+ * **PyTorch Build:** Stable
244
+ * **Your OS:** Windows or Linux
245
+ * **Package:** Pip
246
+ * **Language:** Python
247
+ * **Compute Platform:** Choose the CUDA version **equal to or lower than** the version reported by `nvidia-smi`. For example, if `nvidia-smi` shows `CUDA Version: 12.4`, select `CUDA 12.1`. If it shows `11.8`, select `CUDA 11.8`. **Do not select a version higher than your driver supports.** (CUDA 12.1 or 11.8 are common stable choices).
248
+ 3. **Copy Command:** Copy the generated installation command. It will look similar to:
249
+ ```bash
250
+ # Example for CUDA 12.1 (Windows/Linux):
251
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
252
+ # Example for CUDA 11.8 (Windows/Linux):
253
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
254
+ ```
255
+ *(Use `pip` instead of `pip3` if that's your command)*
256
+ 4. **Install in Activated venv:**
257
+ * Ensure your `(venv)` is active.
258
+ * **Uninstall CPU PyTorch first:**
259
+ ```bash
260
+ pip uninstall torch torchvision torchaudio -y
261
+ ```
262
+ * **Paste and run the copied command** from the PyTorch website.
263
+
264
+ #### 4.4.3 Step 3: Verify PyTorch CUDA Installation
265
+
266
+ 1. With the `(venv)` still active, start a Python interpreter:
267
+ ```bash
268
+ python
269
+ ```
270
+ 2. Run the following Python code:
271
+ ```python
272
+ import torch
273
+ print(f"PyTorch version: {torch.__version__}")
274
+ cuda_available = torch.cuda.is_available()
275
+ print(f"CUDA available: {cuda_available}")
276
+ if cuda_available:
277
+ print(f"CUDA version used by PyTorch: {torch.version.cuda}")
278
+ print(f"Device count: {torch.cuda.device_count()}")
279
+ print(f"Current device index: {torch.cuda.current_device()}")
280
+ print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
281
+ else:
282
+ print("CUDA not available to PyTorch. Ensure drivers and CUDA-enabled PyTorch are installed correctly.")
283
+ exit()
284
+ ```
285
+ 3. If `CUDA available:` shows `True`, the setup was successful. If `False`, review driver installation and the PyTorch installation command.
286
+
287
+ ---
288
+
289
+ ## 5. Configuration
290
+
291
+ The server's behavior, including model selection, paths, and default generation parameters, is controlled via configuration settings.
292
+
293
+ ### 5.1 Configuration Files (`.env` and `config.py`)
294
+
295
+ * **`config.py`:** Defines the *default* values for all configuration parameters in the `DEFAULT_CONFIG` dictionary. It also contains the `ConfigManager` class and getter functions used by the application.
296
+ * **`.env` File:** This file, located in the project root directory (`dia-tts-server/.env`), allows you to *override* the default values. Create this file if it doesn't exist. Settings are defined as `KEY=VALUE` pairs, one per line. The server reads this file on startup using `python-dotenv`.
297
+
298
+ **Priority:** Values set in the `.env` file take precedence over the defaults in `config.py`. Environment variables set directly in your system also override `.env` file values (though using `.env` is generally recommended for project-specific settings).
299
+
300
+ ### 5.2 Configuration Parameters
301
+
302
+ The following parameters can be set in your `.env` file:
303
+
304
+ | Parameter Name (in `.env`) | Default Value (`config.py`) | Description | Example `.env` Value |
305
+ | :--------------------------------- | :--------------------------------- | :--------------------------------------------------------------------------------------------------------- | :----------------------------------- |
306
+ | **Server Settings** | | | |
307
+ | `HOST` | `0.0.0.0` | The network interface address the server listens on. `0.0.0.0` makes it accessible on your local network. | `127.0.0.1` (localhost only) |
308
+ | `PORT` | `8003` | The port number the server listens on. | `8080` |
309
+ | **Model Source Settings** | | | |
310
+ | `DIA_MODEL_REPO_ID` | `ttj/dia-1.6b-safetensors` | The Hugging Face repository ID containing the model files. | `nari-labs/Dia-1.6B` |
311
+ | `DIA_MODEL_CONFIG_FILENAME` | `config.json` | The filename of the model's configuration JSON within the repository. | `config.json` |
312
+ | `DIA_MODEL_WEIGHTS_FILENAME` | `dia-v0_1_bf16.safetensors` | The filename of the model weights file (`.safetensors` or `.pth`) within the repository to load. | `dia-v0_1.safetensors` or `dia-v0_1.pth` |
313
+ | **Path Settings** | | | |
314
+ | `DIA_MODEL_CACHE_PATH` | `./model_cache` | Local directory to store downloaded model files. Relative paths are based on the project root. | `/path/to/shared/cache` |
315
+ | `REFERENCE_AUDIO_PATH` | `./reference_audio` | Local directory to store reference audio files (`.wav`, `.mp3`) used for voice cloning. | `./voices` |
316
+ | `OUTPUT_PATH` | `./outputs` | Local directory where generated audio files from the Web UI are saved. | `./generated_speech` |
317
+ | **Default Generation Parameters** | | *(These set the initial UI values and can be saved via the UI)* | |
318
+ | `GEN_DEFAULT_SPEED_FACTOR` | `0.90` | Default playback speed factor applied *after* generation (UI slider initial value). | `1.0` |
319
+ | `GEN_DEFAULT_CFG_SCALE` | `3.0` | Default Classifier-Free Guidance scale (UI slider initial value). | `2.5` |
320
+ | `GEN_DEFAULT_TEMPERATURE` | `1.3` | Default sampling temperature (UI slider initial value). | `1.2` |
321
+ | `GEN_DEFAULT_TOP_P` | `0.95` | Default nucleus sampling probability (UI slider initial value). | `0.9` |
322
+ | `GEN_DEFAULT_CFG_FILTER_TOP_K` | `35` | Default Top-K value for CFG filtering (UI slider initial value). | `40` |
323
+
324
+ **Example `.env` File (Using Original Nari Labs Model):**
325
+
326
+ ```dotenv
327
+ # .env
328
+ # Example configuration to use the original Nari Labs model
329
+
330
+ HOST=0.0.0.0
331
+ PORT=8003
332
+
333
+ DIA_MODEL_REPO_ID=nari-labs/Dia-1.6B
334
+ DIA_MODEL_CONFIG_FILENAME=config.json
335
+ DIA_MODEL_WEIGHTS_FILENAME=dia-v0_1.pth
336
+
337
+ # Keep other paths as default or specify custom ones
338
+ # DIA_MODEL_CACHE_PATH=./model_cache
339
+ # REFERENCE_AUDIO_PATH=./reference_audio
340
+ # OUTPUT_PATH=./outputs
341
+
342
+ # Keep default generation parameters or override them
343
+ # GEN_DEFAULT_SPEED_FACTOR=0.90
344
+ # GEN_DEFAULT_CFG_SCALE=3.0
345
+ # GEN_DEFAULT_TEMPERATURE=1.3
346
+ # GEN_DEFAULT_TOP_P=0.95
347
+ # GEN_DEFAULT_CFG_FILTER_TOP_K=35
348
+ ```
349
+
350
+ **Important:** You must **restart the server** after making changes to the `.env` file for them to take effect.
351
+
352
+ ---
353
+
354
+ ## 6. Running the Server
355
+
356
+ 1. **Activate Virtual Environment:** Ensure your virtual environment is activated (`(venv)` prefix).
357
+ * Windows: `.\venv\Scripts\activate`
358
+ * Linux: `source venv/bin/activate`
359
+ 2. **Navigate to Project Root:** Make sure your terminal is in the `dia-tts-server` directory.
360
+ 3. **Run the Server:**
361
+ ```bash
362
+ python server.py
363
+ ```
364
+ 4. **Server Output:** You should see log messages indicating the server is starting, including:
365
+ * The configuration being used (repo ID, filenames, paths).
366
+ * The device being used (CPU or CUDA).
367
+ * Model loading progress (downloading if necessary).
368
+ * Confirmation that the server is running (e.g., `Uvicorn running on http://0.0.0.0:8003`).
369
+ * URLs for accessing the Web UI and API Docs.
370
+
371
+ 5. **Accessing the Server:**
372
+ * **Web UI:** Open your web browser and go to `http://localhost:PORT` (e.g., `http://localhost:8003` if using the default port). If running on a different machine or VM, replace `localhost` with the server's IP address.
373
+ * **API Docs:** Access the interactive API documentation (Swagger UI) at `http://localhost:PORT/docs`.
374
+ 6. **Stopping the Server:** Press `CTRL+C` in the terminal where the server is running.
375
+
376
+ **Auto-Reload:** The server is configured to run with `reload=True`. This means Uvicorn will automatically restart the server if it detects changes in `.py`, `.html`, `.css`, `.js`, `.env`, or `.yaml` files within the project or `ui` directory. This is useful for development but should generally be disabled in production.
377
+
378
+ ---
379
+
380
+ ## 7. Usage
381
+
382
+ The Dia TTS Server can be used via its Web UI or its API endpoints.
383
+
384
+ ### 7.1 Web User Interface (Web UI)
385
+
386
+ Access the UI by navigating to the server's base URL (e.g., `http://localhost:8003`).
387
+
388
+ #### 7.1.1 Main Generation Form
389
+
390
+ * **Text to speak:** Enter the text you want to synthesize.
391
+ * Use `[S1]` and `[S2]` tags to indicate speaker turns for dialogue.
392
+ * Include non-verbal cues like `(laughs)`, `(sighs)`, `(clears throat)` directly in the text where desired.
393
+ * For voice cloning, **prepend the exact transcript** of the selected reference audio before the text you want generated (e.g., `[S1] Reference transcript text. [S1] This is the new text to generate in the cloned voice.`).
394
+ * **Voice Mode:** Select the desired generation mode:
395
+ * **Single / Dialogue (Use [S1]/[S2]):** Use this for single-speaker text (you can use `[S1]` or omit tags if the model handles it) or multi-speaker dialogue (using `[S1]` and `[S2]`).
396
+ * **Voice Clone (from Reference):** Enables voice cloning based on a selected audio file. Requires selecting a file below and prepending its transcript to the text input.
397
+ * **Generate Speech Button:** Submits the text and settings to the server to start generation.
398
+
399
+ #### 7.1.2 Presets
400
+
401
+ * Located below the Voice Mode selection.
402
+ * Clicking a preset button (e.g., "Standard Dialogue", "Expressive Narration") will automatically populate the "Text to speak" area and the "Generation Parameters" sliders with predefined values, demonstrating different use cases.
403
+
404
+ #### 7.1.3 Voice Cloning
405
+
406
+ * This section appears only when "Voice Clone" mode is selected.
407
+ * **Reference Audio File Dropdown:** Lists available `.wav` and `.mp3` files found in the configured `REFERENCE_AUDIO_PATH`. Select the file whose voice you want to clone. Remember to prepend its transcript to the main text input.
408
+ * **Load Button:** Click this to open your system's file browser. You can select one or more `.wav` or `.mp3` files to upload. The selected files will be copied to the server's `REFERENCE_AUDIO_PATH`, and the dropdown list will refresh automatically. The first newly uploaded file will be selected in the dropdown.
409
+
410
+ #### 7.1.4 Generation Parameters
411
+
412
+ * Expand this section to fine-tune the generation process. These values correspond to the parameters used by the underlying Dia model.
413
+ * **Sliders:** Adjust Speed Factor, CFG Scale, Temperature, Top P, and CFG Filter Top K. The current value is displayed next to the label.
414
+ * **Save Generation Defaults Button:** Saves the *current* values of these sliders to the `.env` file (as `GEN_DEFAULT_...` keys). These saved values will become the default settings loaded into the UI the next time the server starts.
415
+
416
+ #### 7.1.5 Server Configuration (UI)
417
+
418
+ * Expand this section to view and modify server-level settings stored in the `.env` file.
419
+ * **Fields:** Edit Model Repo ID, Config/Weights Filenames, Cache/Reference/Output Paths, Host, and Port.
420
+ * **Save Server Configuration Button:** Saves the values currently shown in these fields to the `.env` file. **A server restart is required** for most of these changes (especially model source or paths) to take effect.
421
+ * **Restart Server Button:** (Appears after saving) Attempts to trigger a server restart. This works best if the server was started with `reload=True` or is managed by a process manager like systemd or Supervisor.
422
+
423
+ #### 7.1.6 Generated Audio Player
424
+
425
+ * Appears below the main form after a successful generation.
426
+ * **Waveform:** Visual representation of the generated audio.
427
+ * **Play/Pause Button:** Controls audio playback.
428
+ * **Download WAV Button:** Downloads the generated audio as a `.wav` file.
429
+ * **Info:** Displays the voice mode used, generation time, and audio duration.
430
+
431
+ #### 7.1.7 Theme Toggle
432
+
433
+ * Located in the top-right navigation bar.
434
+ * Click the Sun/Moon icon to switch between Light and Dark themes. Your preference is saved in your browser's `localStorage`.
435
+
436
+ ### 7.2 API Endpoints
437
+
438
+ Access the interactive API documentation via the `/docs` path (e.g., `http://localhost:8003/docs`).
439
+
440
+ #### 7.2.1 POST `/v1/audio/speech` (OpenAI Compatible)
441
+
442
+ * **Purpose:** Provides an endpoint compatible with the basic OpenAI TTS API for easier integration with existing tools.
443
+ * **Request Body:** (`application/json`) - Uses the `OpenAITTSRequest` model.
444
+ | Field | Type | Required | Description | Example |
445
+ | :---------------- | :----------------------- | :------- | :---------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------- |
446
+ | `model` | string | No | Ignored by this server (always uses Dia). Included for compatibility. Defaults to `dia-1.6b`. | `"dia-1.6b"` |
447
+ | `input` | string | Yes | The text to synthesize. Use `[S1]`/`[S2]` tags for dialogue. For cloning, prepend reference transcript. | `"Hello [S1] world."` |
448
+ | `voice` | string | No | Maps to Dia modes. Use `"S1"`, `"S2"`, `"dialogue"`, or the filename of a reference audio (e.g., `"my_ref.wav"`) for cloning. Defaults to `S1`. | `"dialogue"` or `"ref.mp3"` |
449
+ | `response_format` | `"opus"` \| `"wav"` | No | Desired audio output format. Defaults to `opus`. | `"wav"` |
450
+ | `speed` | float | No | Playback speed factor (0.5-2.0). Applied *after* generation. Defaults to `1.0`. | `0.9` |
451
+ * **Response:**
452
+ * **Success (200 OK):** `StreamingResponse` containing the binary audio data (`audio/opus` or `audio/wav`).
453
+ * **Error:** Standard FastAPI JSON error response (e.g., 400, 404, 500).
454
+
455
+ #### 7.2.2 POST `/tts` (Custom Parameters)
456
+
457
+ * **Purpose:** Allows generation using all specific Dia generation parameters.
458
+ * **Request Body:** (`application/json`) - Uses the `CustomTTSRequest` model.
459
+ | Field | Type | Required | Description | Default |
460
+ | :------------------------- | :------------------------------------- | :------- | :---------------------------------------------------------------------------------------------------------------------------------------- | :---------- |
461
+ | `text` | string | Yes | The text to synthesize. Use `[S1]`/`[S2]` tags. Prepend transcript for cloning. | |
462
+ | `voice_mode` | `"dialogue"` \| `"clone"` | No | Generation mode. Note: `single_s1`/`single_s2` are handled via `dialogue` mode with appropriate tags in the text. | `dialogue` |
463
+ | `clone_reference_filename` | string \| null | No | Filename of reference audio in `REFERENCE_AUDIO_PATH`. **Required if `voice_mode` is `clone`**. | `null` |
464
+ | `output_format` | `"opus"` \| `"wav"` | No | Desired audio output format. | `opus` |
465
+ | `max_tokens` | integer \| null | No | Maximum audio tokens to generate. `null` uses the model's default. | `null` |
466
+ | `cfg_scale` | float | No | Classifier-Free Guidance scale. | `3.0` |
467
+ | `temperature` | float | No | Sampling temperature. | `1.3` |
468
+ | `top_p` | float | No | Nucleus sampling probability. | `0.95` |
469
+ | `speed_factor` | float | No | Playback speed factor (0.5-2.0). Applied *after* generation. | `0.90` |
470
+ | `cfg_filter_top_k` | integer | No | Top-K value for CFG filtering. | `35` |
471
+ * **Response:**
472
+ * **Success (200 OK):** `StreamingResponse` containing the binary audio data (`audio/opus` or `audio/wav`).
473
+ * **Error:** Standard FastAPI JSON error response (e.g., 400, 404, 500).
474
+
475
+ #### 7.2.3 Configuration & Helper Endpoints
476
+
477
+ * **GET `/get_config`:** Returns the current server configuration as JSON.
478
+ * **POST `/save_config`:** Saves server configuration settings provided in the JSON request body to the `.env` file. Requires server restart.
479
+ * **POST `/save_generation_defaults`:** Saves default generation parameters provided in the JSON request body to the `.env` file. Affects UI defaults on next load.
480
+ * **POST `/restart_server`:** Attempts to trigger a server restart (reliability depends on execution environment).
481
+ * **POST `/upload_reference`:** Uploads one or more audio files (`.wav`, `.mp3`) as `multipart/form-data` to the reference audio directory. Returns JSON with status and updated file list.
482
+ * **GET `/health`:** Basic health check endpoint. Returns `{"status": "healthy", "model_loaded": true/false}`.
483
+
484
+ ---
485
+
486
+ ## 8. Troubleshooting
487
+
488
+ * **Error: `CUDA available: False` or Slow Performance:**
489
+ * Verify NVIDIA drivers are installed correctly (`nvidia-smi` command).
490
+ * Ensure you installed the correct PyTorch version with CUDA support matching your driver (See [Section 4.4](#44-nvidia-driver-and-cuda-setup-required-for-gpu-acceleration)). Reinstall PyTorch using the command from the official website if unsure.
491
+ * Check if another process is using all GPU VRAM.
492
+ * **Error: `ImportError: No module named 'dac'` (or `safetensors`, `yaml`, etc.):**
493
+ * Make sure your virtual environment is activated.
494
+ * Run `pip install -r requirements.txt` again to install missing dependencies.
495
+ * Specifically for `dac`, ensure you installed `descript-audio-codec` and not a different package named `dac`. Run `pip uninstall dac -y && pip install descript-audio-codec`.
496
+ * **Error: `libsndfile library not found` (or similar `soundfile` error, mainly on Linux):**
497
+ * Install the system library: `sudo apt update && sudo apt install libsndfile1` (Debian/Ubuntu) or the equivalent for your distribution.
498
+ * **Error: Model Download Fails (e.g., `HTTPError`, `ConnectionError`):**
499
+ * Check your internet connection.
500
+ * Verify the `DIA_MODEL_REPO_ID`, `DIA_MODEL_CONFIG_FILENAME`, and `DIA_MODEL_WEIGHTS_FILENAME` in your `.env` file (or defaults in `config.py`) are correct and accessible on Hugging Face Hub.
501
+ * Check Hugging Face Hub status if multiple downloads fail.
502
+ * Ensure the cache directory (`DIA_MODEL_CACHE_PATH`) is writable.
503
+ * **Error: `RuntimeError: Failed to load DAC model...`:**
504
+ * This usually indicates an issue with the `descript-audio-codec` installation or version incompatibility. Ensure it's installed correctly (see `ImportError` above).
505
+ * Check logs for specific `AttributeError` messages (like missing `utils` or `download`) which might indicate version mismatches between the Dia code's expectation and the installed library. The current code expects `dac.utils.download()`.
506
+ * **Error: `FileNotFoundError` during generation (Reference Audio):**
507
+ * Ensure the filename selected/provided for voice cloning exists in the configured `REFERENCE_AUDIO_PATH`.
508
+ * Check that the path in `config.py` or `.env` is correct and the server has permission to read from it.
509
+ * **Error: Cannot Save Output/Reference Files (`PermissionError`, etc.):**
510
+ * Ensure the directories specified by `OUTPUT_PATH` and `REFERENCE_AUDIO_PATH` exist and the server process has write permissions to them.
511
+ * **Web UI Issues (Buttons don't work, styles missing):**
512
+ * Clear your browser cache.
513
+ * Check the browser's developer console (usually F12) for JavaScript errors.
514
+ * Ensure `ui/script.js` and `ui/style.css` are being loaded correctly (check network tab in developer tools).
515
+ * **Generation Cancel Button Doesn't Stop Process:**
516
+ * This is expected ("Fake Cancel"). The button currently only prevents the UI from processing the result when it eventually arrives. True cancellation is complex and not implemented. Clicking "Generate" again *will* cancel the *previous UI request's result processing* before starting the new one.
517
+
518
+ ---
519
+
520
+ ## 9. Project Architecture
521
+
522
+ * **`server.py`:** The main entry point using FastAPI. Defines API routes, serves the Web UI using Jinja2, handles requests, and orchestrates calls to the engine.
523
+ * **`engine.py`:** Responsible for loading the Dia model (including downloading files via `huggingface_hub`), managing the model instance, preparing inputs for the model's `generate` method based on user requests (handling voice modes), and calling the model's generation function. Also handles post-processing like speed adjustment.
524
+ * **`config.py`:** Manages all configuration settings using default values and overrides from a `.env` file. Provides getter functions for easy access to settings.
525
+ * **`dia/` package:** Contains the core implementation of the Dia model itself.
526
+ * `model.py`: Defines the `Dia` class, which wraps the underlying PyTorch model (`DiaModel`). It handles loading weights (`.pth` or `.safetensors`), loading the required DAC model, preparing inputs specifically for the `DiaModel` forward pass (including CFG logic), and running the autoregressive generation loop.
527
+ * `config.py` (within `dia/`): Defines Pydantic models representing the *structure* and hyperparameters of the Dia model architecture (encoder, decoder, data parameters). This is loaded from the `config.json` file associated with the model weights.
528
+ * `layers.py`: Contains custom PyTorch `nn.Module` implementations used within the `DiaModel` (e.g., Attention blocks, MLP blocks, RoPE).
529
+ * `audio.py`: Includes helper functions for audio processing specific to the model's tokenization and delay patterns (e.g., `audio_to_codebook`, `codebook_to_audio`, `apply_audio_delay`).
530
+ * **`ui/` directory:** Contains all files related to the Web UI.
531
+ * `index.html`: The main Jinja2 template.
532
+ * `script.js`: Frontend JavaScript for interactivity, API calls, theme switching, etc.
533
+ * `presets.yaml`: Definitions for the UI preset examples.
534
+ * **`utils.py`:** General utility functions, such as audio encoding (`encode_audio`) and saving (`save_audio_to_file`) using the `soundfile` library.
535
+ * **Dependencies:** Relies heavily on `FastAPI`, `Uvicorn`, `PyTorch`, `torchaudio`, `huggingface_hub`, `safetensors`, `descript-audio-codec`, `soundfile`, `PyYAML`, `python-dotenv`, `pydantic`, and `Jinja2`.
536
+
537
+ ---
538
+
539
+ ## 10. License and Disclaimer
540
+
541
+ * **License:** This project is licensed under the MIT License.
542
+ * **Disclaimer:** This project offers a high-fidelity speech generation model intended solely for research and educational use. The following uses are **strictly forbidden**:
543
+ * **Identity Misuse**: Do not produce audio resembling real individuals without permission.
544
+ * **Deceptive Content**: Do not use this model to generate misleading content (e.g. fake news)
545
+ * **Illegal or Malicious Use**: Do not use this model for activities that are illegal or intended to cause harm.
546
+
547
+ By using this model, you agree to uphold relevant legal standards and ethical responsibilities. The creators **are not responsible** for any misuse and firmly oppose any unethical usage of this technology.
548
+
549
+ ---
download_model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # download_model.py
2
+ # Utility script to download the Dia model and dependencies without starting the server.
3
+
4
+ import logging
5
+ import os
6
+ import engine # Import the engine module to trigger its loading logic
7
+
8
+ # Configure basic logging for the script
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
11
+ )
12
+ logger = logging.getLogger("ModelDownloader")
13
+
14
+ if __name__ == "__main__":
15
+ logger.info("--- Starting Dia Model Download ---")
16
+
17
+ # Ensure cache directory exists (redundant if engine.load_model does it, but safe)
18
+ try:
19
+ from config import get_model_cache_path
20
+
21
+ cache_path = get_model_cache_path()
22
+ os.makedirs(cache_path, exist_ok=True)
23
+ logger.info(
24
+ f"Ensured model cache directory exists: {os.path.abspath(cache_path)}"
25
+ )
26
+ except Exception as e:
27
+ logger.warning(f"Could not ensure cache directory exists: {e}")
28
+
29
+ # Trigger the model loading function from the engine
30
+ logger.info("Calling engine.load_model() to initiate download if necessary...")
31
+ success = engine.load_model()
32
+
33
+ if success:
34
+ logger.info("--- Model download/load process completed successfully ---")
35
+ else:
36
+ logger.error(
37
+ "--- Model download/load process failed. Check logs for details. ---"
38
+ )
39
+ exit(1) # Exit with error code
40
+
41
+ logger.info("You can now start the server using 'python server.py'")
engine.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # engine.py
2
+ # Core Dia TTS model loading and generation logic
3
+
4
+ import logging
5
+ import time
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ from typing import Optional, Tuple
10
+ from huggingface_hub import hf_hub_download # Import downloader
11
+
12
+ # Import Dia model class and config
13
+ try:
14
+ from dia.model import Dia
15
+ from dia.config import DiaConfig
16
+ except ImportError as e:
17
+ # Log critical error if core components are missing
18
+ logging.critical(
19
+ f"Failed to import Dia model components: {e}. Ensure the 'dia' package exists and is importable.",
20
+ exc_info=True,
21
+ )
22
+
23
+ # Define dummy classes/functions to prevent server crash on import,
24
+ # but generation will fail later if these are used.
25
+ class Dia:
26
+ @staticmethod
27
+ def load_model_from_files(*args, **kwargs):
28
+ raise RuntimeError("Dia model package not available or failed to import.")
29
+
30
+ def generate(*args, **kwargs):
31
+ raise RuntimeError("Dia model package not available or failed to import.")
32
+
33
+ class DiaConfig:
34
+ pass
35
+
36
+
37
+ # Import configuration getters from our project's config.py
38
+ from config import (
39
+ get_model_repo_id,
40
+ get_model_cache_path,
41
+ get_reference_audio_path,
42
+ get_model_config_filename,
43
+ get_model_weights_filename,
44
+ )
45
+
46
+ logger = logging.getLogger(__name__) # Use standard logger name
47
+
48
+ # --- Global Variables ---
49
+ dia_model: Optional[Dia] = None
50
+ # model_config is now loaded within Dia.load_model_from_files, maybe remove global?
51
+ # Let's keep it for now if needed elsewhere, but populate it after loading.
52
+ model_config_instance: Optional[DiaConfig] = None
53
+ model_device: Optional[torch.device] = None
54
+ MODEL_LOADED = False
55
+ EXPECTED_SAMPLE_RATE = 44100 # Dia model and DAC typically operate at 44.1kHz
56
+
57
+ # --- Model Loading ---
58
+
59
+
60
+ def get_device() -> torch.device:
61
+ """Determines the optimal torch device (CUDA > MPS > CPU)."""
62
+ if torch.cuda.is_available():
63
+ logger.info("CUDA is available, using GPU.")
64
+ return torch.device("cuda")
65
+ # Add MPS check for Apple Silicon GPUs
66
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
67
+ # Basic check is usually sufficient
68
+ logger.info("MPS is available, using Apple Silicon GPU.")
69
+ return torch.device("mps")
70
+ else:
71
+ logger.info("CUDA and MPS not available, using CPU.")
72
+ return torch.device("cpu")
73
+
74
+
75
+ def load_model():
76
+ """
77
+ Loads the Dia TTS model and associated DAC model.
78
+ Downloads model files based on configuration if they don't exist locally.
79
+ Handles both .pth and .safetensors formats.
80
+ """
81
+ global dia_model, model_config_instance, model_device, MODEL_LOADED
82
+
83
+ if MODEL_LOADED:
84
+ logger.info("Dia model already loaded.")
85
+ return True
86
+
87
+ # Get configuration values
88
+ repo_id = get_model_repo_id()
89
+ config_filename = get_model_config_filename()
90
+ weights_filename = get_model_weights_filename()
91
+ cache_path = get_model_cache_path() # Already absolute path
92
+ model_device = get_device()
93
+
94
+ logger.info(f"Attempting to load Dia model:")
95
+ logger.info(f" Repo ID: {repo_id}")
96
+ logger.info(f" Config File: {config_filename}")
97
+ logger.info(f" Weights File: {weights_filename}")
98
+ logger.info(f" Cache Directory: {cache_path}")
99
+ logger.info(f" Target Device: {model_device}")
100
+
101
+ # Ensure cache directory exists
102
+ try:
103
+ os.makedirs(cache_path, exist_ok=True)
104
+ except OSError as e:
105
+ logger.error(
106
+ f"Failed to create cache directory '{cache_path}': {e}", exc_info=True
107
+ )
108
+ # Depending on severity, might want to return False here
109
+ # return False
110
+ pass # Continue and let hf_hub_download handle potential issues
111
+
112
+ try:
113
+ start_time = time.time()
114
+
115
+ # --- Download Model Files ---
116
+ logger.info(
117
+ f"Downloading/finding configuration file '{config_filename}' from repo '{repo_id}'..."
118
+ )
119
+ local_config_path = hf_hub_download(
120
+ repo_id=repo_id,
121
+ filename=config_filename,
122
+ cache_dir=cache_path,
123
+ # force_download=False, # Default: only download if missing or outdated
124
+ # resume_download=True, # Default: resume interrupted downloads
125
+ )
126
+ logger.info(f"Configuration file path: {local_config_path}")
127
+
128
+ logger.info(
129
+ f"Downloading/finding weights file '{weights_filename}' from repo '{repo_id}'..."
130
+ )
131
+ local_weights_path = hf_hub_download(
132
+ repo_id=repo_id,
133
+ filename=weights_filename,
134
+ cache_dir=cache_path,
135
+ )
136
+ logger.info(f"Weights file path: {local_weights_path}")
137
+
138
+ # --- Load Model using the class method ---
139
+ # The Dia class method now handles config loading, instantiation, weight loading, etc.
140
+ dia_model = Dia.load_model_from_files(
141
+ config_path=local_config_path,
142
+ weights_path=local_weights_path,
143
+ device=model_device,
144
+ )
145
+
146
+ # Store the config instance if needed globally (optional)
147
+ model_config_instance = dia_model.config
148
+
149
+ end_time = time.time()
150
+ logger.info(
151
+ f"Dia model loaded successfully in {end_time - start_time:.2f} seconds."
152
+ )
153
+ MODEL_LOADED = True
154
+ return True
155
+
156
+ except FileNotFoundError as e:
157
+ logger.error(
158
+ f"Model loading failed: Required file not found. {e}", exc_info=True
159
+ )
160
+ MODEL_LOADED = False
161
+ return False
162
+ except ImportError:
163
+ # This catches if the 'dia' package itself is missing
164
+ logger.critical(
165
+ "Failed to load model: Dia package or its core dependencies not found.",
166
+ exc_info=True,
167
+ )
168
+ MODEL_LOADED = False
169
+ return False
170
+ except Exception as e:
171
+ # Catch other potential errors during download or loading
172
+ logger.error(
173
+ f"Error loading Dia model from repo '{repo_id}': {e}", exc_info=True
174
+ )
175
+ dia_model = None
176
+ model_config_instance = None
177
+ MODEL_LOADED = False
178
+ return False
179
+
180
+
181
+ # --- Speech Generation ---
182
+
183
+
184
+ def generate_speech(
185
+ text: str,
186
+ voice_mode: str = "single_s1",
187
+ clone_reference_filename: Optional[str] = None,
188
+ max_tokens: Optional[int] = None,
189
+ cfg_scale: float = 3.0,
190
+ temperature: float = 1.3,
191
+ top_p: float = 0.95,
192
+ speed_factor: float = 0.94, # Keep speed factor separate from model generation params
193
+ cfg_filter_top_k: int = 35,
194
+ ) -> Optional[Tuple[np.ndarray, int]]:
195
+ """
196
+ Generates speech using the loaded Dia model, handling voice modes and speed adjustment.
197
+
198
+ Args:
199
+ text: Text to synthesize.
200
+ voice_mode: 'dialogue', 'single_s1', 'single_s2', 'clone'.
201
+ clone_reference_filename: Filename for voice cloning (if mode is 'clone'). Located in reference audio path.
202
+ max_tokens: Max generation tokens for the model's generate method.
203
+ cfg_scale: CFG scale for the model's generate method.
204
+ temperature: Sampling temperature for the model's generate method.
205
+ top_p: Nucleus sampling p for the model's generate method.
206
+ speed_factor: Factor to adjust the playback speed *after* generation (e.g., 0.9 = slower, 1.1 = faster).
207
+ cfg_filter_top_k: CFG filter top K for the model's generate method.
208
+
209
+ Returns:
210
+ Tuple of (numpy_audio_array, sample_rate), or None on failure.
211
+ """
212
+ if not MODEL_LOADED or dia_model is None:
213
+ logger.error("Dia model is not loaded. Cannot generate speech.")
214
+ return None
215
+
216
+ logger.info(f"Generating speech with mode: {voice_mode}")
217
+ logger.debug(f"Input text (start): '{text[:100]}...'")
218
+ # Log model generation parameters
219
+ logger.debug(
220
+ f"Model Params: max_tokens={max_tokens}, cfg={cfg_scale}, temp={temperature}, top_p={top_p}, top_k={cfg_filter_top_k}"
221
+ )
222
+ # Log post-processing parameters
223
+ logger.debug(f"Post-processing Params: speed_factor={speed_factor}")
224
+
225
+ audio_prompt_path = None
226
+ processed_text = text # Start with original text
227
+
228
+ # --- Handle Voice Mode ---
229
+ if voice_mode == "clone":
230
+ if not clone_reference_filename:
231
+ logger.error("Clone mode selected but no reference filename provided.")
232
+ return None
233
+ ref_base_path = get_reference_audio_path() # Gets absolute path
234
+ potential_path = os.path.join(ref_base_path, clone_reference_filename)
235
+ if os.path.isfile(potential_path):
236
+ audio_prompt_path = potential_path
237
+ logger.info(f"Using audio prompt for cloning: {audio_prompt_path}")
238
+ # Dia requires the transcript of the clone audio to be prepended to the target text.
239
+ # The UI/API caller is responsible for constructing this combined text.
240
+ logger.warning(
241
+ "Clone mode active. Ensure the 'text' input includes the transcript of the reference audio for best results (e.g., '[S1] Reference transcript. [S1] Target text...')."
242
+ )
243
+ processed_text = text # Use the combined text provided by the caller
244
+ else:
245
+ logger.error(f"Reference audio file not found: {potential_path}")
246
+ return None # Fail generation if reference file is missing
247
+ elif voice_mode == "dialogue":
248
+ # Assume text already contains [S1]/[S2] tags as required by the model
249
+ logger.info("Using dialogue mode. Expecting [S1]/[S2] tags in input text.")
250
+ if "[S1]" not in text and "[S2]" not in text:
251
+ logger.warning(
252
+ "Dialogue mode selected, but no [S1] or [S2] tags found in the input text."
253
+ )
254
+ processed_text = text # Pass directly
255
+ elif voice_mode == "single_s1":
256
+ logger.info("Using single voice mode (S1).")
257
+ # Check if text *already* contains tags, warn if so, as it might confuse the model
258
+ if "[S1]" in text or "[S2]" in text:
259
+ logger.warning(
260
+ "Input text contains dialogue tags ([S1]/[S2]), but 'single_s1' mode was selected. Model behavior might be unexpected."
261
+ )
262
+ # Dia likely expects tags even for single speaker. Prepending [S1] might be safer.
263
+ # Let's assume for now the model handles untagged text as S1, but this could be adjusted.
264
+ # Consider: processed_text = f"[S1] {text}" # Option to enforce S1 tag
265
+ processed_text = text # Pass directly for now
266
+ elif voice_mode == "single_s2":
267
+ logger.info("Using single voice mode (S2).")
268
+ if "[S1]" in text or "[S2]" in text:
269
+ logger.warning(
270
+ "Input text contains dialogue tags ([S1]/[S2]), but 'single_s2' mode was selected."
271
+ )
272
+ # Similar to S1, how to signal S2? Prepending [S2] seems logical if needed.
273
+ # Consider: processed_text = f"[S2] {text}" # Option to enforce S2 tag
274
+ processed_text = text # Pass directly for now
275
+ else:
276
+ logger.error(
277
+ f"Unsupported voice_mode: {voice_mode}. Defaulting to 'single_s1'."
278
+ )
279
+ processed_text = text # Fallback
280
+
281
+ # --- Call Dia Generate ---
282
+ try:
283
+ start_time = time.time()
284
+ logger.info("Calling Dia model generate method...")
285
+
286
+ # Call the model's generate method with appropriate parameters
287
+ generated_audio_np = dia_model.generate(
288
+ text=processed_text,
289
+ audio_prompt_path=audio_prompt_path,
290
+ max_tokens=max_tokens, # Pass None if not specified, Dia uses its default
291
+ cfg_scale=cfg_scale,
292
+ temperature=temperature,
293
+ top_p=top_p,
294
+ use_cfg_filter=True, # Default from Dia's app.py, seems reasonable
295
+ cfg_filter_top_k=cfg_filter_top_k,
296
+ use_torch_compile=False, # Keep False for stability unless specifically tested/enabled
297
+ )
298
+ gen_end_time = time.time()
299
+ logger.info(
300
+ f"Dia model generation finished in {gen_end_time - start_time:.2f} seconds."
301
+ )
302
+
303
+ if generated_audio_np is None or generated_audio_np.size == 0:
304
+ logger.warning("Dia model returned None or empty audio array.")
305
+ return None
306
+
307
+ # --- Apply Speed Factor (Post-processing) ---
308
+ # This mimics the logic in Dia's original app.py
309
+ if speed_factor != 1.0:
310
+ logger.info(f"Applying speed factor: {speed_factor}")
311
+ original_len = len(generated_audio_np)
312
+ # Ensure speed_factor is within a reasonable range to avoid extreme distortion
313
+ # Adjust range based on observed quality (e.g., 0.5 to 2.0)
314
+ speed_factor = max(0.5, min(speed_factor, 2.0))
315
+ target_len = int(original_len / speed_factor)
316
+
317
+ if target_len > 0 and target_len != original_len:
318
+ logger.debug(
319
+ f"Resampling audio from {original_len} to {target_len} samples."
320
+ )
321
+ # Create time axes for original and resampled audio
322
+ x_original = np.linspace(0, original_len - 1, original_len)
323
+ x_resampled = np.linspace(0, original_len - 1, target_len)
324
+ # Interpolate using numpy
325
+ resampled_audio_np = np.interp(
326
+ x_resampled, x_original, generated_audio_np
327
+ )
328
+ final_audio_np = resampled_audio_np.astype(np.float32) # Ensure float32
329
+ logger.info(f"Audio resampled for {speed_factor:.2f}x speed.")
330
+ else:
331
+ logger.warning(
332
+ f"Skipping speed adjustment (factor: {speed_factor:.2f}). Target length invalid ({target_len}) or no change needed."
333
+ )
334
+ final_audio_np = generated_audio_np # Use original audio
335
+ else:
336
+ logger.info("Speed factor is 1.0, no speed adjustment needed.")
337
+ final_audio_np = generated_audio_np # No speed change needed
338
+
339
+ # Ensure output is float32 (DAC output should be, but good practice)
340
+ if final_audio_np.dtype != np.float32:
341
+ logger.warning(
342
+ f"Generated audio was not float32 ({final_audio_np.dtype}), converting."
343
+ )
344
+ final_audio_np = final_audio_np.astype(np.float32)
345
+
346
+ logger.info(
347
+ f"Final audio ready. Shape: {final_audio_np.shape}, dtype: {final_audio_np.dtype}"
348
+ )
349
+ # Return the processed audio and the expected sample rate
350
+ return final_audio_np, EXPECTED_SAMPLE_RATE
351
+
352
+ except Exception as e:
353
+ logger.error(
354
+ f"Error during Dia generation or post-processing: {e}", exc_info=True
355
+ )
356
+ return None # Return None on failure
models.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models.py
2
+ # Pydantic models for API requests and potentially responses
3
+
4
+ from pydantic import BaseModel, Field
5
+ from typing import Optional, Literal
6
+
7
+ # --- Request Models ---
8
+
9
+
10
+ class OpenAITTSRequest(BaseModel):
11
+ """Request model compatible with the OpenAI TTS API."""
12
+
13
+ model: str = Field(
14
+ default="dia-1.6b",
15
+ description="Model identifier (ignored by this server, always uses Dia). Included for compatibility.",
16
+ )
17
+ input: str = Field(..., description="The text to synthesize.")
18
+ voice: str = Field(
19
+ default="S1",
20
+ description="Voice mode or reference audio filename. Examples: 'S1', 'S2', 'dialogue', 'my_reference.wav'.",
21
+ )
22
+ response_format: Literal["opus", "wav"] = Field(
23
+ default="opus", description="The desired audio output format."
24
+ )
25
+ speed: float = Field(
26
+ default=1.0,
27
+ ge=0.8,
28
+ le=1.2, # Dia speed factor range seems narrower
29
+ description="Adjusts the speed of the generated audio (0.8 to 1.2).",
30
+ )
31
+
32
+
33
+ class CustomTTSRequest(BaseModel):
34
+ """Request model for the custom /tts endpoint."""
35
+
36
+ text: str = Field(
37
+ ...,
38
+ description="The text to synthesize. For 'dialogue' mode, include [S1]/[S2] tags.",
39
+ )
40
+ voice_mode: Literal["dialogue", "single_s1", "single_s2", "clone"] = Field(
41
+ default="single_s1", description="Specifies the generation mode."
42
+ )
43
+ clone_reference_filename: Optional[str] = Field(
44
+ default=None,
45
+ description="Filename of the reference audio within the configured reference path (required if voice_mode is 'clone').",
46
+ )
47
+ output_format: Literal["opus", "wav"] = Field(
48
+ default="opus", description="The desired audio output format."
49
+ )
50
+ # Dia-specific generation parameters
51
+ max_tokens: Optional[int] = Field(
52
+ default=None,
53
+ gt=0,
54
+ description="Maximum number of audio tokens to generate (defaults to model's internal config value).",
55
+ )
56
+ cfg_scale: float = Field(
57
+ default=3.0,
58
+ ge=1.0,
59
+ le=5.0,
60
+ description="Classifier-Free Guidance scale (1.0-5.0).",
61
+ )
62
+ temperature: float = Field(
63
+ default=1.3, ge=1.0, le=1.5, description="Sampling temperature (1.0-1.5)."
64
+ )
65
+ top_p: float = Field(
66
+ default=0.95,
67
+ ge=0.8,
68
+ le=1.0,
69
+ description="Nucleus sampling probability (0.8-1.0).",
70
+ )
71
+ speed_factor: float = Field(
72
+ default=0.94,
73
+ ge=0.8,
74
+ le=1.0, # Dia's default range seems to be <= 1.0
75
+ description="Adjusts the speed of the generated audio (0.8 to 1.0).",
76
+ )
77
+ cfg_filter_top_k: int = Field(
78
+ default=35, ge=15, le=50, description="Top k filter for CFG guidance (15-50)."
79
+ )
80
+
81
+
82
+ # --- Response Models (Optional, can be simple dicts too) ---
83
+
84
+
85
+ class TTSResponse(BaseModel):
86
+ """Basic response model for successful generation (if returning JSON)."""
87
+
88
+ request_id: str
89
+ status: str = "completed"
90
+ generation_time_sec: float
91
+ output_url: Optional[str] = None # If saving file and returning URL
92
+
93
+
94
+ class ErrorResponse(BaseModel):
95
+ """Error response model."""
96
+
97
+ detail: str
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+
3
+ # Core Web Framework
4
+ fastapi
5
+ uvicorn[standard]
6
+
7
+ # Machine Learning & Audio
8
+ torch
9
+ torchaudio
10
+ numpy
11
+ soundfile # Requires libsndfile system library (e.g., sudo apt-get install libsndfile1 on Debian/Ubuntu)
12
+ huggingface_hub
13
+ descript-audio-codec
14
+ safetensors
15
+
16
+ # Configuration & Utilities
17
+ pydantic
18
+ python-dotenv
19
+ Jinja2
20
+ python-multipart # For potential file uploads in UI
21
+ requests # For health checks or other potential uses
22
+ PyYAML # For parsing presets.yaml
server.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server.py
2
+ # Main FastAPI server for Dia TTS
3
+
4
+ import sys
5
+ import logging
6
+ import time
7
+ import os
8
+ import io
9
+ import uuid
10
+ import sys
11
+ import shutil # For file copying
12
+ import yaml # For loading presets
13
+ from datetime import datetime
14
+ from contextlib import asynccontextmanager
15
+ from typing import Optional, Literal, List, Dict, Any
16
+ import webbrowser
17
+ import threading
18
+ import time
19
+
20
+ from fastapi import (
21
+ FastAPI,
22
+ HTTPException,
23
+ Request,
24
+ Response,
25
+ Form,
26
+ UploadFile,
27
+ File,
28
+ BackgroundTasks,
29
+ )
30
+ from fastapi.responses import (
31
+ StreamingResponse,
32
+ JSONResponse,
33
+ HTMLResponse,
34
+ RedirectResponse,
35
+ )
36
+ from fastapi.staticfiles import StaticFiles
37
+ from fastapi.templating import Jinja2Templates
38
+ import uvicorn
39
+ import numpy as np
40
+
41
+ # Internal imports
42
+ from config import (
43
+ config_manager,
44
+ get_host,
45
+ get_port,
46
+ get_output_path,
47
+ get_reference_audio_path,
48
+ # register_config_routes is now defined locally
49
+ get_model_cache_path,
50
+ get_model_repo_id,
51
+ get_model_config_filename,
52
+ get_model_weights_filename,
53
+ # Generation default getters
54
+ get_gen_default_speed_factor,
55
+ get_gen_default_cfg_scale,
56
+ get_gen_default_temperature,
57
+ get_gen_default_top_p,
58
+ get_gen_default_cfg_filter_top_k,
59
+ DEFAULT_CONFIG,
60
+ )
61
+ from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse
62
+ import engine
63
+ from engine import (
64
+ load_model as load_dia_model,
65
+ generate_speech,
66
+ EXPECTED_SAMPLE_RATE,
67
+ )
68
+ from utils import encode_audio, save_audio_to_file, PerformanceMonitor
69
+
70
+ # Configure logging (Basic setup, can be enhanced)
71
+ logging.basicConfig(
72
+ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
73
+ )
74
+ # Reduce verbosity of noisy libraries if needed
75
+ # logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
76
+ # logging.getLogger("watchfiles").setLevel(logging.WARNING)
77
+ logger = logging.getLogger(__name__) # Logger for this module
78
+
79
+ # --- Global Variables & Constants ---
80
+ PRESETS_FILE = "ui/presets.yaml"
81
+ loaded_presets: List[Dict[str, Any]] = [] # Cache presets in memory
82
+ startup_complete_event = threading.Event()
83
+
84
+ # --- Helper Functions ---
85
+
86
+
87
+ def load_presets():
88
+ """Loads presets from the YAML file."""
89
+ global loaded_presets
90
+ try:
91
+ if os.path.exists(PRESETS_FILE):
92
+ with open(PRESETS_FILE, "r", encoding="utf-8") as f:
93
+ loaded_presets = yaml.safe_load(f)
94
+ if not isinstance(loaded_presets, list):
95
+ logger.error(
96
+ f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded."
97
+ )
98
+ loaded_presets = []
99
+ else:
100
+ logger.info(
101
+ f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}."
102
+ )
103
+ else:
104
+ logger.warning(
105
+ f"Presets file not found at '{PRESETS_FILE}'. No presets will be available."
106
+ )
107
+ loaded_presets = []
108
+ except yaml.YAMLError as e:
109
+ logger.error(
110
+ f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True
111
+ )
112
+ loaded_presets = []
113
+ except Exception as e:
114
+ logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True)
115
+ loaded_presets = []
116
+
117
+
118
+ def get_valid_reference_files() -> list[str]:
119
+ """Gets a list of valid audio files (.wav, .mp3) from the reference directory."""
120
+ ref_path = get_reference_audio_path()
121
+ valid_files = []
122
+ allowed_extensions = (".wav", ".mp3")
123
+ try:
124
+ if os.path.isdir(ref_path):
125
+ for filename in os.listdir(ref_path):
126
+ if filename.lower().endswith(allowed_extensions):
127
+ # Optional: Add check for file size or basic validity if needed
128
+ valid_files.append(filename)
129
+ else:
130
+ logger.warning(f"Reference audio directory not found: {ref_path}")
131
+ except Exception as e:
132
+ logger.error(
133
+ f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True
134
+ )
135
+ return sorted(valid_files)
136
+
137
+
138
+ def sanitize_filename(filename: str) -> str:
139
+ """Removes potentially unsafe characters and path components from a filename."""
140
+ # Remove directory separators
141
+ filename = os.path.basename(filename)
142
+ # Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore.
143
+ safe_chars = set(
144
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-"
145
+ )
146
+ sanitized = "".join(c if c in safe_chars else "_" for c in filename)
147
+ # Prevent names starting with dot or consisting only of dots/spaces
148
+ if not sanitized or sanitized.lstrip("._ ") == "":
149
+ return f"uploaded_file_{uuid.uuid4().hex[:8]}" # Generate a safe fallback name
150
+ # Limit length
151
+ max_len = 100
152
+ if len(sanitized) > max_len:
153
+ name, ext = os.path.splitext(sanitized)
154
+ sanitized = name[: max_len - len(ext)] + ext
155
+ return sanitized
156
+
157
+
158
+ # --- Application Lifespan (Startup/Shutdown) ---
159
+ @asynccontextmanager
160
+ async def lifespan(app: FastAPI):
161
+ """Application lifespan manager for startup/shutdown."""
162
+ model_loaded_successfully = False # Flag to track success
163
+ try:
164
+ logger.info("Starting Dia TTS server initialization...")
165
+ # Ensure base directories exist
166
+ os.makedirs(get_output_path(), exist_ok=True)
167
+ os.makedirs(get_reference_audio_path(), exist_ok=True)
168
+ os.makedirs(get_model_cache_path(), exist_ok=True)
169
+ os.makedirs("ui", exist_ok=True)
170
+ os.makedirs("static", exist_ok=True)
171
+
172
+ # Load presets from YAML file
173
+ load_presets()
174
+
175
+ # Load the main TTS model during startup
176
+ if not load_dia_model():
177
+ # Model loading failed
178
+ error_msg = (
179
+ "CRITICAL: Failed to load Dia model on startup. Server cannot start."
180
+ )
181
+ logger.critical(error_msg)
182
+ # Option 1: Raise an exception to stop Uvicorn startup cleanly
183
+ raise RuntimeError(error_msg)
184
+ # Option 2: Force exit (less clean, might bypass some Uvicorn shutdown)
185
+ # sys.exit(1)
186
+ else:
187
+ logger.info("Dia model loaded successfully.")
188
+ model_loaded_successfully = True
189
+
190
+ # Create and start a delayed browser opening thread
191
+ # IMPORTANT: Create this thread AFTER model loading completes
192
+ host = get_host()
193
+ port = get_port()
194
+ browser_thread = threading.Thread(
195
+ target=lambda: _delayed_browser_open(host, port), daemon=True
196
+ )
197
+ browser_thread.start()
198
+
199
+ # --- Signal completion AFTER potentially long operations ---
200
+ logger.info("Application startup sequence finished. Signaling readiness.")
201
+ startup_complete_event.set()
202
+
203
+ yield # Application runs here
204
+
205
+ except Exception as e:
206
+ # Catch the RuntimeError we raised or any other startup error
207
+ logger.error(f"Fatal error during application startup: {e}", exc_info=True)
208
+ # Do NOT set the event here if startup failed
209
+ # Re-raise the exception or exit to ensure the server stops
210
+ raise e # Re-raising ensures Uvicorn knows startup failed
211
+ # Alternatively: sys.exit(1)
212
+ finally:
213
+ # Cleanup on shutdown
214
+ logger.info("Application shutdown initiated...")
215
+ # Add any specific cleanup needed
216
+ logger.info("Application shutdown complete.")
217
+
218
+
219
+ def _delayed_browser_open(host, port):
220
+ """Opens browser after a short delay to ensure server is ready"""
221
+ try:
222
+ # Small delay to ensure Uvicorn is fully ready
223
+ time.sleep(2)
224
+
225
+ display_host = "localhost" if host == "0.0.0.0" else host
226
+ browser_url = f"http://{display_host}:{port}/"
227
+
228
+ # Log to file for debugging
229
+ with open("browser_thread_debug.log", "a") as f:
230
+ f.write(f"[{time.time()}] Opening browser at {browser_url}\n")
231
+
232
+ # Try to use logger as well (might work at this point)
233
+ try:
234
+ logger.info(f"Opening browser at {browser_url}")
235
+ except:
236
+ pass
237
+
238
+ # Open browser directly without health checks
239
+ webbrowser.open(browser_url)
240
+
241
+ except Exception as e:
242
+ with open("browser_thread_debug.log", "a") as f:
243
+ f.write(f"[{time.time()}] Browser open error: {str(e)}\n")
244
+
245
+
246
+ # --- FastAPI App Initialization ---
247
+ app = FastAPI(
248
+ title="Dia TTS Server",
249
+ description="Text-to-Speech server using the Dia model, providing API and Web UI.",
250
+ version="1.1.0", # Incremented version
251
+ lifespan=lifespan,
252
+ )
253
+
254
+ # List of folders to check/create
255
+ folders = ["reference_audio", "model_cache", "outputs"]
256
+
257
+ # Check each folder and create if it doesn't exist
258
+ for folder in folders:
259
+ if not os.path.exists(folder):
260
+ os.makedirs(folder)
261
+ print(f"Created directory: {folder}")
262
+
263
+ # --- Static Files and Templates ---
264
+ # Serve generated audio files from the configured output path
265
+ app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs")
266
+ # Serve UI files (CSS, JS) from the 'ui' directory
267
+ app.mount("/ui", StaticFiles(directory="ui"), name="ui_static")
268
+ # Initialize Jinja2 templates to look in the 'ui' directory
269
+ templates = Jinja2Templates(directory="ui")
270
+
271
+
272
+ # --- Configuration Routes Definition ---
273
+ # Defined locally now instead of importing from config.py
274
+ def register_config_routes(app: FastAPI):
275
+ """Adds configuration management endpoints to the FastAPI app."""
276
+ logger.info(
277
+ "Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)."
278
+ )
279
+
280
+ @app.get(
281
+ "/get_config",
282
+ tags=["Configuration"],
283
+ summary="Get current server configuration",
284
+ )
285
+ async def get_current_config():
286
+ """Returns the current server configuration values (from .env or defaults)."""
287
+ logger.info("Request received for /get_config")
288
+ return JSONResponse(content=config_manager.get_all())
289
+
290
+ @app.post(
291
+ "/save_config", tags=["Configuration"], summary="Save server configuration"
292
+ )
293
+ async def save_new_config(request: Request):
294
+ """
295
+ Saves updated server configuration values (Host, Port, Model paths, etc.)
296
+ to the .env file. Requires server restart to apply most changes.
297
+ """
298
+ logger.info("Request received for /save_config")
299
+ try:
300
+ new_config_data = await request.json()
301
+ if not isinstance(new_config_data, dict):
302
+ raise ValueError("Request body must be a JSON object.")
303
+ logger.debug(f"Received server config data to save: {new_config_data}")
304
+
305
+ # Filter data to only include keys present in DEFAULT_CONFIG
306
+ filtered_data = {
307
+ k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG
308
+ }
309
+ unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys())
310
+ if unknown_keys:
311
+ logger.warning(
312
+ f"Ignoring unknown keys in save_config request: {unknown_keys}"
313
+ )
314
+
315
+ config_manager.update(filtered_data) # Update in memory first
316
+ if config_manager.save(): # Attempt to save to .env
317
+ logger.info("Server configuration saved successfully to .env.")
318
+ return JSONResponse(
319
+ content={
320
+ "message": "Server configuration saved. Restart server to apply changes."
321
+ }
322
+ )
323
+ else:
324
+ logger.error("Failed to save server configuration to .env file.")
325
+ raise HTTPException(
326
+ status_code=500, detail="Failed to save configuration file."
327
+ )
328
+ except ValueError as ve:
329
+ logger.error(f"Invalid data format for /save_config: {ve}")
330
+ raise HTTPException(
331
+ status_code=400, detail=f"Invalid request data: {str(ve)}"
332
+ )
333
+ except Exception as e:
334
+ logger.error(f"Error processing /save_config request: {e}", exc_info=True)
335
+ raise HTTPException(
336
+ status_code=500, detail=f"Internal server error during save: {str(e)}"
337
+ )
338
+
339
+ @app.post(
340
+ "/save_generation_defaults",
341
+ tags=["Configuration"],
342
+ summary="Save default generation parameters",
343
+ )
344
+ async def save_generation_defaults(request: Request):
345
+ """
346
+ Saves the provided generation parameters (speed, cfg, temp, etc.)
347
+ as the new defaults in the .env file. These are loaded by the UI on startup.
348
+ """
349
+ logger.info("Request received for /save_generation_defaults")
350
+ try:
351
+ gen_params = await request.json()
352
+ if not isinstance(gen_params, dict):
353
+ raise ValueError("Request body must be a JSON object.")
354
+ logger.debug(f"Received generation defaults to save: {gen_params}")
355
+
356
+ # Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR')
357
+ defaults_to_save = {}
358
+ key_map = {
359
+ "speed_factor": "GEN_DEFAULT_SPEED_FACTOR",
360
+ "cfg_scale": "GEN_DEFAULT_CFG_SCALE",
361
+ "temperature": "GEN_DEFAULT_TEMPERATURE",
362
+ "top_p": "GEN_DEFAULT_TOP_P",
363
+ "cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K",
364
+ }
365
+ valid_keys_found = False
366
+ for ui_key, env_key in key_map.items():
367
+ if ui_key in gen_params:
368
+ # Basic validation could be added here (e.g., check if float/int)
369
+ defaults_to_save[env_key] = str(
370
+ gen_params[ui_key]
371
+ ) # Ensure saving as string
372
+ valid_keys_found = True
373
+ else:
374
+ logger.warning(
375
+ f"Missing expected key '{ui_key}' in save_generation_defaults request."
376
+ )
377
+
378
+ if not valid_keys_found:
379
+ raise ValueError("No valid generation parameters found in the request.")
380
+
381
+ config_manager.update(defaults_to_save) # Update in memory
382
+ if (
383
+ config_manager.save()
384
+ ): # Save all current config (including these) to .env
385
+ logger.info("Generation defaults saved successfully to .env.")
386
+ return JSONResponse(content={"message": "Generation defaults saved."})
387
+ else:
388
+ logger.error("Failed to save generation defaults to .env file.")
389
+ raise HTTPException(
390
+ status_code=500, detail="Failed to save configuration file."
391
+ )
392
+ except ValueError as ve:
393
+ logger.error(f"Invalid data format for /save_generation_defaults: {ve}")
394
+ raise HTTPException(
395
+ status_code=400, detail=f"Invalid request data: {str(ve)}"
396
+ )
397
+ except Exception as e:
398
+ logger.error(
399
+ f"Error processing /save_generation_defaults request: {e}",
400
+ exc_info=True,
401
+ )
402
+ raise HTTPException(
403
+ status_code=500, detail=f"Internal server error during save: {str(e)}"
404
+ )
405
+
406
+ @app.post(
407
+ "/restart_server",
408
+ tags=["Configuration"],
409
+ summary="Attempt to restart the server",
410
+ )
411
+ async def trigger_server_restart(background_tasks: BackgroundTasks):
412
+ """
413
+ Attempts to restart the server process.
414
+ NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload,
415
+ or managed by systemd/supervisor). A simple exit might just stop the process.
416
+ This implementation attempts a clean exit, relying on the runner to restart it.
417
+ """
418
+ logger.warning("Received request to restart server via API.")
419
+
420
+ def _do_restart():
421
+ time.sleep(1) # Short delay to allow response to be sent
422
+ logger.warning("Attempting clean exit for restart...")
423
+ # Option 1: Clean exit (relies on Uvicorn reload or process manager)
424
+ sys.exit(0)
425
+ # Option 2: Forceful re-execution (use with caution, might not work as expected)
426
+ # try:
427
+ # logger.warning("Attempting os.execv for restart...")
428
+ # os.execv(sys.executable, ['python'] + sys.argv)
429
+ # except Exception as exec_e:
430
+ # logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.")
431
+ # # Fallback to sys.exit if execv fails
432
+ # sys.exit(1)
433
+
434
+ background_tasks.add_task(_do_restart)
435
+ return JSONResponse(
436
+ content={
437
+ "message": "Restart signal sent. Server should restart shortly if run with auto-reload."
438
+ }
439
+ )
440
+
441
+
442
+ # --- Register Configuration Routes ---
443
+ register_config_routes(app)
444
+
445
+
446
+ # --- API Endpoints ---
447
+
448
+
449
+ @app.post(
450
+ "/v1/audio/speech",
451
+ response_class=StreamingResponse,
452
+ tags=["TTS Generation"],
453
+ summary="Generate speech (OpenAI compatible)",
454
+ )
455
+ async def openai_tts_endpoint(request: OpenAITTSRequest):
456
+ """
457
+ Generates speech audio from text, compatible with the OpenAI TTS API structure.
458
+ Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone).
459
+ """
460
+ monitor = PerformanceMonitor()
461
+ monitor.record("Request received")
462
+ logger.info(
463
+ f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'"
464
+ )
465
+ logger.debug(f"Input text (start): '{request.input[:100]}...'")
466
+
467
+ voice_mode = "single_s1" # Default if mapping fails
468
+ clone_ref_file = None
469
+ ref_path = get_reference_audio_path()
470
+
471
+ # --- Map OpenAI 'voice' parameter to Dia's modes ---
472
+ voice_param = request.voice.strip()
473
+ if voice_param.lower() == "dialogue":
474
+ voice_mode = "dialogue"
475
+ elif voice_param.lower() == "s1":
476
+ voice_mode = "single_s1"
477
+ elif voice_param.lower() == "s2":
478
+ voice_mode = "single_s2"
479
+ # Check if it looks like a filename for cloning (allow .wav or .mp3)
480
+ elif voice_param.lower().endswith((".wav", ".mp3")):
481
+ potential_path = os.path.join(ref_path, voice_param)
482
+ # Check if the file actually exists in the reference directory
483
+ if os.path.isfile(potential_path):
484
+ voice_mode = "clone"
485
+ clone_ref_file = voice_param # Use the provided filename
486
+ logger.info(
487
+ f"OpenAI request mapped to clone mode with file: {clone_ref_file}"
488
+ )
489
+ else:
490
+ logger.warning(
491
+ f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode."
492
+ )
493
+ # Fallback to default 'single_s1' if file not found
494
+ else:
495
+ logger.warning(
496
+ f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'."
497
+ )
498
+ # Fallback for any other value
499
+
500
+ monitor.record("Parameters processed")
501
+
502
+ try:
503
+ # Call the core engine function using mapped parameters
504
+ result = generate_speech(
505
+ text=request.input,
506
+ voice_mode=voice_mode,
507
+ clone_reference_filename=clone_ref_file,
508
+ speed_factor=request.speed, # Pass speed factor for post-processing
509
+ # Use Dia's configured defaults for other generation params unless mapped
510
+ max_tokens=None, # Let Dia use its default unless specified otherwise
511
+ cfg_scale=get_gen_default_cfg_scale(), # Use saved defaults
512
+ temperature=get_gen_default_temperature(),
513
+ top_p=get_gen_default_top_p(),
514
+ cfg_filter_top_k=get_gen_default_cfg_filter_top_k(),
515
+ )
516
+ monitor.record("Generation complete")
517
+
518
+ if result is None:
519
+ logger.error("Speech generation failed (engine returned None).")
520
+ raise HTTPException(status_code=500, detail="Speech generation failed.")
521
+
522
+ audio_array, sample_rate = result
523
+
524
+ if sample_rate != EXPECTED_SAMPLE_RATE:
525
+ logger.warning(
526
+ f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}."
527
+ )
528
+ # Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for
529
+ sample_rate = EXPECTED_SAMPLE_RATE
530
+
531
+ # Encode the audio in memory to the requested format
532
+ encoded_audio = encode_audio(audio_array, sample_rate, request.response_format)
533
+ monitor.record("Audio encoding complete")
534
+
535
+ if encoded_audio is None:
536
+ logger.error(f"Failed to encode audio to format: {request.response_format}")
537
+ raise HTTPException(
538
+ status_code=500,
539
+ detail=f"Failed to encode audio to {request.response_format}",
540
+ )
541
+
542
+ # Determine the correct media type for the response header
543
+ media_type = "audio/opus" if request.response_format == "opus" else "audio/wav"
544
+ # Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI.
545
+
546
+ logger.info(
547
+ f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}"
548
+ )
549
+ logger.debug(monitor.report())
550
+
551
+ # Stream the encoded audio back to the client
552
+ return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
553
+
554
+ except HTTPException as http_exc:
555
+ # Re-raise HTTPExceptions directly (e.g., from parameter validation)
556
+ logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}")
557
+ raise http_exc
558
+ except Exception as e:
559
+ logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True)
560
+ logger.debug(monitor.report())
561
+ # Return generic server error for unexpected issues
562
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
563
+
564
+
565
+ @app.post(
566
+ "/tts",
567
+ response_class=StreamingResponse,
568
+ tags=["TTS Generation"],
569
+ summary="Generate speech (Custom parameters)",
570
+ )
571
+ async def custom_tts_endpoint(request: CustomTTSRequest):
572
+ """
573
+ Generates speech audio from text using explicit Dia parameters.
574
+ """
575
+ monitor = PerformanceMonitor()
576
+ monitor.record("Request received")
577
+ logger.info(
578
+ f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'"
579
+ )
580
+ logger.debug(f"Input text (start): '{request.text[:100]}...'")
581
+ logger.debug(
582
+ f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}"
583
+ )
584
+
585
+ clone_ref_file = None
586
+ if request.voice_mode == "clone":
587
+ if not request.clone_reference_filename:
588
+ raise HTTPException(
589
+ status_code=400, # Bad request
590
+ detail="Missing 'clone_reference_filename' which is required for clone mode.",
591
+ )
592
+ ref_path = get_reference_audio_path()
593
+ potential_path = os.path.join(ref_path, request.clone_reference_filename)
594
+ if not os.path.isfile(potential_path):
595
+ logger.error(
596
+ f"Reference audio file not found for clone mode: {potential_path}"
597
+ )
598
+ raise HTTPException(
599
+ status_code=404, # Not found
600
+ detail=f"Reference audio file not found: {request.clone_reference_filename}",
601
+ )
602
+ clone_ref_file = request.clone_reference_filename
603
+ logger.info(f"Custom request using clone mode with file: {clone_ref_file}")
604
+
605
+ monitor.record("Parameters processed")
606
+
607
+ try:
608
+ # Call the core engine function with parameters from the request
609
+ result = generate_speech(
610
+ text=request.text,
611
+ voice_mode=request.voice_mode,
612
+ clone_reference_filename=clone_ref_file,
613
+ max_tokens=request.max_tokens, # Pass user value or None
614
+ cfg_scale=request.cfg_scale,
615
+ temperature=request.temperature,
616
+ top_p=request.top_p,
617
+ speed_factor=request.speed_factor, # For post-processing
618
+ cfg_filter_top_k=request.cfg_filter_top_k,
619
+ )
620
+ monitor.record("Generation complete")
621
+
622
+ if result is None:
623
+ logger.error("Speech generation failed (engine returned None).")
624
+ raise HTTPException(status_code=500, detail="Speech generation failed.")
625
+
626
+ audio_array, sample_rate = result
627
+
628
+ if sample_rate != EXPECTED_SAMPLE_RATE:
629
+ logger.warning(
630
+ f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}."
631
+ )
632
+ sample_rate = EXPECTED_SAMPLE_RATE
633
+
634
+ # Encode the audio in memory
635
+ encoded_audio = encode_audio(audio_array, sample_rate, request.output_format)
636
+ monitor.record("Audio encoding complete")
637
+
638
+ if encoded_audio is None:
639
+ logger.error(f"Failed to encode audio to format: {request.output_format}")
640
+ raise HTTPException(
641
+ status_code=500,
642
+ detail=f"Failed to encode audio to {request.output_format}",
643
+ )
644
+
645
+ # Determine media type
646
+ media_type = "audio/opus" if request.output_format == "opus" else "audio/wav"
647
+
648
+ logger.info(
649
+ f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}"
650
+ )
651
+ logger.debug(monitor.report())
652
+
653
+ # Stream the response
654
+ return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
655
+
656
+ except HTTPException as http_exc:
657
+ logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}")
658
+ raise http_exc
659
+ except Exception as e:
660
+ logger.error(f"Error processing custom TTS request: {e}", exc_info=True)
661
+ logger.debug(monitor.report())
662
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
663
+
664
+
665
+ # --- Web UI Endpoints ---
666
+
667
+
668
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
669
+ async def get_web_ui(request: Request):
670
+ """Serves the main TTS web interface."""
671
+ logger.info("Serving TTS Web UI (index.html)")
672
+ # Get current list of reference files for the clone dropdown
673
+ reference_files = get_valid_reference_files()
674
+ # Get current server config and default generation params
675
+ current_config = config_manager.get_all()
676
+ default_gen_params = {
677
+ "speed_factor": get_gen_default_speed_factor(),
678
+ "cfg_scale": get_gen_default_cfg_scale(),
679
+ "temperature": get_gen_default_temperature(),
680
+ "top_p": get_gen_default_top_p(),
681
+ "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
682
+ }
683
+
684
+ return templates.TemplateResponse(
685
+ "index.html", # Use the renamed file
686
+ {
687
+ "request": request,
688
+ "reference_files": reference_files,
689
+ "config": current_config, # Pass current server config
690
+ "presets": loaded_presets, # Pass loaded presets
691
+ "default_gen_params": default_gen_params, # Pass default gen params
692
+ # Add other variables needed by the template for initial state
693
+ "error": None,
694
+ "success": None,
695
+ "output_file_url": None,
696
+ "generation_time": None,
697
+ "submitted_text": "",
698
+ "submitted_voice_mode": "dialogue", # Default to combined mode
699
+ "submitted_clone_file": None,
700
+ # Initial generation params will be set by default_gen_params
701
+ },
702
+ )
703
+
704
+
705
+ @app.post("/web/generate", response_class=HTMLResponse, include_in_schema=False)
706
+ async def handle_web_ui_generate(
707
+ request: Request,
708
+ text: str = Form(...),
709
+ voice_mode: Literal["dialogue", "clone"] = Form(...), # Updated modes
710
+ clone_reference_select: Optional[str] = Form(None),
711
+ # Generation parameters from form
712
+ speed_factor: float = Form(...), # Make required or use Depends with default
713
+ cfg_scale: float = Form(...),
714
+ temperature: float = Form(...),
715
+ top_p: float = Form(...),
716
+ cfg_filter_top_k: int = Form(...),
717
+ ):
718
+ """Handles the generation request from the web UI form."""
719
+ logger.info(f"Web UI generation request: mode='{voice_mode}'")
720
+ monitor = PerformanceMonitor()
721
+ monitor.record("Web request received")
722
+
723
+ output_file_url = None
724
+ generation_time = None
725
+ error_message = None
726
+ success_message = None
727
+ output_filename_base = "dia_output" # Default base name
728
+
729
+ # --- Pre-generation Validation ---
730
+ if not text.strip():
731
+ error_message = "Please enter some text to synthesize."
732
+
733
+ clone_ref_file = None
734
+ if voice_mode == "clone":
735
+ if not clone_reference_select or clone_reference_select == "none":
736
+ error_message = "Please select a reference audio file for clone mode."
737
+ else:
738
+ # Verify selected file still exists (important if files can be deleted)
739
+ ref_path = get_reference_audio_path()
740
+ potential_path = os.path.join(ref_path, clone_reference_select)
741
+ if not os.path.isfile(potential_path):
742
+ error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload."
743
+ # Invalidate selection
744
+ clone_ref_file = None
745
+ clone_reference_select = None # Clear submitted value for re-rendering
746
+ else:
747
+ clone_ref_file = clone_reference_select
748
+ logger.info(f"Using selected reference file: {clone_ref_file}")
749
+
750
+ # If validation failed, re-render the page with error and submitted values
751
+ if error_message:
752
+ logger.warning(f"Web UI validation error: {error_message}")
753
+ reference_files = get_valid_reference_files()
754
+ current_config = config_manager.get_all()
755
+ default_gen_params = { # Pass defaults again for consistency
756
+ "speed_factor": get_gen_default_speed_factor(),
757
+ "cfg_scale": get_gen_default_cfg_scale(),
758
+ "temperature": get_gen_default_temperature(),
759
+ "top_p": get_gen_default_top_p(),
760
+ "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
761
+ }
762
+ # Pass back the values the user submitted
763
+ submitted_gen_params = {
764
+ "speed_factor": speed_factor,
765
+ "cfg_scale": cfg_scale,
766
+ "temperature": temperature,
767
+ "top_p": top_p,
768
+ "cfg_filter_top_k": cfg_filter_top_k,
769
+ }
770
+
771
+ return templates.TemplateResponse(
772
+ "index.html",
773
+ {
774
+ "request": request,
775
+ "error": error_message,
776
+ "reference_files": reference_files,
777
+ "config": current_config,
778
+ "presets": loaded_presets,
779
+ "default_gen_params": default_gen_params, # Base defaults
780
+ # Submitted values to repopulate form
781
+ "submitted_text": text,
782
+ "submitted_voice_mode": voice_mode,
783
+ "submitted_clone_file": clone_reference_select, # Use potentially invalidated value
784
+ "submitted_gen_params": submitted_gen_params, # Pass submitted params back
785
+ # Ensure other necessary template variables are passed
786
+ "success": None,
787
+ "output_file_url": None,
788
+ "generation_time": None,
789
+ },
790
+ )
791
+
792
+ # --- Generation ---
793
+ try:
794
+ monitor.record("Parameters processed")
795
+ # Call the core engine function
796
+ result = generate_speech(
797
+ text=text,
798
+ voice_mode=voice_mode,
799
+ clone_reference_filename=clone_ref_file,
800
+ speed_factor=speed_factor,
801
+ cfg_scale=cfg_scale,
802
+ temperature=temperature,
803
+ top_p=top_p,
804
+ cfg_filter_top_k=cfg_filter_top_k,
805
+ max_tokens=None, # Use model default for UI simplicity
806
+ )
807
+ monitor.record("Generation complete")
808
+
809
+ if result:
810
+ audio_array, sample_rate = result
811
+ output_path_base = get_output_path()
812
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
813
+ # Create a more descriptive filename
814
+ mode_tag = voice_mode
815
+ if voice_mode == "clone" and clone_ref_file:
816
+ safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0])
817
+ mode_tag = f"clone_{safe_ref_name[:20]}" # Limit length
818
+ output_filename = (
819
+ f"{mode_tag}_{timestamp}.wav" # Always save as WAV for simplicity
820
+ )
821
+ output_filepath = os.path.join(output_path_base, output_filename)
822
+
823
+ # Save the audio to a WAV file
824
+ saved = save_audio_to_file(audio_array, sample_rate, output_filepath)
825
+ monitor.record("Audio saved")
826
+
827
+ if saved:
828
+ output_file_url = (
829
+ f"/outputs/{output_filename}" # URL path for browser access
830
+ )
831
+ generation_time = (
832
+ monitor.events[-1][1] - monitor.start_time
833
+ ) # Time until save complete
834
+ success_message = f"Audio generated successfully!"
835
+ logger.info(f"Web UI generated audio saved to: {output_filepath}")
836
+ else:
837
+ error_message = "Failed to save generated audio file."
838
+ logger.error("Failed to save audio file from web UI request.")
839
+ else:
840
+ error_message = "Speech generation failed (engine returned None)."
841
+ logger.error("Speech generation failed for web UI request.")
842
+
843
+ except Exception as e:
844
+ logger.error(f"Error processing web UI TTS request: {e}", exc_info=True)
845
+ error_message = f"An unexpected error occurred: {str(e)}"
846
+
847
+ logger.debug(monitor.report())
848
+
849
+ # --- Re-render Template with Results ---
850
+ reference_files = get_valid_reference_files()
851
+ current_config = config_manager.get_all()
852
+ default_gen_params = {
853
+ "speed_factor": get_gen_default_speed_factor(),
854
+ "cfg_scale": get_gen_default_cfg_scale(),
855
+ "temperature": get_gen_default_temperature(),
856
+ "top_p": get_gen_default_top_p(),
857
+ "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(),
858
+ }
859
+ # Pass back submitted values to repopulate form correctly
860
+ submitted_gen_params = {
861
+ "speed_factor": speed_factor,
862
+ "cfg_scale": cfg_scale,
863
+ "temperature": temperature,
864
+ "top_p": top_p,
865
+ "cfg_filter_top_k": cfg_filter_top_k,
866
+ }
867
+
868
+ return templates.TemplateResponse(
869
+ "index.html",
870
+ {
871
+ "request": request,
872
+ "error": error_message,
873
+ "success": success_message,
874
+ "output_file_url": output_file_url,
875
+ "generation_time": f"{generation_time:.2f}" if generation_time else None,
876
+ "reference_files": reference_files,
877
+ "config": current_config,
878
+ "presets": loaded_presets,
879
+ "default_gen_params": default_gen_params, # Base defaults
880
+ # Pass back submitted values
881
+ "submitted_text": text,
882
+ "submitted_voice_mode": voice_mode,
883
+ "submitted_clone_file": clone_ref_file, # Pass the validated filename back
884
+ "submitted_gen_params": submitted_gen_params, # Pass submitted params back
885
+ },
886
+ )
887
+
888
+
889
+ # --- Reference Audio Upload Endpoint ---
890
+ @app.post(
891
+ "/upload_reference", tags=["Web UI Helpers"], summary="Upload reference audio files"
892
+ )
893
+ async def upload_reference_audio(files: List[UploadFile] = File(...)):
894
+ """Handles uploading of reference audio files (.wav, .mp3) for voice cloning."""
895
+ logger.info(f"Received request to upload {len(files)} reference audio file(s).")
896
+ ref_path = get_reference_audio_path()
897
+ uploaded_filenames = []
898
+ errors = []
899
+ allowed_mime_types = [
900
+ "audio/wav",
901
+ "audio/mpeg",
902
+ "audio/x-wav",
903
+ ] # Common WAV/MP3 types
904
+ allowed_extensions = [".wav", ".mp3"]
905
+
906
+ for file in files:
907
+ try:
908
+ # Basic validation
909
+ if not file.filename:
910
+ errors.append("Received file with no filename.")
911
+ continue
912
+
913
+ # Sanitize filename
914
+ safe_filename = sanitize_filename(file.filename)
915
+ _, ext = os.path.splitext(safe_filename)
916
+ if ext.lower() not in allowed_extensions:
917
+ errors.append(
918
+ f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}"
919
+ )
920
+ continue
921
+
922
+ # Check MIME type (more reliable than extension)
923
+ if file.content_type not in allowed_mime_types:
924
+ errors.append(
925
+ f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}"
926
+ )
927
+ continue
928
+
929
+ # Construct full save path
930
+ destination_path = os.path.join(ref_path, safe_filename)
931
+
932
+ # Prevent overwriting existing files (optional, could add counter)
933
+ if os.path.exists(destination_path):
934
+ # Simple approach: skip if exists
935
+ logger.warning(
936
+ f"Reference file '{safe_filename}' already exists. Skipping upload."
937
+ )
938
+ # Add to list so UI knows it's available, even if not newly uploaded this time
939
+ if safe_filename not in uploaded_filenames:
940
+ uploaded_filenames.append(safe_filename)
941
+ continue
942
+ # Alternative: add counter like file_1.wav, file_2.wav
943
+
944
+ # Save the file using shutil.copyfileobj for efficiency with large files
945
+ try:
946
+ with open(destination_path, "wb") as buffer:
947
+ shutil.copyfileobj(file.file, buffer)
948
+ logger.info(f"Successfully saved reference file: {destination_path}")
949
+ uploaded_filenames.append(safe_filename)
950
+ except Exception as save_exc:
951
+ errors.append(f"Failed to save file '{safe_filename}': {save_exc}")
952
+ logger.error(
953
+ f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}",
954
+ exc_info=True,
955
+ )
956
+ finally:
957
+ # Ensure the UploadFile resource is closed
958
+ await file.close()
959
+
960
+ except Exception as e:
961
+ errors.append(
962
+ f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}"
963
+ )
964
+ logger.error(
965
+ f"Unexpected error processing uploaded file: {e}", exc_info=True
966
+ )
967
+ # Ensure file is closed even if other errors occur
968
+ if file:
969
+ await file.close()
970
+
971
+ # Get the updated list of all valid files in the directory
972
+ updated_file_list = get_valid_reference_files()
973
+
974
+ response_data = {
975
+ "message": f"Processed {len(files)} file(s).",
976
+ "uploaded_files": uploaded_filenames, # List of successfully saved *new* files this request
977
+ "all_reference_files": updated_file_list, # Complete current list
978
+ "errors": errors,
979
+ }
980
+
981
+ status_code = (
982
+ 200 if not errors or len(errors) < len(files) else 400
983
+ ) # OK if at least one succeeded, else Bad Request
984
+ if errors:
985
+ logger.warning(f"Upload completed with errors: {errors}")
986
+
987
+ return JSONResponse(content=response_data, status_code=status_code)
988
+
989
+
990
+ # --- Health Check Endpoint ---
991
+ @app.get("/health", tags=["Server Status"], summary="Check server health")
992
+ async def health_check():
993
+ """Basic health check, indicates if the server is running and if the model is loaded."""
994
+ # Access the MODEL_LOADED variable *directly* from the engine module
995
+ # each time the endpoint is called to get the current status.
996
+ current_model_status = getattr(engine, "MODEL_LOADED", False) # Safely get status
997
+ logger.debug(
998
+ f"Health check returning model_loaded status: {current_model_status}"
999
+ ) # Add debug log
1000
+ return {"status": "healthy", "model_loaded": current_model_status}
1001
+
1002
+
1003
+ # --- Main Execution ---
1004
+ if __name__ == "__main__":
1005
+ host = get_host()
1006
+ port = get_port()
1007
+ logger.info(f"Starting Dia TTS server on {host}:{port}")
1008
+ logger.info(f"Model Repository: {get_model_repo_id()}")
1009
+ logger.info(f"Model Config File: {get_model_config_filename()}")
1010
+ logger.info(f"Model Weights File: {get_model_weights_filename()}")
1011
+ logger.info(f"Model Cache Path: {get_model_cache_path()}")
1012
+ logger.info(f"Reference Audio Path: {get_reference_audio_path()}")
1013
+ logger.info(f"Output Path: {get_output_path()}")
1014
+ # Determine the host to display in logs and use for browser opening
1015
+ display_host = "localhost" if host == "0.0.0.0" else host
1016
+ logger.info(f"Web UI will be available at http://{display_host}:{port}/")
1017
+ logger.info(f"API Docs available at http://{display_host}:{port}/docs")
1018
+
1019
+ # Ensure UI directory and index.html exist for UI
1020
+ ui_dir = "ui"
1021
+ index_file = os.path.join(ui_dir, "index.html")
1022
+ if not os.path.isdir(ui_dir) or not os.path.isfile(index_file):
1023
+ logger.warning(
1024
+ f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work."
1025
+ )
1026
+ # Optionally create dummy files/dirs if needed for startup
1027
+ os.makedirs(ui_dir, exist_ok=True)
1028
+ if not os.path.isfile(index_file):
1029
+ try:
1030
+ with open(index_file, "w") as f:
1031
+ f.write(
1032
+ "<html><body>Web UI template missing. See project source for index.html.</body></html>"
1033
+ )
1034
+ logger.info(f"Created dummy {index_file}.")
1035
+ except Exception as e:
1036
+ logger.error(f"Failed to create dummy {index_file}: {e}")
1037
+
1038
+ # --- Create synchronization event ---
1039
+ # This event will be set by the lifespan manager once startup (incl. model loading) is complete.
1040
+ startup_complete_event = threading.Event()
1041
+
1042
+ # Run Uvicorn server
1043
+ # The lifespan context manager ('lifespan="on"') will run during startup.
1044
+ # The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'.
1045
+ uvicorn.run(
1046
+ "server:app", # Use the format 'module:app_instance'
1047
+ host=host,
1048
+ port=port,
1049
+ reload=False, # Set reload as needed for development/production
1050
+ # reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development
1051
+ # reload_includes=[
1052
+ # "*.py",
1053
+ # "*.html",
1054
+ # "*.css",
1055
+ # "*.js",
1056
+ # ".env",
1057
+ # "*.yaml",
1058
+ # ],
1059
+ lifespan="on", # Use the lifespan context manager defined in this file
1060
+ # workers=1 # Keep workers=1 when using reload=True or complex global state/models
1061
+ )
ui/index.html ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en" class="dark"> <!-- Default to dark mode class -->
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Dia TTS Server | Text-to-Dialogue</title>
8
+ <link rel="icon" href="/static/favicon.ico" type="image/x-icon">
9
+ <!-- Tailwind CSS (CDN for simplicity, processes styles in <style type="text/tailwindcss"> below) -->
10
+ <script src="https://cdn.tailwindcss.com"></script>
11
+ <script>
12
+ // Configure Tailwind CSS
13
+ tailwind.config = {
14
+ darkMode: 'class', // Enable class-based dark mode
15
+ theme: {
16
+ extend: {
17
+ colors: {
18
+ // Define color palettes used in style.css
19
+ // Light Mode Colors (Examples - Adjust as needed)
20
+ gray: { 50: '#f9fafb', 100: '#f3f4f6', 200: '#e5e7eb', 300: '#d1d5db', 400: '#9ca3af', 500: '#6b7280', 600: '#4b5563', 700: '#374151', 800: '#1f2937', 900: '#111827' },
21
+ sky: { 50: '#f0f9ff', 100: '#e0f2fe', 200: '#bae6fd', 300: '#7dd3fc', 400: '#38bdf8', 500: '#0ea5e9', 600: '#0284c7', 700: '#0369a1', 800: '#075985', 900: '#0c4a6e' },
22
+ indigo: { 50: '#eef2ff', 100: '#e0e7ff', 200: '#c7d2fe', 300: '#a5b4fc', 400: '#818cf8', 500: '#6366f1', 600: '#4f46e5', 700: '#4338ca', 800: '#3730a3', 900: '#312e81' },
23
+ red: { 100: '#fee2e2', 300: '#fca5a5', 500: '#ef4444', 600: '#dc2626', 800: '#991b1b', 900: '#7f1d1d' },
24
+ green: { 100: '#dcfce7', 300: '#86efac', 500: '#22c55e', 800: '#166534', 900: '#14532d' },
25
+ yellow: { 100: '#fef9c3', 300: '#fcd34d', 500: '#eab308', 700: '#b45309', 900: '#78350f' },
26
+
27
+ // Dark Mode Colors (Copied from previous inline config)
28
+ primary: { 50: '#f0f9ff', 100: '#e0f2fe', 200: '#bae6fd', 300: '#7dd3fc', 400: '#38bdf8', 500: '#0ea5e9', 600: '#0284c7', 700: '#0369a1', 800: '#075985', 900: '#0c4a6e' },
29
+ purple: { 50: '#faf5ff', 100: '#f3e8ff', 200: '#e9d5ff', 300: '#d8b4fe', 400: '#c084fc', 500: '#a855f7', 600: '#9333ea', 700: '#7e22ce', 800: '#6b21a8', 900: '#581c87' },
30
+ dark: { 50: '#f9fafb', 100: '#f3f4f6', 200: '#e5e7eb', 300: '#d1d5db', 400: '#9ca3af', 500: '#6b7280', 600: '#4b5563', 700: '#374151', 800: '#1f2937', 900: '#111827', 950: '#030712', 1000: '#0f1729' }
31
+ }
32
+ }
33
+ }
34
+ }
35
+ </script>
36
+ <!-- Removed External Stylesheet Link: <link rel="stylesheet" href="/ui/style.css"> -->
37
+ <!-- Wavesurfer for audio visualization -->
38
+ <script src="https://unpkg.com/wavesurfer.js@7"></script>
39
+
40
+ <style type="text/tailwindcss">
41
+ /* ui/style.css */
42
+
43
+ /* Import Tailwind base, components, and utilities */
44
+ @tailwind base;
45
+ @tailwind components;
46
+ @tailwind utilities;
47
+
48
+ /* Define custom components/utilities */
49
+ @layer components {
50
+
51
+ /* Base styles (Light Mode) */
52
+ .body-base {
53
+ @apply h-full bg-gray-100 text-gray-900;
54
+ }
55
+
56
+ .nav-base {
57
+ @apply bg-gradient-to-r from-white to-sky-100 border-b border-sky-200 shadow-md;
58
+ }
59
+
60
+ .nav-link {
61
+ @apply text-sky-700 hover:text-sky-900 px-3 py-2 rounded-md text-sm font-medium;
62
+ }
63
+
64
+ .title-link {
65
+ @apply text-gray-900 text-xl font-bold;
66
+ }
67
+
68
+ .card-base {
69
+ @apply bg-white shadow-lg rounded-lg overflow-hidden border border-gray-200;
70
+ }
71
+
72
+ .card-header {
73
+ @apply text-lg font-medium text-gray-900 mb-4;
74
+ }
75
+
76
+ .card-footer {
77
+ @apply bg-gray-50 px-6 py-4 flex items-center justify-between border-t border-gray-200;
78
+ }
79
+
80
+ .label-base {
81
+ @apply block text-sm font-medium text-gray-700 mb-1;
82
+ }
83
+
84
+ .input-base {
85
+ @apply block w-full rounded-md border-gray-300 shadow-sm focus:border-sky-500 focus:ring-sky-500 sm:text-sm px-3 py-2 bg-white text-gray-900 placeholder-gray-400;
86
+ }
87
+
88
+ .textarea-base {
89
+ @apply input-base;
90
+ /* Inherit base input styles */
91
+ }
92
+
93
+ .select-base {
94
+ @apply input-base appearance-none pr-8;
95
+ /* Add padding for arrow */
96
+ /* Consider adding a background SVG for the dropdown arrow */
97
+ }
98
+
99
+ .button-base {
100
+ @apply inline-flex items-center justify-center px-4 py-2 border border-transparent rounded-md shadow-sm text-sm font-medium focus:outline-none focus:ring-2 focus:ring-offset-2 transition-colors disabled:opacity-50 disabled:cursor-not-allowed whitespace-nowrap flex-shrink-0;
101
+ /* Added whitespace-nowrap and flex-shrink-0 for button text */
102
+ }
103
+
104
+ .btn-primary {
105
+ @apply button-base bg-sky-600 text-white hover:bg-sky-700 focus:ring-sky-500;
106
+ }
107
+
108
+ .btn-secondary {
109
+ @apply button-base bg-gray-200 text-gray-700 border-gray-300 hover:bg-gray-300 focus:ring-indigo-500;
110
+ /* Example secondary */
111
+ }
112
+
113
+ .btn-danger {
114
+ @apply button-base bg-red-600 text-white hover:bg-red-700 focus:ring-red-500;
115
+ }
116
+
117
+ .btn-purple {
118
+ @apply button-base bg-purple-600 text-white hover:bg-purple-700 focus:ring-purple-500;
119
+ }
120
+
121
+ .slider-base {
122
+ @apply w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer;
123
+ /* Need to style the thumb separately per browser */
124
+ }
125
+
126
+ .slider-thumb {
127
+ /* Basic thumb styling */
128
+ @apply appearance-none w-5 h-5 bg-sky-600 rounded-full cursor-pointer;
129
+ }
130
+
131
+ .radio-label {
132
+ @apply flex items-center space-x-2 cursor-pointer border border-gray-300 bg-white hover:border-sky-400 p-3 rounded-md transition-colors;
133
+ }
134
+
135
+ .radio-label-text {
136
+ @apply text-gray-700;
137
+ }
138
+
139
+ /* Apply checked styles directly using peer-checked utility on the container/text span */
140
+ /* .radio-label input:checked+span {
141
+ @apply text-sky-600 font-semibold;
142
+ }
143
+
144
+ .radio-label-checked {
145
+ @apply border-sky-500 ring-2 ring-sky-500;
146
+ } */
147
+ /* Replaced these custom classes with Tailwind peer utilities in the HTML */
148
+
149
+
150
+ .preset-button {
151
+ @apply button-base bg-indigo-100 text-indigo-700 border-indigo-200 hover:bg-indigo-200 focus:ring-indigo-500 text-xs px-3 py-1;
152
+ }
153
+
154
+ .notification-base {
155
+ @apply px-4 py-3 rounded relative shadow-md flex items-center mb-3;
156
+ /* Reduced margin bottom */
157
+ }
158
+
159
+ .notification-success {
160
+ @apply notification-base bg-green-100 border border-green-300 text-green-800;
161
+ }
162
+
163
+ .notification-error {
164
+ @apply notification-base bg-red-100 border border-red-300 text-red-800;
165
+ }
166
+
167
+ .notification-warning {
168
+ @apply notification-base bg-yellow-100 border border-yellow-300 text-yellow-800;
169
+ }
170
+
171
+ .notification-info {
172
+ /* Added info style */
173
+ @apply notification-base bg-sky-100 border border-sky-300 text-sky-800;
174
+ }
175
+
176
+ .code-inline {
177
+ @apply bg-gray-200 px-1 rounded text-sm font-mono text-gray-800;
178
+ }
179
+
180
+ .tooltip {
181
+ /* Basic tooltip styling */
182
+ @apply absolute hidden group-hover:block bg-gray-700 text-white text-xs rounded py-1 px-2 z-10 -mt-8;
183
+ }
184
+
185
+ .loading-overlay-base {
186
+ @apply fixed inset-0 bg-gray-600 bg-opacity-75 flex items-center justify-center z-50 transition-opacity duration-300;
187
+ }
188
+
189
+ .loading-box-base {
190
+ @apply bg-white p-6 rounded-lg shadow-xl flex flex-col items-center border border-gray-300;
191
+ }
192
+
193
+ .loading-spinner {
194
+ @apply animate-spin h-10 w-10 text-sky-600 mb-4;
195
+ }
196
+
197
+ .loading-text {
198
+ @apply text-gray-900 text-lg mb-2;
199
+ }
200
+
201
+ .loading-status {
202
+ @apply text-gray-600 text-sm mb-4 text-center max-w-xs;
203
+ /* Limit width */
204
+ }
205
+
206
+ .waveform-container {
207
+ @apply w-full h-24 bg-gray-100 rounded;
208
+ }
209
+
210
+ .audio-player-card {
211
+ @apply card-base mt-8;
212
+ /* Margin top for spacing */
213
+ }
214
+
215
+ .audio-player-controls {
216
+ @apply flex flex-wrap items-center justify-between gap-4;
217
+ }
218
+
219
+ .audio-player-buttons {
220
+ @apply flex items-center space-x-2 sm:space-x-4;
221
+ /* Adjust spacing */
222
+ }
223
+
224
+ .audio-player-info {
225
+ @apply text-sm text-gray-600 text-right;
226
+ }
227
+
228
+ .theme-switch {
229
+ @apply p-2 rounded-md text-gray-600 hover:bg-gray-200 hover:text-gray-800 focus:outline-none focus:ring-2 focus:ring-sky-500 focus:ring-offset-2;
230
+ }
231
+
232
+
233
+ /* Dark Mode Overrides using 'dark:' prefix */
234
+ .dark .body-base {
235
+ @apply bg-[#0f1729] text-white;
236
+ /* Original dark bg */
237
+ }
238
+
239
+ .dark .nav-base {
240
+ @apply bg-gradient-to-r from-dark-900 to-purple-900 border-b border-purple-800 shadow-lg;
241
+ }
242
+
243
+ .dark .nav-link {
244
+ @apply text-primary-300 hover:text-white;
245
+ }
246
+
247
+ .dark .title-link {
248
+ @apply text-white;
249
+ }
250
+
251
+ .dark .card-base {
252
+ @apply bg-dark-800 border border-dark-700;
253
+ }
254
+
255
+ .dark .card-header {
256
+ @apply text-white;
257
+ }
258
+
259
+ .dark .card-footer {
260
+ @apply bg-dark-900 border-t border-dark-700;
261
+ }
262
+
263
+ .dark .label-base {
264
+ @apply text-gray-300;
265
+ /* Lighter gray for dark */
266
+ }
267
+
268
+ .dark .input-base {
269
+ @apply border-dark-600 bg-dark-700 text-white placeholder-gray-500 focus:ring-offset-dark-800;
270
+ }
271
+
272
+ .dark .select-base {
273
+ /* Dark mode arrow styling if needed */
274
+ }
275
+
276
+ .dark .btn-primary {
277
+ @apply bg-primary-600 text-white hover:bg-primary-700 focus:ring-primary-500 focus:ring-offset-dark-800;
278
+ }
279
+
280
+ .dark .btn-secondary {
281
+ @apply bg-dark-700 text-white border-dark-600 hover:bg-dark-600 focus:ring-purple-500 focus:ring-offset-dark-800;
282
+ }
283
+
284
+ .dark .btn-danger {
285
+ @apply bg-red-600 text-white hover:bg-red-700 focus:ring-red-500 focus:ring-offset-dark-800;
286
+ }
287
+
288
+ .dark .btn-purple {
289
+ @apply bg-purple-600 text-white hover:bg-purple-700 focus:ring-purple-500 focus:ring-offset-dark-800;
290
+ }
291
+
292
+ .dark .slider-base {
293
+ @apply bg-dark-600;
294
+ }
295
+
296
+ .dark .slider-thumb {
297
+ @apply bg-primary-500;
298
+ }
299
+
300
+ .dark .radio-label {
301
+ @apply border-dark-600 bg-dark-800 hover:border-primary-400;
302
+ }
303
+
304
+ .dark .radio-label-text {
305
+ @apply text-gray-300;
306
+ }
307
+
308
+ /* Apply checked styles directly using peer-checked utility on the container/text span */
309
+ /* .dark .radio-label input:checked+span {
310
+ @apply text-primary-400;
311
+ }
312
+
313
+ .dark .radio-label-checked {
314
+ @apply border-primary-500 ring-primary-500;
315
+ } */
316
+ /* Replaced these custom classes with Tailwind peer utilities in the HTML */
317
+
318
+
319
+ .dark .preset-button {
320
+ @apply bg-indigo-900 text-indigo-200 border-indigo-700 hover:bg-indigo-800 focus:ring-indigo-500 focus:ring-offset-dark-800;
321
+ }
322
+
323
+ .dark .notification-success {
324
+ @apply notification-base bg-green-900 border border-green-700 text-green-100;
325
+ }
326
+
327
+ .dark .notification-error {
328
+ @apply notification-base bg-red-900 border border-red-700 text-red-100;
329
+ }
330
+
331
+ .dark .notification-warning {
332
+ @apply notification-base bg-yellow-900 border border-yellow-700 text-yellow-100;
333
+ }
334
+
335
+ .dark .notification-info {
336
+ /* Added info style */
337
+ @apply notification-base bg-sky-900 border border-sky-700 text-sky-100;
338
+ }
339
+
340
+ .dark .code-inline {
341
+ @apply bg-dark-900 text-purple-300;
342
+ }
343
+
344
+ .dark .tooltip {
345
+ @apply bg-dark-950;
346
+ }
347
+
348
+ .dark .loading-overlay-base {
349
+ @apply bg-dark-900 bg-opacity-75;
350
+ }
351
+
352
+ .dark .loading-box-base {
353
+ @apply bg-dark-800 border border-dark-700;
354
+ }
355
+
356
+ .dark .loading-spinner {
357
+ @apply text-primary-500;
358
+ }
359
+
360
+ .dark .loading-text {
361
+ @apply text-white;
362
+ }
363
+
364
+ .dark .loading-status {
365
+ @apply text-gray-400;
366
+ }
367
+
368
+ .dark .waveform-container {
369
+ @apply bg-dark-900;
370
+ }
371
+
372
+ .dark .audio-player-info {
373
+ @apply text-purple-300;
374
+ }
375
+
376
+ .dark .theme-switch {
377
+ @apply text-gray-400 hover:bg-dark-700 hover:text-white focus:ring-offset-dark-900;
378
+ }
379
+
380
+ }
381
+
382
+ /* Specific slider thumb styling per browser */
383
+ /* Apply these within the <style> tag as they target pseudo-elements */
384
+ input[type="range"].slider-base::-webkit-slider-thumb {
385
+ @apply slider-thumb;
386
+ }
387
+
388
+ input[type="range"].slider-base::-moz-range-thumb {
389
+ @apply slider-thumb;
390
+ }
391
+
392
+ /* Dark mode thumbs need specific overrides if needed */
393
+ .dark input[type="range"].slider-base::-webkit-slider-thumb {
394
+ /* Apply dark mode thumb styles directly */
395
+ background-color: theme('colors.primary.500');
396
+ /* Replaced @apply dark:slider-thumb */
397
+ /* Inherit other base thumb styles if needed (like size, border-radius) or re-apply */
398
+ @apply appearance-none w-5 h-5 rounded-full cursor-pointer;
399
+ }
400
+
401
+ .dark input[type="range"].slider-base::-moz-range-thumb {
402
+ /* Apply dark mode thumb styles directly */
403
+ background-color: theme('colors.primary.500');
404
+ /* Replaced @apply dark:slider-thumb */
405
+ /* Inherit other base thumb styles if needed or re-apply */
406
+ @apply appearance-none w-5 h-5 rounded-full cursor-pointer;
407
+ }
408
+ </style>
409
+ </head>
410
+
411
+ <body class="body-base">
412
+ <div class="min-h-full">
413
+ <!-- Navigation -->
414
+ <nav class="nav-base">
415
+ <div class="mx-auto max-w-7xl px-4 sm:px-6 lg:px-8">
416
+ <div class="flex h-16 items-center justify-between">
417
+ <div class="flex items-center">
418
+ <div class="flex-shrink-0">
419
+ <!-- Make title clickable -->
420
+ <a href="/" class="title-link">Dia TTS Server</a>
421
+ </div>
422
+ </div>
423
+ <div class="flex items-center space-x-2 sm:space-x-4">
424
+ <a href="/docs" target="_blank" class="nav-link">API Docs</a>
425
+ <!-- Theme Toggle Button -->
426
+ <button id="theme-toggle-btn" type="button"
427
+ class="relative inline-flex items-center p-1 rounded-full bg-gray-200 dark:bg-dark-700 h-8 w-16 transition-colors"
428
+ title="Toggle light/dark mode">
429
+ <span class="sr-only">Toggle theme</span>
430
+ <span class="absolute inset-0 rounded-full transition-colors"></span>
431
+ <!-- Toggle thumb with icons -->
432
+ <span
433
+ class="relative rounded-full w-6 h-6 bg-white dark:bg-purple-600 transform transition-transform duration-200 ease-in-out translate-x-0 dark:translate-x-8 flex items-center justify-center shadow-md">
434
+ <!-- Sun icon (for light mode) -->
435
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
436
+ class="w-4 h-4 text-yellow-500 dark:opacity-0 transition-opacity">
437
+ <path
438
+ d="M10 2a.75.75 0 0 1 .75.75v1.5a.75.75 0 0 1-1.5 0v-1.5A.75.75 0 0 1 10 2ZM10 15a.75.75 0 0 1 .75.75v1.5a.75.75 0 0 1-1.5 0v-1.5A.75.75 0 0 1 10 15ZM10 7a3 3 0 1 0 0 6 3 3 0 0 0 0-6Z" />
439
+ </svg>
440
+ <!-- Moon icon (for dark mode) -->
441
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
442
+ class="w-4 h-4 text-white opacity-0 dark:opacity-100 transition-opacity">
443
+ <path
444
+ d="M7.455 1.75A8.5 8.5 0 0 1 18.25 12.55a8.5 8.5 0 0 1-8.46 8.46A8.5 8.5 0 0 1 1.75 12.55a8.5 8.5 0 0 1 5.705-10.8Z" />
445
+ </svg>
446
+ </span>
447
+ </button>
448
+ </div>
449
+ </div>
450
+ </div>
451
+ </nav>
452
+
453
+ <!-- Main content -->
454
+ <main>
455
+ <div class="mx-auto max-w-7xl px-4 py-8 sm:px-6 lg:px-8">
456
+
457
+ <!-- Notification area -->
458
+ <div id="notification-area" class="mb-6 space-y-3">
459
+ {% if error %}
460
+ <div class="notification-error" role="alert">
461
+ <svg class="h-5 w-5 text-red-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor">
462
+ <path fill-rule="evenodd"
463
+ d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.707 7.293a1 1 0 00-1.414 1.414L8.586 10l-1.293 1.293a1 1 0 101.414 1.414L10 11.414l1.293 1.293a1 1 0 001.414-1.414L11.414 10l1.293-1.293a1 1 0 00-1.414-1.414L10 8.586 8.707 7.293z"
464
+ clip-rule="evenodd" />
465
+ </svg>
466
+ <span class="block sm:inline">{{ error }}</span>
467
+ </div>
468
+ {% endif %}
469
+ {% if success %}
470
+ <div class="notification-success" role="alert">
471
+ <svg class="h-5 w-5 text-green-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor">
472
+ <path fill-rule="evenodd"
473
+ d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z"
474
+ clip-rule="evenodd" />
475
+ </svg>
476
+ <span class="block sm:inline">{{ success }}</span>
477
+ </div>
478
+ {% endif %}
479
+ </div>
480
+
481
+ <!-- TTS form -->
482
+ <div class="card-base">
483
+ <form id="tts-form" action="/web/generate" method="post" class="flex flex-col">
484
+ <div class="p-6">
485
+ <h2 class="card-header">Generate Speech with Dia</h2>
486
+
487
+ <!-- Text input -->
488
+ <div class="mb-6">
489
+ <label for="text" class="label-base">Text to speak</label>
490
+ <p class="text-xs text-purple-500 dark:text-purple-300 mb-2">
491
+ Use <code class="code-inline">[S1]</code> and <code class="code-inline">[S2]</code>
492
+ tags for speaker turns. Add non-verbals like <code
493
+ class="code-inline">(laughs)</code>.
494
+ </p>
495
+ <div class="relative">
496
+ <textarea name="text" id="text" rows="5" maxlength="8192" class="textarea-base"
497
+ placeholder="Example: [S1] Hello there! [S2] Hi! How are you? [S1] I'm doing well, thanks. (laughs)"
498
+ required>{{ submitted_text if submitted_text else "" }}</textarea>
499
+ <div class="absolute bottom-2 right-2 text-xs text-gray-500 dark:text-purple-300">
500
+ <span id="char-count">0</span> / 8192
501
+ </div>
502
+ </div>
503
+ </div>
504
+
505
+ <!-- Voice Mode Selection -->
506
+ <div class="mb-6">
507
+ <label class="label-base mb-2">Voice Mode</label>
508
+ <div class="grid grid-cols-1 md:grid-cols-2 gap-4">
509
+ <!-- Combined Dialogue / Single Speaker Mode -->
510
+ <label
511
+ class="radio-label peer-checked:border-sky-500 peer-checked:dark:border-primary-500 peer-checked:ring-2 peer-checked:ring-sky-500 peer-checked:dark:ring-primary-500">
512
+ <input type="radio" name="voice_mode" value="dialogue" class="hidden peer" {% if
513
+ submitted_voice_mode=='dialogue' or not submitted_voice_mode %}checked{%
514
+ endif %} onchange="toggleCloneOptions()">
515
+ <span
516
+ class="radio-label-text peer-checked:text-sky-600 dark:peer-checked:text-primary-400 peer-checked:font-semibold">
517
+ Single / Dialogue (Use [S1]/[S2])
518
+ </span>
519
+ </label>
520
+ <!-- Clone Mode -->
521
+ <label
522
+ class="radio-label peer-checked:border-sky-500 peer-checked:dark:border-primary-500 peer-checked:ring-2 peer-checked:ring-sky-500 peer-checked:dark:ring-primary-500">
523
+ <input type="radio" name="voice_mode" value="clone" class="hidden peer" {% if
524
+ submitted_voice_mode=='clone' %}checked{% endif %}
525
+ onchange="toggleCloneOptions()">
526
+ <span
527
+ class="radio-label-text peer-checked:text-sky-600 dark:peer-checked:text-primary-400 peer-checked:font-semibold">
528
+ Voice Clone (from Reference)
529
+ </span>
530
+ </label>
531
+ </div>
532
+ </div>
533
+
534
+ <!-- Presets Section -->
535
+ <div class="mb-6">
536
+ <label class="label-base mb-2">Load Example Preset</label>
537
+ <div id="presets-container" class="flex flex-wrap gap-2">
538
+ {% if presets %}
539
+ {% for preset in presets %}
540
+ <button type="button" id="preset-btn-{{ loop.index0 }}" class="preset-button"
541
+ title="Load '{{ preset.name }}' text and settings">
542
+ {{ preset.name }}
543
+ </button>
544
+ {% endfor %}
545
+ {% else %}
546
+ <p class="text-sm text-gray-500 dark:text-gray-400">No presets loaded. Check
547
+ presets.yaml.</p>
548
+ {% endif %}
549
+ </div>
550
+ </div>
551
+
552
+
553
+ <!-- Clone Options (Hidden by default) -->
554
+ <div id="clone-options" class="mb-6 hidden">
555
+ <label for="clone_reference_select" class="label-base">Reference Audio File</label>
556
+ <p class="text-xs text-purple-500 dark:text-purple-300 mb-2">
557
+ Select a <code class="code-inline">.wav</code> or <code
558
+ class="code-inline">.mp3</code> file from the <code
559
+ class="code-inline">reference_audio</code> folder.
560
+ <strong class="dark:text-yellow-300 text-yellow-600">Important:</strong> Prepend the
561
+ exact transcript of this audio to your text input above for best results.
562
+ </p>
563
+ <div class="flex items-center gap-2">
564
+ <select id="clone_reference_select" name="clone_reference_select"
565
+ class="select-base flex-grow">
566
+ <option value="none" {% if not submitted_clone_file %}selected{% endif %}>--
567
+ Select Reference File --</option>
568
+ {% for filename in reference_files %}
569
+ <option value="{{ filename }}" {% if submitted_clone_file==filename %}selected{%
570
+ endif %}>{{ filename }}</option>
571
+ {% endfor %}
572
+ </select>
573
+ <!-- Hidden file input triggered by the button -->
574
+ <input type="file" id="clone-file-input" class="hidden" multiple accept=".wav,.mp3"
575
+ aria-label="Upload reference audio file">
576
+ <!-- Modified Load Button -->
577
+ <button type="button" id="clone-load-button" class="btn-secondary hidden"
578
+ title="Upload new reference files">
579
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
580
+ class="w-5 h-5 mr-1">
581
+ <path
582
+ d="M9.25 13.25a.75.75 0 0 0 1.5 0V4.636l2.955 3.129a.75.75 0 0 0 1.09-1.03l-4.25-4.5a.75.75 0 0 0-1.09 0l-4.25 4.5a.75.75 0 1 0 1.09 1.03L9.25 4.636v8.614Z" />
583
+ <path
584
+ d="M3.5 12.75a.75.75 0 0 0-1.5 0v2.5A2.75 2.75 0 0 0 4.75 18h10.5A2.75 2.75 0 0 0 18 15.25v-2.5a.75.75 0 0 0-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5Z" />
585
+ </svg>
586
+ Load
587
+ </button>
588
+ </div>
589
+ </div>
590
+
591
+
592
+ <!-- Generation Parameters -->
593
+ <div class="mb-6">
594
+ <details class="group">
595
+ <summary class="list-none flex cursor-pointer items-center">
596
+ <span class="text-sm font-medium label-base">Generation Parameters</span>
597
+ <span class="ml-2 text-purple-500 dark:text-purple-300">
598
+ <svg class="group-open:rotate-180 h-5 w-5 transition-transform"
599
+ viewBox="0 0 20 20" fill="currentColor">
600
+ <path fill-rule="evenodd"
601
+ d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z"
602
+ clip-rule="evenodd" />
603
+ </svg>
604
+ </span>
605
+ </summary>
606
+ <div class="mt-4 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-4">
607
+ <!-- Use default_gen_params passed from server for initial values -->
608
+ {% set current_gen_params = submitted_gen_params if submitted_gen_params else
609
+ default_gen_params %}
610
+ <!-- Speed Factor -->
611
+ <div>
612
+ <label for="speed_factor" class="label-base">Speed Factor (<span
613
+ id="speed_factor_value">{{ current_gen_params.speed_factor
614
+ }}</span>)</label>
615
+ <input type="range" id="speed_factor" name="speed_factor" min="0.5"
616
+ max="2.0" step="0.01" value="{{ current_gen_params.speed_factor }}"
617
+ class="slider-base">
618
+ </div>
619
+ <!-- CFG Scale -->
620
+ <div>
621
+ <label for="cfg_scale" class="label-base">CFG Scale (<span
622
+ id="cfg_scale_value">{{ current_gen_params.cfg_scale
623
+ }}</span>)</label>
624
+ <input type="range" id="cfg_scale" name="cfg_scale" min="1.0" max="5.0"
625
+ step="0.1" value="{{ current_gen_params.cfg_scale }}"
626
+ class="slider-base">
627
+ </div>
628
+ <!-- Temperature -->
629
+ <div>
630
+ <label for="temperature" class="label-base">Temperature (<span
631
+ id="temperature_value">{{ current_gen_params.temperature
632
+ }}</span>)</label>
633
+ <input type="range" id="temperature" name="temperature" min="1.0" max="1.5"
634
+ step="0.05" value="{{ current_gen_params.temperature }}"
635
+ class="slider-base">
636
+ </div>
637
+ <!-- Top P -->
638
+ <div>
639
+ <label for="top_p" class="label-base">Top P (<span id="top_p_value">{{
640
+ current_gen_params.top_p }}</span>)</label>
641
+ <input type="range" id="top_p" name="top_p" min="0.8" max="1.0" step="0.01"
642
+ value="{{ current_gen_params.top_p }}" class="slider-base">
643
+ </div>
644
+ <!-- CFG Filter Top K -->
645
+ <div>
646
+ <label for="cfg_filter_top_k" class="label-base">CFG Filter Top K (<span
647
+ id="cfg_filter_top_k_value">{{ current_gen_params.cfg_filter_top_k
648
+ }}</span>)</label>
649
+ <input type="range" id="cfg_filter_top_k" name="cfg_filter_top_k" min="15"
650
+ max="50" step="1" value="{{ current_gen_params.cfg_filter_top_k }}"
651
+ class="slider-base">
652
+ </div>
653
+ <!-- Save Gen Defaults Button -->
654
+ <div class="col-span-1 md:col-span-2 mt-4 flex items-center gap-4">
655
+ <button id="save-gen-defaults-btn" type="button" class="btn-secondary">
656
+ Save Generation Defaults
657
+ </button>
658
+ <span id="gen-defaults-status" class="text-xs hidden"></span>
659
+ </div>
660
+ </div>
661
+ </details>
662
+ </div>
663
+
664
+ <!-- Server Configuration (Collapsible) -->
665
+ <div class="mb-6">
666
+ <details class="group">
667
+ <summary class="list-none flex cursor-pointer items-center">
668
+ <span class="text-sm font-medium label-base">Server Configuration</span>
669
+ <span class="ml-2 text-purple-500 dark:text-purple-300">
670
+ <svg class="group-open:rotate-180 h-5 w-5 transition-transform"
671
+ viewBox="0 0 20 20" fill="currentColor">
672
+ <path fill-rule="evenodd"
673
+ d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z"
674
+ clip-rule="evenodd" />
675
+ </svg>
676
+ </span>
677
+ </summary>
678
+ <div id="server-config-form"
679
+ class="mt-4 border-t border-gray-200 dark:border-dark-700 pt-4">
680
+ <p class="text-xs text-purple-500 dark:text-purple-300 mb-3">
681
+ These settings are saved to the <code class="code-inline">.env</code> file.
682
+ Restart the server to apply changes.
683
+ </p>
684
+ <div class="grid grid-cols-1 md:grid-cols-2 gap-4">
685
+ <!-- Dia Model Repo ID -->
686
+ <div>
687
+ <label for="config_model_repo" class="label-base text-xs">Model Repo
688
+ ID</label>
689
+ <input type="text" id="config_model_repo" name="DIA_MODEL_REPO_ID"
690
+ value="{{ config.DIA_MODEL_REPO_ID }}"
691
+ placeholder="ttj/dia-1.6b-safetensors" class="input-base text-sm">
692
+ </div>
693
+ <!-- Model Config Filename -->
694
+ <div>
695
+ <label for="config_model_config" class="label-base text-xs">Model Config
696
+ Filename</label>
697
+ <input type="text" id="config_model_config"
698
+ name="DIA_MODEL_CONFIG_FILENAME"
699
+ value="{{ config.DIA_MODEL_CONFIG_FILENAME }}"
700
+ placeholder="config.json" class="input-base text-sm">
701
+ </div>
702
+ <!-- Model Weights Filename -->
703
+ <div>
704
+ <label for="config_model_weights" class="label-base text-xs">Model
705
+ Weights Filename</label>
706
+ <input type="text" id="config_model_weights"
707
+ name="DIA_MODEL_WEIGHTS_FILENAME"
708
+ value="{{ config.DIA_MODEL_WEIGHTS_FILENAME }}"
709
+ placeholder="dia-v0_1_bf16.safetensors" class="input-base text-sm">
710
+ </div>
711
+ <!-- Model Cache Path -->
712
+ <div>
713
+ <label for="config_model_cache" class="label-base text-xs">Model Cache
714
+ Path</label>
715
+ <input type="text" id="config_model_cache" name="DIA_MODEL_CACHE_PATH"
716
+ value="{{ config.DIA_MODEL_CACHE_PATH }}"
717
+ placeholder="./model_cache" class="input-base text-sm">
718
+ </div>
719
+ <!-- Reference Audio Path -->
720
+ <div>
721
+ <label for="config_ref_audio" class="label-base text-xs">Reference Audio
722
+ Path</label>
723
+ <input type="text" id="config_ref_audio" name="REFERENCE_AUDIO_PATH"
724
+ value="{{ config.REFERENCE_AUDIO_PATH }}"
725
+ placeholder="./reference_audio" class="input-base text-sm">
726
+ </div>
727
+ <!-- Output Path -->
728
+ <div>
729
+ <label for="config_output_path" class="label-base text-xs">Output
730
+ Path</label>
731
+ <input type="text" id="config_output_path" name="OUTPUT_PATH"
732
+ value="{{ config.OUTPUT_PATH }}" placeholder="./outputs"
733
+ class="input-base text-sm">
734
+ </div>
735
+ <!-- Server Host -->
736
+ <div>
737
+ <label for="config_host" class="label-base text-xs">Server Host</label>
738
+ <input type="text" id="config_host" name="HOST"
739
+ value="{{ config.HOST }}" placeholder="0.0.0.0"
740
+ class="input-base text-sm">
741
+ </div>
742
+ <!-- Server Port -->
743
+ <div>
744
+ <label for="config_port" class="label-base text-xs">Server Port</label>
745
+ <input type="number" id="config_port" name="PORT"
746
+ value="{{ config.PORT }}" min="1024" max="65535" step="1"
747
+ class="input-base text-sm">
748
+ </div>
749
+ <!-- Save/Restart Buttons -->
750
+ <div
751
+ class="col-span-1 md:col-span-2 mt-4 flex flex-col md:flex-row gap-4 items-center">
752
+ <button id="save-config-btn" type="button"
753
+ class="btn-purple w-full md:w-auto">
754
+ Save Server Configuration
755
+ </button>
756
+ <button id="restart-server-btn" type="button"
757
+ class="btn-danger w-full md:w-auto hidden">
758
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none"
759
+ viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"
760
+ class="w-5 h-5 mr-1 inline-block">
761
+ <path stroke-linecap="round" stroke-linejoin="round"
762
+ d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0 3.181 3.183a8.25 8.25 0 0 0 13.803-3.7M4.031 9.865a8.25 8.25 0 0 1 13.803-3.7l3.181 3.182m0-4.991v4.99" />
763
+ </svg>
764
+ Restart Server
765
+ </button>
766
+ <span id="config-status" class="text-xs ml-2 hidden"></span>
767
+ </div>
768
+ </div>
769
+ </div>
770
+ </details>
771
+ </div>
772
+
773
+ </div> <!-- End p-6 -->
774
+
775
+ <!-- Form Actions -->
776
+ <div class="card-footer">
777
+ <div class="text-sm text-gray-600 dark:text-purple-300">
778
+ <p>Use <code class="code-inline">[S1]</code>/<code class="code-inline">[S2]</code> for
779
+ dialogue. Add <code class="code-inline">(laughs)</code> etc.</p>
780
+ </div>
781
+ <button type="submit" id="generate-btn" class="btn-primary">
782
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"
783
+ stroke-width="1.5" stroke="currentColor" class="w-5 h-5 mr-1 inline-block">
784
+ <path stroke-linecap="round" stroke-linejoin="round"
785
+ d="M19.114 5.636a9 9 0 0 1 0 12.728M16.463 8.288a5.25 5.25 0 0 1 0 7.424M6.75 8.25l4.72-4.72a.75.75 0 0 1 1.28.53v15.88a.75.75 0 0 1-1.28.53l-4.72-4.72H4.51c-.88 0-1.704-.507-1.938-1.354A9.009 9.009 0 0 1 2.25 12c0-.83.112-1.633.322-2.396C2.806 8.756 3.63 8.25 4.51 8.25H6.75Z" />
786
+ </svg>
787
+ Generate Speech
788
+ </button>
789
+ </div>
790
+ </form>
791
+ </div> <!-- End TTS Form Card -->
792
+
793
+ <!-- Audio player container - Populated by JavaScript if generation is successful -->
794
+ <div id="audio-player-container" class="mt-8">
795
+ {% if output_file_url %}
796
+ <!-- Template for initial load if result is passed from server -->
797
+ <!-- Add data attribute to signal JS that result is present -->
798
+ <div id="output-file-url-data" data-initial-audio-url="{{ output_file_url }}" class="hidden"></div>
799
+ <div class="audio-player-card">
800
+ <div class="p-6">
801
+ <h2 class="card-header">Generated Audio</h2>
802
+ <div class="mb-4">
803
+ <div id="waveform" class="waveform-container"></div>
804
+ </div>
805
+ <div class="audio-player-controls">
806
+ <div class="audio-player-buttons">
807
+ <button id="play-btn" class="btn-primary" disabled>
808
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
809
+ class="w-5 h-5 mr-1">
810
+ <path fill-rule="evenodd"
811
+ d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z"
812
+ clip-rule="evenodd" />
813
+ </svg>
814
+ Play
815
+ </button>
816
+ <a id="download-link" href="{{ output_file_url }}"
817
+ download="{{ output_file_url.split('/')[-1] }}" class="btn-secondary">
818
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"
819
+ class="w-5 h-5 mr-1">
820
+ <path
821
+ d="M10.75 2.75a.75.75 0 0 0-1.5 0v8.614L6.295 8.235a.75.75 0 1 0-1.09 1.03l4.25 4.5a.75.75 0 0 0 1.09 0l4.25-4.5a.75.75 0 0 0-1.09-1.03l-2.955 3.129V2.75Z" />
822
+ <path
823
+ d="M3.5 12.75a.75.75 0 0 0-1.5 0v2.5A2.75 2.75 0 0 0 4.75 18h10.5A2.75 2.75 0 0 0 18 15.25v-2.5a.75.75 0 0 0-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5Z" />
824
+ </svg>
825
+ Download WAV
826
+ </a>
827
+ </div>
828
+ <div class="audio-player-info">
829
+ Mode: <span class="font-medium">{{ submitted_voice_mode }}</span>
830
+ {% if submitted_voice_mode == 'clone' and submitted_clone_file %}
831
+ (<span class="font-medium">{{ submitted_clone_file }}</span>)
832
+ {% endif %}
833
+ • Gen Time: <span class="font-medium">{{ generation_time }}s</span>
834
+ • Duration: <span id="audio-duration" class="font-medium">--:--</span>
835
+ </div>
836
+ </div>
837
+ </div>
838
+ </div>
839
+ {% endif %}
840
+ </div>
841
+
842
+ <!-- Tips Section -->
843
+ <div class="mt-8">
844
+ <h2 class="card-header mb-4">Tips & Tricks for Dia</h2>
845
+ <div class="card-base">
846
+ <div class="p-6">
847
+ <ul class="list-disc pl-5 text-sm text-gray-700 dark:text-purple-300 space-y-2">
848
+ <li>For **Dialogue** mode, clearly mark speaker turns using <code
849
+ class="code-inline">[S1]</code> and <code class="code-inline">[S2]</code>.</li>
850
+ <li>Add non-verbal sounds like <code class="code-inline">(laughs)</code>, <code
851
+ class="code-inline">(sighs)</code>, <code
852
+ class="code-inline">(clears throat)</code> within the text where desired.</li>
853
+ <li>For **Voice Clone** mode, upload a clean reference audio file (<code
854
+ class="code-inline">.wav</code>/<code class="code-inline">.mp3</code>) using the
855
+ "Load" button. <strong class="dark:text-yellow-300 text-yellow-600">Crucially,
856
+ include the exact transcript of the reference audio at the beginning of your
857
+ text input</strong> (e.g., <code
858
+ class="code-inline">[S1] Reference transcript. [S1] Target text...</code>).</li>
859
+ <li>Experiment with **CFG Scale** (higher = more adherence to text, potentially less
860
+ natural) and **Temperature** (higher = more random/varied).</li>
861
+ <li>The **Speed Factor** adjusts playback speed (0.8 = slower, 1.0 = original).</li>
862
+ <li>Use the <code class="code-inline">/v1/audio/speech</code> endpoint for OpenAI
863
+ compatibility. Use the <code class="code-inline">voice</code> parameter to specify
864
+ mode ('S1', 'S2', 'dialogue', 'reference_file.wav').</li>
865
+ </ul>
866
+ </div>
867
+ </div>
868
+ </div>
869
+ </div>
870
+ </main>
871
+
872
+ <footer class="nav-base py-6 mt-12">
873
+ <div class="mx-auto max-w-7xl px-4 sm:px-6 lg:px-8">
874
+ <div class="flex justify-center">
875
+ <a href="https://github.com/devnen/Dia-TTS-Server"
876
+ class="flex items-center gap-2 text-gray-600 dark:text-purple-300 text-sm hover:text-sky-600 dark:hover:text-primary-400 transition-colors">
877
+ <!-- GitHub icon -->
878
+ <svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="currentColor"
879
+ viewBox="0 0 16 16" class="flex-shrink-0">
880
+ <path
881
+ d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.012 8.012 0 0 0 16 8c0-4.42-3.58-8-8-8z" />
882
+ </svg>
883
+ <span>Dia TTS Server | Powered by FastAPI</span>
884
+ </a>
885
+ </div>
886
+ </div>
887
+ </footer>
888
+ </div>
889
+
890
+ <!-- Loading spinner template (hidden by default) -->
891
+ <div id="loading-overlay" class="loading-overlay-base hidden">
892
+ <div class="loading-box-base">
893
+ <svg class="loading-spinner" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
894
+ <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
895
+ <path class="opacity-75" fill="currentColor"
896
+ d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z">
897
+ </path>
898
+ </svg>
899
+ <p id="loading-message" class="loading-text">Generating audio...</p>
900
+ <p id="loading-status" class="loading-status">Please wait.</p>
901
+ <button id="loading-cancel-btn" type="button" class="btn-secondary mt-4">Cancel</button>
902
+ </div>
903
+ </div>
904
+
905
+ <!-- Pass data from server to JavaScript -->
906
+ <script>
907
+ // Make presets data available to script.js
908
+ // Ensure this is correctly populated by your Jinja2 template context
909
+ window.appPresets = {{ presets | tojson | safe }};
910
+ </script>
911
+
912
+ <!-- Link External JavaScript (Ensure it's loaded AFTER the DOM) -->
913
+ <script src="/ui/script.js" defer></script>
914
+ </body>
915
+
916
+ </html>
ui/presets.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ui/presets.yaml
2
+ # Predefined examples for the Dia TTS UI
3
+
4
+ - name: "Standard Dialogue"
5
+ voice_mode: "dialogue"
6
+ text: |
7
+ [S1] Hey, how's it going?
8
+ [S2] Pretty good! Just grabbing some coffee. You?
9
+ [S1] Same here. Need the fuel! (laughs)
10
+ params:
11
+ cfg_scale: 3.0
12
+ temperature: 1.3
13
+ top_p: 0.95
14
+ cfg_filter_top_k: 35
15
+ # speed_factor uses the saved default
16
+
17
+ - name: "Expressive Narration"
18
+ voice_mode: "dialogue" # Use dialogue mode with single speaker tag
19
+ text: |
20
+ [S1] The old house stood on a windswept hill, its windows like empty eyes staring out at the stormy sea. (sighs) It felt... lonely.
21
+ params:
22
+ cfg_scale: 3.0
23
+ temperature: 1.2 # Slightly lower temp for clarity
24
+ top_p: 0.95
25
+ cfg_filter_top_k: 35
26
+
27
+ - name: "Quick Announcement"
28
+ voice_mode: "dialogue" # Use dialogue mode with single speaker tag
29
+ text: |
30
+ [S1] Attention shoppers! The store will be closing in 15 minutes. Please bring your final purchases to the checkout.
31
+ params:
32
+ cfg_scale: 2.8 # Slightly lower CFG for potentially more natural tone
33
+ temperature: 1.3
34
+ top_p: 0.95
35
+ cfg_filter_top_k: 35
36
+
37
+ - name: "Funny Exchange"
38
+ voice_mode: "dialogue"
39
+ text: |
40
+ [S1] Did you remember to buy the alien repellent?
41
+ [S2] The what now? (laughs) I thought you were joking!
42
+ [S1] Joking? They're landing tonight! (clears throat) Probably.
43
+ params:
44
+ cfg_scale: 3.2 # Slightly higher CFG
45
+ temperature: 1.35 # Slightly higher temp
46
+ top_p: 0.95
47
+ cfg_filter_top_k: 35
48
+
49
+ - name: "Simple Sentence"
50
+ voice_mode: "dialogue" # Use dialogue mode with single speaker tag
51
+ text: |
52
+ [S1] This is a test of the text to speech system.
53
+ params:
54
+ cfg_scale: 3.0
55
+ temperature: 1.3
56
+ top_p: 0.95
57
+ cfg_filter_top_k: 35
ui/script.js ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ui/script.js
2
+
3
+ document.addEventListener('DOMContentLoaded', function () {
4
+ // --- Global Flags ---
5
+ let isGenerating = false;
6
+ let isGenerationCancelled = false;
7
+ let wavesurfer = null; // Global wavesurfer instance
8
+
9
+ // --- Element Selectors ---
10
+ const ttsForm = document.getElementById('tts-form');
11
+ const textArea = document.getElementById('text');
12
+ const charCount = document.getElementById('char-count');
13
+ const voiceModeRadios = document.querySelectorAll('input[name="voice_mode"]');
14
+ const cloneOptionsDiv = document.getElementById('clone-options');
15
+ const cloneReferenceSelect = document.getElementById('clone_reference_select');
16
+ const cloneLoadButton = document.getElementById('clone-load-button'); // New ID
17
+ const cloneFileInput = document.getElementById('clone-file-input'); // New ID
18
+ const generateBtn = document.getElementById('generate-btn');
19
+ const loadingOverlay = document.getElementById('loading-overlay');
20
+ const loadingMessage = document.getElementById('loading-message');
21
+ const loadingStatus = document.getElementById('loading-status'); // New element for status
22
+ const loadingCancelBtn = document.getElementById('loading-cancel-btn'); // New ID
23
+ const notificationArea = document.getElementById('notification-area');
24
+ const audioPlayerContainer = document.getElementById('audio-player-container');
25
+ const configSaveBtn = document.getElementById('save-config-btn');
26
+ const configRestartBtn = document.getElementById('restart-server-btn');
27
+ const configStatus = document.getElementById('config-status');
28
+ const genDefaultsSaveBtn = document.getElementById('save-gen-defaults-btn'); // New ID
29
+ const genDefaultsStatus = document.getElementById('gen-defaults-status'); // New ID
30
+ const themeToggleButton = document.getElementById('theme-toggle-btn'); // New ID
31
+ const themeIconLight = document.getElementById('theme-icon-light'); // New ID
32
+ const themeIconDark = document.getElementById('theme-icon-dark'); // New ID
33
+ const presetsContainer = document.getElementById('presets-container'); // New ID
34
+
35
+ // --- Initial Setup ---
36
+
37
+ // Character counter
38
+ function updateCharCount() {
39
+ if (textArea && charCount) {
40
+ charCount.textContent = textArea.value.length;
41
+ }
42
+ }
43
+ if (textArea) {
44
+ textArea.addEventListener('input', updateCharCount);
45
+ updateCharCount(); // Initial count
46
+ }
47
+
48
+ // Toggle Clone Options Visibility & Required Attribute
49
+ function toggleCloneOptions() {
50
+ const selectedMode = document.querySelector('input[name="voice_mode"]:checked')?.value;
51
+ if (cloneOptionsDiv && cloneReferenceSelect && cloneLoadButton) {
52
+ if (selectedMode === 'clone') {
53
+ cloneOptionsDiv.classList.remove('hidden');
54
+ cloneReferenceSelect.required = true;
55
+ cloneLoadButton.classList.remove('hidden');
56
+ } else {
57
+ cloneOptionsDiv.classList.add('hidden');
58
+ cloneReferenceSelect.required = false;
59
+ // cloneReferenceSelect.value = 'none'; // Don't reset if user might switch back
60
+ cloneLoadButton.classList.add('hidden');
61
+ }
62
+ }
63
+ }
64
+ voiceModeRadios.forEach(radio => radio.addEventListener('change', toggleCloneOptions));
65
+ toggleCloneOptions(); // Initial check
66
+
67
+ // Update slider value displays dynamically
68
+ const sliders = [
69
+ { id: 'speed_factor', valueId: 'speed_factor_value' },
70
+ { id: 'cfg_scale', valueId: 'cfg_scale_value' },
71
+ { id: 'temperature', valueId: 'temperature_value' },
72
+ { id: 'top_p', valueId: 'top_p_value' },
73
+ { id: 'cfg_filter_top_k', valueId: 'cfg_filter_top_k_value' },
74
+ ];
75
+ sliders.forEach(sliderInfo => {
76
+ const slider = document.getElementById(sliderInfo.id);
77
+ const valueDisplay = document.getElementById(sliderInfo.valueId);
78
+ if (slider && valueDisplay) {
79
+ // Set initial display from slider's current value (set by template)
80
+ valueDisplay.textContent = slider.value;
81
+ // Add event listener to update display on change
82
+ slider.addEventListener('input', () => valueDisplay.textContent = slider.value);
83
+ }
84
+ });
85
+
86
+ // --- Notifications ---
87
+ function showNotification(message, type = 'success', duration = 5000) {
88
+ if (!notificationArea) return;
89
+ // notificationArea.innerHTML = ''; // Clear previous? Or allow multiple? Let's allow multiple for now.
90
+ const colors = {
91
+ success: 'notification-success',
92
+ error: 'notification-error',
93
+ warning: 'notification-warning',
94
+ info: 'notification-info' // Add info style if needed
95
+ };
96
+ const icons = { // SVG icons or classes
97
+ success: '<svg class="h-5 w-5 text-green-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd" /></svg>',
98
+ error: '<svg class="h-5 w-5 text-red-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zM8.707 7.293a1 1 0 00-1.414 1.414L8.586 10l-1.293 1.293a1 1 0 101.414 1.414L10 11.414l1.293 1.293a1 1 0 001.414-1.414L11.414 10l1.293-1.293a1 1 0 00-1.414-1.414L10 8.586 8.707 7.293z" clip-rule="evenodd" /></svg>',
99
+ warning: '<svg class="h-5 w-5 text-yellow-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M8.485 2.495c.673-1.167 2.357-1.167 3.03 0l6.28 10.875c.673 1.167-.17 2.625-1.516 2.625H3.72c-1.347 0-2.189-1.458-1.515-2.625L8.485 2.495zM10 5a.75.75 0 01.75.75v3.5a.75.75 0 01-1.5 0v-3.5A.75.75 0 0110 5zm0 9a1 1 0 100-2 1 1 0 000 2z" clip-rule="evenodd" /></svg>',
100
+ info: '<svg class="h-5 w-5 text-sky-500 mr-2 flex-shrink-0" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a.75.75 0 000 1.5h.253a.25.25 0 01.244.304l-.459 2.066A1.75 1.75 0 0010.747 15H11a.75.75 0 000-1.5h-.253a.25.25 0 01-.244-.304l.459-2.066A1.75 1.75 0 009.253 9H9z" clip-rule="evenodd" /></svg>'
101
+ };
102
+
103
+ const notificationDiv = document.createElement('div');
104
+ notificationDiv.className = colors[type] || colors['info']; // Default to info style
105
+ notificationDiv.innerHTML = `${icons[type] || icons['info']} <span class="block sm:inline">${message}</span>`;
106
+ notificationArea.appendChild(notificationDiv);
107
+
108
+ // Auto-hide after specified duration
109
+ if (duration > 0) {
110
+ setTimeout(() => {
111
+ notificationDiv.style.transition = 'opacity 0.5s ease-out';
112
+ notificationDiv.style.opacity = '0';
113
+ setTimeout(() => notificationDiv.remove(), 500);
114
+ }, duration);
115
+ }
116
+ return notificationDiv; // Return the element if manual removal is needed
117
+ }
118
+
119
+ // --- Presets ---
120
+ function applyPreset(presetData) {
121
+ console.log("Applying preset:", presetData);
122
+ if (!presetData) return;
123
+
124
+ // Update text area
125
+ if (textArea && presetData.text !== undefined) {
126
+ textArea.value = presetData.text;
127
+ updateCharCount(); // Update counter
128
+ }
129
+
130
+ // Update voice mode
131
+ if (presetData.voice_mode) {
132
+ const radio = document.querySelector(`input[name="voice_mode"][value="${presetData.voice_mode}"]`);
133
+ if (radio) {
134
+ radio.checked = true;
135
+ toggleCloneOptions(); // Update UI based on new mode
136
+ }
137
+ }
138
+
139
+ // Update generation parameters
140
+ if (presetData.params) {
141
+ for (const [key, value] of Object.entries(presetData.params)) {
142
+ const slider = document.getElementById(key); // Assumes slider ID matches param key
143
+ const valueDisplay = document.getElementById(`${key}_value`);
144
+ if (slider) {
145
+ slider.value = value;
146
+ if (valueDisplay) {
147
+ valueDisplay.textContent = value; // Update display
148
+ }
149
+ } else {
150
+ console.warn(`Slider element not found for preset parameter: ${key}`);
151
+ }
152
+ }
153
+ }
154
+ showNotification(`Preset "${presetData.name}" loaded.`, 'info', 3000);
155
+ }
156
+
157
+ // Add event listeners to preset buttons (assuming they exist)
158
+ // Presets data should be available globally, e.g., from template `window.appPresets = {{ presets | tojson }};`
159
+ if (window.appPresets && presetsContainer) {
160
+ window.appPresets.forEach((preset, index) => {
161
+ const button = document.getElementById(`preset-btn-${index}`);
162
+ if (button) {
163
+ button.addEventListener('click', () => applyPreset(preset));
164
+ }
165
+ });
166
+ } else if (presetsContainer) {
167
+ console.warn("Presets data (window.appPresets) not found, preset buttons will not work.");
168
+ }
169
+
170
+
171
+ // --- Audio Player ---
172
+ function initializeWaveSurfer(audioUrl) {
173
+ if (wavesurfer) {
174
+ wavesurfer.destroy();
175
+ }
176
+ const waveformDiv = document.getElementById('waveform');
177
+ const playBtn = document.getElementById('play-btn');
178
+ const durationSpan = document.getElementById('audio-duration');
179
+
180
+ if (!waveformDiv || !playBtn || !durationSpan) {
181
+ console.error("Audio player elements not found in the container.");
182
+ // Clear the container if elements are missing after generation
183
+ if (audioPlayerContainer) audioPlayerContainer.innerHTML = '<p class="text-red-500 dark:text-red-400">Error displaying audio player.</p>';
184
+ return;
185
+ }
186
+
187
+ // Ensure button text doesn't wrap
188
+ playBtn.classList.add('whitespace-nowrap', 'flex-shrink-0');
189
+ const downloadLink = document.getElementById('download-link');
190
+ if (downloadLink) downloadLink.classList.add('whitespace-nowrap', 'flex-shrink-0');
191
+
192
+
193
+ wavesurfer = WaveSurfer.create({
194
+ container: waveformDiv,
195
+ waveColor: document.documentElement.classList.contains('dark') ? '#38bdf8' : '#0ea5e9', // primary-400(dark) / primary-500(light)
196
+ progressColor: document.documentElement.classList.contains('dark') ? '#0284c7' : '#0369a1', // primary-600(dark) / primary-700(light)
197
+ cursorColor: document.documentElement.classList.contains('dark') ? '#a855f7' : '#9333ea', // purple-500(dark) / purple-600(light)
198
+ barWidth: 3,
199
+ barRadius: 3,
200
+ cursorWidth: 1,
201
+ height: 80,
202
+ barGap: 2,
203
+ responsive: true,
204
+ url: audioUrl,
205
+ mediaControls: false, // Use custom controls
206
+ normalize: true,
207
+ });
208
+
209
+ wavesurfer.on('ready', () => {
210
+ const duration = wavesurfer.getDuration();
211
+ const minutes = Math.floor(duration / 60);
212
+ const seconds = Math.floor(duration % 60);
213
+ durationSpan.textContent = `${minutes}:${seconds < 10 ? '0' : ''}${seconds}`;
214
+ playBtn.disabled = false;
215
+ playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z" clip-rule="evenodd" /></svg> Play`;
216
+ });
217
+
218
+ wavesurfer.on('play', () => {
219
+ playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm5-2.25A.75.75 0 0 1 7.75 7h4.5a.75.75 0 0 1 .75.75v4.5a.75.75 0 0 1-.75.75h-4.5a.75.75 0 0 1-.75-.75v-4.5Z" clip-rule="evenodd" /></svg> Pause`;
220
+ });
221
+
222
+ wavesurfer.on('pause', () => {
223
+ playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z" clip-rule="evenodd" /></svg> Play`;
224
+ });
225
+ wavesurfer.on('finish', () => {
226
+ playBtn.innerHTML = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor" class="w-5 h-5 mr-1"><path fill-rule="evenodd" d="M2 10a8 8 0 1 1 16 0 8 8 0 0 1-16 0Zm6.39-2.908a.75.75 0 0 1 .766.027l3.5 2.25a.75.75 0 0 1 0 1.262l-3.5 2.25A.75.75 0 0 1 8 12.25v-4.5a.75.75 0 0 1 .39-.658Z" clip-rule="evenodd" /></svg> Play`;
227
+ });
228
+
229
+ playBtn.onclick = () => {
230
+ wavesurfer.playPause();
231
+ };
232
+
233
+ // Scroll to the player after initialization
234
+ setTimeout(() => {
235
+ audioPlayerContainer.scrollIntoView({ behavior: 'smooth', block: 'center' });
236
+ }, 100); // Short delay to ensure rendering
237
+ }
238
+
239
+ // Initialize player if audio URL is present on initial page load
240
+ // This logic needs to be adapted as the player is now dynamically added
241
+ // We'll call initializeWaveSurfer if the template renders the player container
242
+ const initialAudioUrlElement = document.querySelector('[data-initial-audio-url]');
243
+ if (initialAudioUrlElement && initialAudioUrlElement.dataset.initialAudioUrl) {
244
+ console.log("Initializing WaveSurfer for initially loaded audio.");
245
+ initializeWaveSurfer(initialAudioUrlElement.dataset.initialAudioUrl);
246
+ }
247
+
248
+
249
+ // --- Form Submission & Cancellation ---
250
+ if (ttsForm) {
251
+ ttsForm.addEventListener('submit', function (event) {
252
+ // Client-side validation
253
+ const text = textArea.value.trim();
254
+ const mode = document.querySelector('input[name="voice_mode"]:checked')?.value;
255
+ const cloneFile = cloneReferenceSelect?.value;
256
+
257
+ if (!text) {
258
+ showNotification("Please enter some text.", 'error');
259
+ event.preventDefault(); return;
260
+ }
261
+ if (mode === 'clone' && (!cloneFile || cloneFile === 'none')) {
262
+ showNotification("Please select a reference file for clone mode.", 'error');
263
+ event.preventDefault(); return;
264
+ }
265
+
266
+ // Handle cancellation of previous request if Generate is clicked again
267
+ if (isGenerating) {
268
+ console.log("Generate clicked while previous generation in progress. Setting cancel flag.");
269
+ showNotification("Cancelling previous request...", 'warning', 2000);
270
+ isGenerationCancelled = true;
271
+ // We don't actually stop the backend here (Fake Cancel)
272
+ // but the result processing will ignore the previous result.
273
+ }
274
+
275
+ // Reset flags and show loading overlay for the new request
276
+ isGenerating = true;
277
+ isGenerationCancelled = false; // Reset cancel flag for the new request
278
+ if (loadingOverlay && generateBtn && loadingCancelBtn) {
279
+ loadingMessage.textContent = 'Generating audio...'; // Initial status
280
+ loadingStatus.textContent = 'Please wait.';
281
+ loadingOverlay.classList.remove('hidden');
282
+ generateBtn.disabled = true;
283
+ generateBtn.classList.add('opacity-50', 'cursor-not-allowed');
284
+ loadingCancelBtn.disabled = false; // Enable cancel button
285
+ }
286
+ // Allow default form submission to proceed
287
+ // The page will reload with results rendered by the template
288
+ });
289
+ }
290
+
291
+ // Handle Cancel button click
292
+ if (loadingCancelBtn) {
293
+ loadingCancelBtn.addEventListener('click', () => {
294
+ if (isGenerating) {
295
+ console.log("Cancel button clicked.");
296
+ isGenerationCancelled = true;
297
+ isGenerating = false; // Stop considering it "generating" from UI perspective
298
+ if (loadingOverlay && generateBtn) {
299
+ loadingOverlay.classList.add('hidden'); // Hide overlay
300
+ generateBtn.disabled = false; // Re-enable generate button
301
+ generateBtn.classList.remove('opacity-50', 'cursor-not-allowed');
302
+ }
303
+ showNotification("Generation cancelled by user.", 'info');
304
+ // Note: Backend request continues, but result will be ignored on page reload/update
305
+ }
306
+ });
307
+ }
308
+
309
+ // --- Result Handling (on page load after form submission) ---
310
+ // This logic runs every time the page loads. We check if specific elements
311
+ // indicating a successful generation are present.
312
+ const outputUrlElement = document.getElementById('output-file-url-data'); // Need to add this element in HTML
313
+ if (outputUrlElement && outputUrlElement.dataset.url) {
314
+ const outputUrl = outputUrlElement.dataset.url;
315
+ console.log("Page loaded with generation result:", outputUrl);
316
+
317
+ if (isGenerationCancelled) {
318
+ console.log("Generation was cancelled, ignoring result.");
319
+ showNotification("Previous generation was cancelled.", "warning");
320
+ // Reset flag after checking
321
+ isGenerationCancelled = false;
322
+ } else {
323
+ console.log("Processing successful generation result.");
324
+ // The audio player structure should be rendered by the template.
325
+ // We just need to initialize wavesurfer for it.
326
+ initializeWaveSurfer(outputUrl);
327
+ }
328
+ }
329
+ // Always reset generating flag on page load, as any active generation is now finished or irrelevant
330
+ isGenerating = false;
331
+ if (generateBtn) { // Re-enable button if page reloads for any reason
332
+ generateBtn.disabled = false;
333
+ generateBtn.classList.remove('opacity-50', 'cursor-not-allowed');
334
+ }
335
+
336
+
337
+ // --- Configuration Management ---
338
+ async function updateConfigStatus(button, statusElement, message, success = true, duration = 5000) {
339
+ const successClass = 'text-green-500 dark:text-green-400';
340
+ const errorClass = 'text-red-500 dark:text-red-400';
341
+ const savingClass = 'text-yellow-500 dark:text-yellow-400';
342
+
343
+ statusElement.textContent = message;
344
+ statusElement.className = `text-xs ml-2 ${success ? successClass : (message.startsWith('Saving') || message.startsWith('Restarting') ? savingClass : errorClass)}`;
345
+ statusElement.classList.remove('hidden');
346
+ if (button) button.disabled = true; // Disable button while processing
347
+
348
+ // Clear status after duration, re-enable button
349
+ if (duration > 0) {
350
+ setTimeout(() => {
351
+ statusElement.classList.add('hidden');
352
+ if (button) button.disabled = false;
353
+ }, duration);
354
+ }
355
+ }
356
+
357
+ // Save Server Configuration
358
+ if (configSaveBtn) {
359
+ configSaveBtn.addEventListener('click', async () => {
360
+ const configData = {};
361
+ document.querySelectorAll('#server-config-form input[name]').forEach(input => { // Assume inputs are within a form/div
362
+ configData[input.name] = input.value;
363
+ });
364
+
365
+ updateConfigStatus(configSaveBtn, configStatus, 'Saving...', true, 0); // Indefinite until success/error
366
+
367
+ try {
368
+ const response = await fetch('/save_config', {
369
+ method: 'POST',
370
+ headers: { 'Content-Type': 'application/json' },
371
+ body: JSON.stringify(configData)
372
+ });
373
+ const result = await response.json();
374
+ if (!response.ok) throw new Error(result.detail || 'Failed to save');
375
+
376
+ updateConfigStatus(configSaveBtn, configStatus, result.message, true);
377
+ if (configRestartBtn) configRestartBtn.classList.remove('hidden'); // Show restart button
378
+
379
+ } catch (error) {
380
+ console.error('Error saving server config:', error);
381
+ updateConfigStatus(configSaveBtn, configStatus, `Error: ${error.message}`, false);
382
+ }
383
+ });
384
+ }
385
+
386
+ // Restart Server
387
+ if (configRestartBtn) {
388
+ configRestartBtn.addEventListener('click', async () => {
389
+ configRestartBtn.disabled = true;
390
+ configRestartBtn.innerHTML = `
391
+ <svg class="animate-spin h-5 w-5 mr-1 inline-block" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
392
+ <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
393
+ <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
394
+ </svg>
395
+ Restarting...`;
396
+ updateConfigStatus(configRestartBtn, configStatus, 'Restarting...', true, 0); // Indefinite
397
+
398
+ try {
399
+ const response = await fetch('/restart_server', { method: 'POST' });
400
+ const result = await response.json();
401
+ if (!response.ok) throw new Error(result.detail || 'Failed to trigger restart');
402
+
403
+ updateConfigStatus(configRestartBtn, configStatus, result.message + " Page will attempt reload.", true, 15000); // Show longer
404
+ // Show main loading overlay during restart check
405
+ if (loadingOverlay) {
406
+ loadingMessage.textContent = 'Server restarting...';
407
+ loadingStatus.textContent = 'Waiting for server to respond...';
408
+ loadingCancelBtn.disabled = true; // Disable cancel during restart
409
+ loadingOverlay.classList.remove('hidden');
410
+ }
411
+
412
+ // Poll for server readiness
413
+ let attempts = 0;
414
+ const maxAttempts = 45; // Wait up to 45 seconds
415
+ function checkServerReady() {
416
+ attempts++;
417
+ console.log(`Checking server readiness (Attempt ${attempts}/${maxAttempts})...`);
418
+ loadingStatus.textContent = `Waiting for server... (${attempts}/${maxAttempts})`;
419
+ fetch('/health?cache=' + Date.now(), { cache: 'no-store', headers: { 'pragma': 'no-cache' } })
420
+ .then(res => {
421
+ if (res.ok) {
422
+ console.log("Server is ready. Reloading page.");
423
+ window.location.reload(true); // Force reload from server
424
+ } else if (attempts < maxAttempts) {
425
+ setTimeout(checkServerReady, 1000); // Check again in 1 second
426
+ } else {
427
+ throw new Error('Server did not become ready after restart.');
428
+ }
429
+ })
430
+ .catch(() => {
431
+ if (attempts < maxAttempts) {
432
+ setTimeout(checkServerReady, 1000); // Check again on connection error
433
+ } else {
434
+ throw new Error('Server did not respond after restart.');
435
+ }
436
+ });
437
+ }
438
+ setTimeout(checkServerReady, 3000); // Start checking after 3 seconds
439
+
440
+ } catch (error) {
441
+ console.error('Error restarting server:', error);
442
+ updateConfigStatus(configRestartBtn, configStatus, `Restart Error: ${error.message}`, false);
443
+ configRestartBtn.disabled = false; // Re-enable button on error
444
+ configRestartBtn.innerHTML = `
445
+ <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-5 h-5 mr-1 inline-block"><path stroke-linecap="round" stroke-linejoin="round" d="M16.023 9.348h4.992v-.001M2.985 19.644v-4.992m0 0h4.992m-4.993 0 3.181 3.183a8.25 8.25 0 0 0 13.803-3.7M4.031 9.865a8.25 8.25 0 0 1 13.803-3.7l3.181 3.182m0-4.991v4.99" /></svg>
446
+ Restart Server`;
447
+ if (loadingOverlay) loadingOverlay.classList.add('hidden');
448
+ }
449
+ });
450
+ }
451
+
452
+ // Save Generation Defaults
453
+ if (genDefaultsSaveBtn) {
454
+ genDefaultsSaveBtn.addEventListener('click', async () => {
455
+ const genParams = {};
456
+ sliders.forEach(s => {
457
+ const slider = document.getElementById(s.id);
458
+ if (slider) genParams[s.id] = slider.value;
459
+ });
460
+
461
+ updateConfigStatus(genDefaultsSaveBtn, genDefaultsStatus, 'Saving...', true, 0);
462
+
463
+ try {
464
+ const response = await fetch('/save_generation_defaults', {
465
+ method: 'POST',
466
+ headers: { 'Content-Type': 'application/json' },
467
+ body: JSON.stringify(genParams)
468
+ });
469
+ const result = await response.json();
470
+ if (!response.ok) throw new Error(result.detail || 'Failed to save');
471
+ updateConfigStatus(genDefaultsSaveBtn, genDefaultsStatus, result.message, true);
472
+
473
+ } catch (error) {
474
+ console.error('Error saving generation defaults:', error);
475
+ updateConfigStatus(genDefaultsSaveBtn, genDefaultsStatus, `Error: ${error.message}`, false);
476
+ }
477
+ });
478
+ }
479
+
480
+ // --- Reference Audio Upload ---
481
+ if (cloneLoadButton && cloneFileInput && cloneReferenceSelect) {
482
+ cloneLoadButton.addEventListener('click', () => {
483
+ cloneFileInput.click(); // Trigger hidden file input
484
+ });
485
+
486
+ cloneFileInput.addEventListener('change', async (event) => {
487
+ const files = event.target.files;
488
+ if (!files || files.length === 0) {
489
+ return; // No files selected
490
+ }
491
+
492
+ cloneLoadButton.disabled = true;
493
+ cloneLoadButton.textContent = 'Uploading...';
494
+ showNotification(`Uploading ${files.length} file(s)...`, 'info', 0); // Indefinite
495
+
496
+ const formData = new FormData();
497
+ for (const file of files) {
498
+ formData.append('files', file);
499
+ }
500
+
501
+ try {
502
+ const response = await fetch('/upload_reference', {
503
+ method: 'POST',
504
+ body: formData
505
+ // Content-Type is set automatically for FormData
506
+ });
507
+
508
+ const result = await response.json();
509
+
510
+ // Clear existing notifications before showing results
511
+ notificationArea.innerHTML = '';
512
+
513
+ if (!response.ok) {
514
+ throw new Error(result.message || `Upload failed with status ${response.status}`);
515
+ }
516
+
517
+ // Process results
518
+ if (result.errors && result.errors.length > 0) {
519
+ result.errors.forEach(err => showNotification(err, 'error'));
520
+ }
521
+ if (result.uploaded_files && result.uploaded_files.length > 0) {
522
+ showNotification(`Successfully uploaded: ${result.uploaded_files.join(', ')}`, 'success');
523
+ } else if (!result.errors || result.errors.length === 0) {
524
+ showNotification("Files processed, but no new files were added (might already exist).", 'info');
525
+ }
526
+
527
+
528
+ // Update dropdown
529
+ const currentSelection = cloneReferenceSelect.value;
530
+ cloneReferenceSelect.innerHTML = '<option value="none">-- Select Reference File --</option>'; // Clear existing options
531
+ result.all_reference_files.forEach(filename => {
532
+ const option = document.createElement('option');
533
+ option.value = filename;
534
+ option.textContent = filename;
535
+ cloneReferenceSelect.appendChild(option);
536
+ });
537
+
538
+ // Select the first newly uploaded file, or keep current selection if still valid
539
+ const firstUploaded = result.uploaded_files ? result.uploaded_files[0] : null;
540
+ if (firstUploaded) {
541
+ cloneReferenceSelect.value = firstUploaded;
542
+ } else if (result.all_reference_files.includes(currentSelection)) {
543
+ cloneReferenceSelect.value = currentSelection; // Restore previous valid selection
544
+ } else {
545
+ cloneReferenceSelect.value = 'none'; // Default if nothing else matches
546
+ }
547
+
548
+ } catch (error) {
549
+ console.error('Error uploading reference files:', error);
550
+ showNotification(`Upload Error: ${error.message}`, 'error');
551
+ } finally {
552
+ cloneLoadButton.disabled = false;
553
+ cloneLoadButton.textContent = 'Load';
554
+ cloneFileInput.value = ''; // Reset file input
555
+ }
556
+ });
557
+ }
558
+
559
+ // --- Theme Toggle ---
560
+ function applyTheme(theme) {
561
+ if (theme === 'light') {
562
+ document.documentElement.classList.remove('dark');
563
+ if (themeIconLight) themeIconLight.classList.remove('hidden');
564
+ if (themeIconDark) themeIconDark.classList.add('hidden');
565
+ } else {
566
+ document.documentElement.classList.add('dark');
567
+ if (themeIconLight) themeIconLight.classList.add('hidden');
568
+ if (themeIconDark) themeIconDark.classList.remove('hidden');
569
+ }
570
+ // Update wavesurfer colors if it exists
571
+ if (wavesurfer) {
572
+ wavesurfer.setOptions({
573
+ waveColor: theme === 'light' ? '#0ea5e9' : '#38bdf8',
574
+ progressColor: theme === 'light' ? '#0369a1' : '#0284c7',
575
+ cursorColor: theme === 'light' ? '#9333ea' : '#a855f7',
576
+ });
577
+ }
578
+ }
579
+
580
+ if (themeToggleButton) {
581
+ // Check localStorage on load
582
+ const savedTheme = localStorage.getItem('theme') || 'dark'; // Default to dark
583
+ applyTheme(savedTheme);
584
+
585
+ themeToggleButton.addEventListener('click', () => {
586
+ const isDark = document.documentElement.classList.contains('dark');
587
+ const newTheme = isDark ? 'light' : 'dark';
588
+ applyTheme(newTheme);
589
+ localStorage.setItem('theme', newTheme); // Save preference
590
+ });
591
+ }
592
+
593
+ }); // End DOMContentLoaded
utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ # Utility functions for the Dia TTS server
3
+
4
+ import logging
5
+ import time
6
+ import os
7
+ import io
8
+ import numpy as np
9
+ import soundfile as sf
10
+ from typing import Optional, Tuple
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # --- Audio Processing ---
15
+
16
+
17
+ def encode_audio(
18
+ audio_array: np.ndarray, sample_rate: int, output_format: str = "opus"
19
+ ) -> Optional[bytes]:
20
+ """
21
+ Encodes a NumPy audio array into the specified format in memory.
22
+
23
+ Args:
24
+ audio_array: NumPy array containing audio data (float32, range [-1, 1]).
25
+ sample_rate: Sample rate of the audio data.
26
+ output_format: Desired output format ('opus' or 'wav').
27
+
28
+ Returns:
29
+ Bytes object containing the encoded audio, or None on failure.
30
+ """
31
+ if audio_array is None or audio_array.size == 0:
32
+ logger.warning("encode_audio received empty or None audio array.")
33
+ return None
34
+
35
+ start_time = time.time()
36
+ output_buffer = io.BytesIO()
37
+
38
+ try:
39
+ if output_format == "opus":
40
+ # Soundfile expects int16 for Opus usually, but let's try float32 first
41
+ # It might convert internally or require specific subtypes.
42
+ # If this fails, we might need to convert to int16 first:
43
+ # audio_int16 = (audio_array * 32767).astype(np.int16)
44
+ # sf.write(output_buffer, audio_int16, sample_rate, format='ogg', subtype='opus')
45
+ sf.write(
46
+ output_buffer, audio_array, sample_rate, format="ogg", subtype="opus"
47
+ )
48
+ content_type = "audio/ogg; codecs=opus"
49
+ elif output_format == "wav":
50
+ # WAV typically uses int16
51
+ audio_int16 = (audio_array * 32767).astype(np.int16)
52
+ sf.write(
53
+ output_buffer, audio_int16, sample_rate, format="wav", subtype="pcm_16"
54
+ )
55
+ content_type = "audio/wav"
56
+ else:
57
+ logger.error(f"Unsupported output format requested: {output_format}")
58
+ return None
59
+
60
+ encoded_bytes = output_buffer.getvalue()
61
+ end_time = time.time()
62
+ logger.info(
63
+ f"Encoded {len(encoded_bytes)} bytes to {output_format} in {end_time - start_time:.3f} seconds."
64
+ )
65
+ return encoded_bytes
66
+
67
+ except ImportError:
68
+ logger.critical(
69
+ "`soundfile` or its dependency `libsndfile` not found/installed correctly. Cannot encode audio."
70
+ )
71
+ raise # Re-raise critical error
72
+ except Exception as e:
73
+ logger.error(f"Error encoding audio to {output_format}: {e}", exc_info=True)
74
+ return None
75
+
76
+
77
+ def save_audio_to_file(
78
+ audio_array: np.ndarray, sample_rate: int, file_path: str
79
+ ) -> bool:
80
+ """
81
+ Saves a NumPy audio array to a WAV file.
82
+
83
+ Args:
84
+ audio_array: NumPy array containing audio data (float32, range [-1, 1]).
85
+ sample_rate: Sample rate of the audio data.
86
+ file_path: Path to save the WAV file.
87
+
88
+ Returns:
89
+ True if saving was successful, False otherwise.
90
+ """
91
+ if audio_array is None or audio_array.size == 0:
92
+ logger.warning("save_audio_to_file received empty or None audio array.")
93
+ return False
94
+ if not file_path.lower().endswith(".wav"):
95
+ logger.warning(
96
+ f"File path '{file_path}' does not end with .wav. Saving as WAV anyway."
97
+ )
98
+ # Optionally change the extension: file_path += ".wav"
99
+
100
+ start_time = time.time()
101
+ try:
102
+ # Ensure output directory exists
103
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
104
+
105
+ # WAV typically uses int16
106
+ audio_int16 = (audio_array * 32767).astype(np.int16)
107
+ sf.write(file_path, audio_int16, sample_rate, format="wav", subtype="pcm_16")
108
+
109
+ end_time = time.time()
110
+ logger.info(
111
+ f"Saved WAV file to {file_path} in {end_time - start_time:.3f} seconds."
112
+ )
113
+ return True
114
+ except ImportError:
115
+ logger.critical(
116
+ "`soundfile` or its dependency `libsndfile` not found/installed correctly. Cannot save audio."
117
+ )
118
+ return False # Indicate failure
119
+ except Exception as e:
120
+ logger.error(f"Error saving WAV file to {file_path}: {e}", exc_info=True)
121
+ return False
122
+
123
+
124
+ # --- Other Utilities (Optional) ---
125
+
126
+
127
+ class PerformanceMonitor:
128
+ """Simple performance monitoring."""
129
+
130
+ def __init__(self):
131
+ self.start_time = time.time()
132
+ self.events = []
133
+
134
+ def record(self, event_name: str):
135
+ self.events.append((event_name, time.time()))
136
+
137
+ def report(self) -> str:
138
+ report_lines = ["Performance Report:"]
139
+ last_time = self.start_time
140
+ total_duration = time.time() - self.start_time
141
+ for name, timestamp in self.events:
142
+ duration = timestamp - last_time
143
+ report_lines.append(f" - {name}: {duration:.3f}s")
144
+ last_time = timestamp
145
+ report_lines.append(f"Total Duration: {total_duration:.3f}s")
146
+ return "\n".join(report_lines)