File size: 15,011 Bytes
ac5de5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# engine.py
# Core Dia TTS model loading and generation logic

import logging
import time
import os
import torch
import numpy as np
from typing import Optional, Tuple
from huggingface_hub import hf_hub_download  # Import downloader

# Import Dia model class and config
try:
    from dia.model import Dia
    from dia.config import DiaConfig
except ImportError as e:
    # Log critical error if core components are missing
    logging.critical(
        f"Failed to import Dia model components: {e}. Ensure the 'dia' package exists and is importable.",
        exc_info=True,
    )

    # Define dummy classes/functions to prevent server crash on import,
    # but generation will fail later if these are used.
    class Dia:
        @staticmethod
        def load_model_from_files(*args, **kwargs):
            raise RuntimeError("Dia model package not available or failed to import.")

        def generate(*args, **kwargs):
            raise RuntimeError("Dia model package not available or failed to import.")

    class DiaConfig:
        pass


# Import configuration getters from our project's config.py
from config import (
    get_model_repo_id,
    get_model_cache_path,
    get_reference_audio_path,
    get_model_config_filename,
    get_model_weights_filename,
)

logger = logging.getLogger(__name__)  # Use standard logger name

# --- Global Variables ---
dia_model: Optional[Dia] = None
# model_config is now loaded within Dia.load_model_from_files, maybe remove global?
# Let's keep it for now if needed elsewhere, but populate it after loading.
model_config_instance: Optional[DiaConfig] = None
model_device: Optional[torch.device] = None
MODEL_LOADED = False
EXPECTED_SAMPLE_RATE = 44100  # Dia model and DAC typically operate at 44.1kHz

# --- Model Loading ---


def get_device() -> torch.device:
    """Determines the optimal torch device (CUDA > MPS > CPU)."""
    if torch.cuda.is_available():
        logger.info("CUDA is available, using GPU.")
        return torch.device("cuda")
    # Add MPS check for Apple Silicon GPUs
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        # Basic check is usually sufficient
        logger.info("MPS is available, using Apple Silicon GPU.")
        return torch.device("mps")
    else:
        logger.info("CUDA and MPS not available, using CPU.")
        return torch.device("cpu")


def load_model():
    """

    Loads the Dia TTS model and associated DAC model.

    Downloads model files based on configuration if they don't exist locally.

    Handles both .pth and .safetensors formats.

    """
    global dia_model, model_config_instance, model_device, MODEL_LOADED

    if MODEL_LOADED:
        logger.info("Dia model already loaded.")
        return True

    # Get configuration values
    repo_id = get_model_repo_id()
    config_filename = get_model_config_filename()
    weights_filename = get_model_weights_filename()
    cache_path = get_model_cache_path()  # Already absolute path
    model_device = get_device()

    logger.info(f"Attempting to load Dia model:")
    logger.info(f"  Repo ID: {repo_id}")
    logger.info(f"  Config File: {config_filename}")
    logger.info(f"  Weights File: {weights_filename}")
    logger.info(f"  Cache Directory: {cache_path}")
    logger.info(f"  Target Device: {model_device}")

    # Ensure cache directory exists
    try:
        os.makedirs(cache_path, exist_ok=True)
    except OSError as e:
        logger.error(
            f"Failed to create cache directory '{cache_path}': {e}", exc_info=True
        )
        # Depending on severity, might want to return False here
        # return False
        pass  # Continue and let hf_hub_download handle potential issues

    try:
        start_time = time.time()

        # --- Download Model Files ---
        logger.info(
            f"Downloading/finding configuration file '{config_filename}' from repo '{repo_id}'..."
        )
        local_config_path = hf_hub_download(
            repo_id=repo_id,
            filename=config_filename,
            cache_dir=cache_path,
            # force_download=False, # Default: only download if missing or outdated
            # resume_download=True, # Default: resume interrupted downloads
        )
        logger.info(f"Configuration file path: {local_config_path}")

        logger.info(
            f"Downloading/finding weights file '{weights_filename}' from repo '{repo_id}'..."
        )
        local_weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=weights_filename,
            cache_dir=cache_path,
        )
        logger.info(f"Weights file path: {local_weights_path}")

        # --- Load Model using the class method ---
        # The Dia class method now handles config loading, instantiation, weight loading, etc.
        dia_model = Dia.load_model_from_files(
            config_path=local_config_path,
            weights_path=local_weights_path,
            device=model_device,
        )

        # Store the config instance if needed globally (optional)
        model_config_instance = dia_model.config

        end_time = time.time()
        logger.info(
            f"Dia model loaded successfully in {end_time - start_time:.2f} seconds."
        )
        MODEL_LOADED = True
        return True

    except FileNotFoundError as e:
        logger.error(
            f"Model loading failed: Required file not found. {e}", exc_info=True
        )
        MODEL_LOADED = False
        return False
    except ImportError:
        # This catches if the 'dia' package itself is missing
        logger.critical(
            "Failed to load model: Dia package or its core dependencies not found.",
            exc_info=True,
        )
        MODEL_LOADED = False
        return False
    except Exception as e:
        # Catch other potential errors during download or loading
        logger.error(
            f"Error loading Dia model from repo '{repo_id}': {e}", exc_info=True
        )
        dia_model = None
        model_config_instance = None
        MODEL_LOADED = False
        return False


# --- Speech Generation ---


def generate_speech(

    text: str,

    voice_mode: str = "single_s1",

    clone_reference_filename: Optional[str] = None,

    max_tokens: Optional[int] = None,

    cfg_scale: float = 3.0,

    temperature: float = 1.3,

    top_p: float = 0.95,

    speed_factor: float = 0.94,  # Keep speed factor separate from model generation params

    cfg_filter_top_k: int = 35,

) -> Optional[Tuple[np.ndarray, int]]:
    """

    Generates speech using the loaded Dia model, handling voice modes and speed adjustment.



    Args:

        text: Text to synthesize.

        voice_mode: 'dialogue', 'single_s1', 'single_s2', 'clone'.

        clone_reference_filename: Filename for voice cloning (if mode is 'clone'). Located in reference audio path.

        max_tokens: Max generation tokens for the model's generate method.

        cfg_scale: CFG scale for the model's generate method.

        temperature: Sampling temperature for the model's generate method.

        top_p: Nucleus sampling p for the model's generate method.

        speed_factor: Factor to adjust the playback speed *after* generation (e.g., 0.9 = slower, 1.1 = faster).

        cfg_filter_top_k: CFG filter top K for the model's generate method.



    Returns:

        Tuple of (numpy_audio_array, sample_rate), or None on failure.

    """
    if not MODEL_LOADED or dia_model is None:
        logger.error("Dia model is not loaded. Cannot generate speech.")
        return None

    logger.info(f"Generating speech with mode: {voice_mode}")
    logger.debug(f"Input text (start): '{text[:100]}...'")
    # Log model generation parameters
    logger.debug(
        f"Model Params: max_tokens={max_tokens}, cfg={cfg_scale}, temp={temperature}, top_p={top_p}, top_k={cfg_filter_top_k}"
    )
    # Log post-processing parameters
    logger.debug(f"Post-processing Params: speed_factor={speed_factor}")

    audio_prompt_path = None
    processed_text = text  # Start with original text

    # --- Handle Voice Mode ---
    if voice_mode == "clone":
        if not clone_reference_filename:
            logger.error("Clone mode selected but no reference filename provided.")
            return None
        ref_base_path = get_reference_audio_path()  # Gets absolute path
        potential_path = os.path.join(ref_base_path, clone_reference_filename)
        if os.path.isfile(potential_path):
            audio_prompt_path = potential_path
            logger.info(f"Using audio prompt for cloning: {audio_prompt_path}")
            # Dia requires the transcript of the clone audio to be prepended to the target text.
            # The UI/API caller is responsible for constructing this combined text.
            logger.warning(
                "Clone mode active. Ensure the 'text' input includes the transcript of the reference audio for best results (e.g., '[S1] Reference transcript. [S1] Target text...')."
            )
            processed_text = text  # Use the combined text provided by the caller
        else:
            logger.error(f"Reference audio file not found: {potential_path}")
            return None  # Fail generation if reference file is missing
    elif voice_mode == "dialogue":
        # Assume text already contains [S1]/[S2] tags as required by the model
        logger.info("Using dialogue mode. Expecting [S1]/[S2] tags in input text.")
        if "[S1]" not in text and "[S2]" not in text:
            logger.warning(
                "Dialogue mode selected, but no [S1] or [S2] tags found in the input text."
            )
        processed_text = text  # Pass directly
    elif voice_mode == "single_s1":
        logger.info("Using single voice mode (S1).")
        # Check if text *already* contains tags, warn if so, as it might confuse the model
        if "[S1]" in text or "[S2]" in text:
            logger.warning(
                "Input text contains dialogue tags ([S1]/[S2]), but 'single_s1' mode was selected. Model behavior might be unexpected."
            )
        # Dia likely expects tags even for single speaker. Prepending [S1] might be safer.
        # Let's assume for now the model handles untagged text as S1, but this could be adjusted.
        # Consider: processed_text = f"[S1] {text}" # Option to enforce S1 tag
        processed_text = text  # Pass directly for now
    elif voice_mode == "single_s2":
        logger.info("Using single voice mode (S2).")
        if "[S1]" in text or "[S2]" in text:
            logger.warning(
                "Input text contains dialogue tags ([S1]/[S2]), but 'single_s2' mode was selected."
            )
        # Similar to S1, how to signal S2? Prepending [S2] seems logical if needed.
        # Consider: processed_text = f"[S2] {text}" # Option to enforce S2 tag
        processed_text = text  # Pass directly for now
    else:
        logger.error(
            f"Unsupported voice_mode: {voice_mode}. Defaulting to 'single_s1'."
        )
        processed_text = text  # Fallback

    # --- Call Dia Generate ---
    try:
        start_time = time.time()
        logger.info("Calling Dia model generate method...")

        # Call the model's generate method with appropriate parameters
        generated_audio_np = dia_model.generate(
            text=processed_text,
            audio_prompt_path=audio_prompt_path,
            max_tokens=max_tokens,  # Pass None if not specified, Dia uses its default
            cfg_scale=cfg_scale,
            temperature=temperature,
            top_p=top_p,
            use_cfg_filter=True,  # Default from Dia's app.py, seems reasonable
            cfg_filter_top_k=cfg_filter_top_k,
            use_torch_compile=False,  # Keep False for stability unless specifically tested/enabled
        )
        gen_end_time = time.time()
        logger.info(
            f"Dia model generation finished in {gen_end_time - start_time:.2f} seconds."
        )

        if generated_audio_np is None or generated_audio_np.size == 0:
            logger.warning("Dia model returned None or empty audio array.")
            return None

        # --- Apply Speed Factor (Post-processing) ---
        # This mimics the logic in Dia's original app.py
        if speed_factor != 1.0:
            logger.info(f"Applying speed factor: {speed_factor}")
            original_len = len(generated_audio_np)
            # Ensure speed_factor is within a reasonable range to avoid extreme distortion
            # Adjust range based on observed quality (e.g., 0.5 to 2.0)
            speed_factor = max(0.5, min(speed_factor, 2.0))
            target_len = int(original_len / speed_factor)

            if target_len > 0 and target_len != original_len:
                logger.debug(
                    f"Resampling audio from {original_len} to {target_len} samples."
                )
                # Create time axes for original and resampled audio
                x_original = np.linspace(0, original_len - 1, original_len)
                x_resampled = np.linspace(0, original_len - 1, target_len)
                # Interpolate using numpy
                resampled_audio_np = np.interp(
                    x_resampled, x_original, generated_audio_np
                )
                final_audio_np = resampled_audio_np.astype(np.float32)  # Ensure float32
                logger.info(f"Audio resampled for {speed_factor:.2f}x speed.")
            else:
                logger.warning(
                    f"Skipping speed adjustment (factor: {speed_factor:.2f}). Target length invalid ({target_len}) or no change needed."
                )
                final_audio_np = generated_audio_np  # Use original audio
        else:
            logger.info("Speed factor is 1.0, no speed adjustment needed.")
            final_audio_np = generated_audio_np  # No speed change needed

        # Ensure output is float32 (DAC output should be, but good practice)
        if final_audio_np.dtype != np.float32:
            logger.warning(
                f"Generated audio was not float32 ({final_audio_np.dtype}), converting."
            )
            final_audio_np = final_audio_np.astype(np.float32)

        logger.info(
            f"Final audio ready. Shape: {final_audio_np.shape}, dtype: {final_audio_np.dtype}"
        )
        # Return the processed audio and the expected sample rate
        return final_audio_np, EXPECTED_SAMPLE_RATE

    except Exception as e:
        logger.error(
            f"Error during Dia generation or post-processing: {e}", exc_info=True
        )
        return None  # Return None on failure