# engine.py # Core Dia TTS model loading and generation logic import logging import time import os import torch import numpy as np from typing import Optional, Tuple from huggingface_hub import hf_hub_download # Import downloader # Import Dia model class and config try: from dia.model import Dia from dia.config import DiaConfig except ImportError as e: # Log critical error if core components are missing logging.critical( f"Failed to import Dia model components: {e}. Ensure the 'dia' package exists and is importable.", exc_info=True, ) # Define dummy classes/functions to prevent server crash on import, # but generation will fail later if these are used. class Dia: @staticmethod def load_model_from_files(*args, **kwargs): raise RuntimeError("Dia model package not available or failed to import.") def generate(*args, **kwargs): raise RuntimeError("Dia model package not available or failed to import.") class DiaConfig: pass # Import configuration getters from our project's config.py from config import ( get_model_repo_id, get_model_cache_path, get_reference_audio_path, get_model_config_filename, get_model_weights_filename, ) logger = logging.getLogger(__name__) # Use standard logger name # --- Global Variables --- dia_model: Optional[Dia] = None # model_config is now loaded within Dia.load_model_from_files, maybe remove global? # Let's keep it for now if needed elsewhere, but populate it after loading. model_config_instance: Optional[DiaConfig] = None model_device: Optional[torch.device] = None MODEL_LOADED = False EXPECTED_SAMPLE_RATE = 44100 # Dia model and DAC typically operate at 44.1kHz # --- Model Loading --- def get_device() -> torch.device: """Determines the optimal torch device (CUDA > MPS > CPU).""" if torch.cuda.is_available(): logger.info("CUDA is available, using GPU.") return torch.device("cuda") # Add MPS check for Apple Silicon GPUs elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # Basic check is usually sufficient logger.info("MPS is available, using Apple Silicon GPU.") return torch.device("mps") else: logger.info("CUDA and MPS not available, using CPU.") return torch.device("cpu") def load_model(): """ Loads the Dia TTS model and associated DAC model. Downloads model files based on configuration if they don't exist locally. Handles both .pth and .safetensors formats. """ global dia_model, model_config_instance, model_device, MODEL_LOADED if MODEL_LOADED: logger.info("Dia model already loaded.") return True # Get configuration values repo_id = get_model_repo_id() config_filename = get_model_config_filename() weights_filename = get_model_weights_filename() cache_path = get_model_cache_path() # Already absolute path model_device = get_device() logger.info(f"Attempting to load Dia model:") logger.info(f" Repo ID: {repo_id}") logger.info(f" Config File: {config_filename}") logger.info(f" Weights File: {weights_filename}") logger.info(f" Cache Directory: {cache_path}") logger.info(f" Target Device: {model_device}") # Ensure cache directory exists try: os.makedirs(cache_path, exist_ok=True) except OSError as e: logger.error( f"Failed to create cache directory '{cache_path}': {e}", exc_info=True ) # Depending on severity, might want to return False here # return False pass # Continue and let hf_hub_download handle potential issues try: start_time = time.time() # --- Download Model Files --- logger.info( f"Downloading/finding configuration file '{config_filename}' from repo '{repo_id}'..." ) local_config_path = hf_hub_download( repo_id=repo_id, filename=config_filename, cache_dir=cache_path, # force_download=False, # Default: only download if missing or outdated # resume_download=True, # Default: resume interrupted downloads ) logger.info(f"Configuration file path: {local_config_path}") logger.info( f"Downloading/finding weights file '{weights_filename}' from repo '{repo_id}'..." ) local_weights_path = hf_hub_download( repo_id=repo_id, filename=weights_filename, cache_dir=cache_path, ) logger.info(f"Weights file path: {local_weights_path}") # --- Load Model using the class method --- # The Dia class method now handles config loading, instantiation, weight loading, etc. dia_model = Dia.load_model_from_files( config_path=local_config_path, weights_path=local_weights_path, device=model_device, ) # Store the config instance if needed globally (optional) model_config_instance = dia_model.config end_time = time.time() logger.info( f"Dia model loaded successfully in {end_time - start_time:.2f} seconds." ) MODEL_LOADED = True return True except FileNotFoundError as e: logger.error( f"Model loading failed: Required file not found. {e}", exc_info=True ) MODEL_LOADED = False return False except ImportError: # This catches if the 'dia' package itself is missing logger.critical( "Failed to load model: Dia package or its core dependencies not found.", exc_info=True, ) MODEL_LOADED = False return False except Exception as e: # Catch other potential errors during download or loading logger.error( f"Error loading Dia model from repo '{repo_id}': {e}", exc_info=True ) dia_model = None model_config_instance = None MODEL_LOADED = False return False # --- Speech Generation --- def generate_speech( text: str, voice_mode: str = "single_s1", clone_reference_filename: Optional[str] = None, max_tokens: Optional[int] = None, cfg_scale: float = 3.0, temperature: float = 1.3, top_p: float = 0.95, speed_factor: float = 0.94, # Keep speed factor separate from model generation params cfg_filter_top_k: int = 35, ) -> Optional[Tuple[np.ndarray, int]]: """ Generates speech using the loaded Dia model, handling voice modes and speed adjustment. Args: text: Text to synthesize. voice_mode: 'dialogue', 'single_s1', 'single_s2', 'clone'. clone_reference_filename: Filename for voice cloning (if mode is 'clone'). Located in reference audio path. max_tokens: Max generation tokens for the model's generate method. cfg_scale: CFG scale for the model's generate method. temperature: Sampling temperature for the model's generate method. top_p: Nucleus sampling p for the model's generate method. speed_factor: Factor to adjust the playback speed *after* generation (e.g., 0.9 = slower, 1.1 = faster). cfg_filter_top_k: CFG filter top K for the model's generate method. Returns: Tuple of (numpy_audio_array, sample_rate), or None on failure. """ if not MODEL_LOADED or dia_model is None: logger.error("Dia model is not loaded. Cannot generate speech.") return None logger.info(f"Generating speech with mode: {voice_mode}") logger.debug(f"Input text (start): '{text[:100]}...'") # Log model generation parameters logger.debug( f"Model Params: max_tokens={max_tokens}, cfg={cfg_scale}, temp={temperature}, top_p={top_p}, top_k={cfg_filter_top_k}" ) # Log post-processing parameters logger.debug(f"Post-processing Params: speed_factor={speed_factor}") audio_prompt_path = None processed_text = text # Start with original text # --- Handle Voice Mode --- if voice_mode == "clone": if not clone_reference_filename: logger.error("Clone mode selected but no reference filename provided.") return None ref_base_path = get_reference_audio_path() # Gets absolute path potential_path = os.path.join(ref_base_path, clone_reference_filename) if os.path.isfile(potential_path): audio_prompt_path = potential_path logger.info(f"Using audio prompt for cloning: {audio_prompt_path}") # Dia requires the transcript of the clone audio to be prepended to the target text. # The UI/API caller is responsible for constructing this combined text. logger.warning( "Clone mode active. Ensure the 'text' input includes the transcript of the reference audio for best results (e.g., '[S1] Reference transcript. [S1] Target text...')." ) processed_text = text # Use the combined text provided by the caller else: logger.error(f"Reference audio file not found: {potential_path}") return None # Fail generation if reference file is missing elif voice_mode == "dialogue": # Assume text already contains [S1]/[S2] tags as required by the model logger.info("Using dialogue mode. Expecting [S1]/[S2] tags in input text.") if "[S1]" not in text and "[S2]" not in text: logger.warning( "Dialogue mode selected, but no [S1] or [S2] tags found in the input text." ) processed_text = text # Pass directly elif voice_mode == "single_s1": logger.info("Using single voice mode (S1).") # Check if text *already* contains tags, warn if so, as it might confuse the model if "[S1]" in text or "[S2]" in text: logger.warning( "Input text contains dialogue tags ([S1]/[S2]), but 'single_s1' mode was selected. Model behavior might be unexpected." ) # Dia likely expects tags even for single speaker. Prepending [S1] might be safer. # Let's assume for now the model handles untagged text as S1, but this could be adjusted. # Consider: processed_text = f"[S1] {text}" # Option to enforce S1 tag processed_text = text # Pass directly for now elif voice_mode == "single_s2": logger.info("Using single voice mode (S2).") if "[S1]" in text or "[S2]" in text: logger.warning( "Input text contains dialogue tags ([S1]/[S2]), but 'single_s2' mode was selected." ) # Similar to S1, how to signal S2? Prepending [S2] seems logical if needed. # Consider: processed_text = f"[S2] {text}" # Option to enforce S2 tag processed_text = text # Pass directly for now else: logger.error( f"Unsupported voice_mode: {voice_mode}. Defaulting to 'single_s1'." ) processed_text = text # Fallback # --- Call Dia Generate --- try: start_time = time.time() logger.info("Calling Dia model generate method...") # Call the model's generate method with appropriate parameters generated_audio_np = dia_model.generate( text=processed_text, audio_prompt_path=audio_prompt_path, max_tokens=max_tokens, # Pass None if not specified, Dia uses its default cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, use_cfg_filter=True, # Default from Dia's app.py, seems reasonable cfg_filter_top_k=cfg_filter_top_k, use_torch_compile=False, # Keep False for stability unless specifically tested/enabled ) gen_end_time = time.time() logger.info( f"Dia model generation finished in {gen_end_time - start_time:.2f} seconds." ) if generated_audio_np is None or generated_audio_np.size == 0: logger.warning("Dia model returned None or empty audio array.") return None # --- Apply Speed Factor (Post-processing) --- # This mimics the logic in Dia's original app.py if speed_factor != 1.0: logger.info(f"Applying speed factor: {speed_factor}") original_len = len(generated_audio_np) # Ensure speed_factor is within a reasonable range to avoid extreme distortion # Adjust range based on observed quality (e.g., 0.5 to 2.0) speed_factor = max(0.5, min(speed_factor, 2.0)) target_len = int(original_len / speed_factor) if target_len > 0 and target_len != original_len: logger.debug( f"Resampling audio from {original_len} to {target_len} samples." ) # Create time axes for original and resampled audio x_original = np.linspace(0, original_len - 1, original_len) x_resampled = np.linspace(0, original_len - 1, target_len) # Interpolate using numpy resampled_audio_np = np.interp( x_resampled, x_original, generated_audio_np ) final_audio_np = resampled_audio_np.astype(np.float32) # Ensure float32 logger.info(f"Audio resampled for {speed_factor:.2f}x speed.") else: logger.warning( f"Skipping speed adjustment (factor: {speed_factor:.2f}). Target length invalid ({target_len}) or no change needed." ) final_audio_np = generated_audio_np # Use original audio else: logger.info("Speed factor is 1.0, no speed adjustment needed.") final_audio_np = generated_audio_np # No speed change needed # Ensure output is float32 (DAC output should be, but good practice) if final_audio_np.dtype != np.float32: logger.warning( f"Generated audio was not float32 ({final_audio_np.dtype}), converting." ) final_audio_np = final_audio_np.astype(np.float32) logger.info( f"Final audio ready. Shape: {final_audio_np.shape}, dtype: {final_audio_np.dtype}" ) # Return the processed audio and the expected sample rate return final_audio_np, EXPECTED_SAMPLE_RATE except Exception as e: logger.error( f"Error during Dia generation or post-processing: {e}", exc_info=True ) return None # Return None on failure