Spaces:
Running
Running
# engine.py | |
# Core Dia TTS model loading and generation logic | |
import logging | |
import time | |
import os | |
import torch | |
import numpy as np | |
from typing import Optional, Tuple | |
from huggingface_hub import hf_hub_download # Import downloader | |
# Import Dia model class and config | |
try: | |
from dia.model import Dia | |
from dia.config import DiaConfig | |
except ImportError as e: | |
# Log critical error if core components are missing | |
logging.critical( | |
f"Failed to import Dia model components: {e}. Ensure the 'dia' package exists and is importable.", | |
exc_info=True, | |
) | |
# Define dummy classes/functions to prevent server crash on import, | |
# but generation will fail later if these are used. | |
class Dia: | |
def load_model_from_files(*args, **kwargs): | |
raise RuntimeError("Dia model package not available or failed to import.") | |
def generate(*args, **kwargs): | |
raise RuntimeError("Dia model package not available or failed to import.") | |
class DiaConfig: | |
pass | |
# Import configuration getters from our project's config.py | |
from config import ( | |
get_model_repo_id, | |
get_model_cache_path, | |
get_reference_audio_path, | |
get_model_config_filename, | |
get_model_weights_filename, | |
) | |
logger = logging.getLogger(__name__) # Use standard logger name | |
# --- Global Variables --- | |
dia_model: Optional[Dia] = None | |
# model_config is now loaded within Dia.load_model_from_files, maybe remove global? | |
# Let's keep it for now if needed elsewhere, but populate it after loading. | |
model_config_instance: Optional[DiaConfig] = None | |
model_device: Optional[torch.device] = None | |
MODEL_LOADED = False | |
EXPECTED_SAMPLE_RATE = 44100 # Dia model and DAC typically operate at 44.1kHz | |
# --- Model Loading --- | |
def get_device() -> torch.device: | |
"""Determines the optimal torch device (CUDA > MPS > CPU).""" | |
if torch.cuda.is_available(): | |
logger.info("CUDA is available, using GPU.") | |
return torch.device("cuda") | |
# Add MPS check for Apple Silicon GPUs | |
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
# Basic check is usually sufficient | |
logger.info("MPS is available, using Apple Silicon GPU.") | |
return torch.device("mps") | |
else: | |
logger.info("CUDA and MPS not available, using CPU.") | |
return torch.device("cpu") | |
def load_model(): | |
""" | |
Loads the Dia TTS model and associated DAC model. | |
Downloads model files based on configuration if they don't exist locally. | |
Handles both .pth and .safetensors formats. | |
""" | |
global dia_model, model_config_instance, model_device, MODEL_LOADED | |
if MODEL_LOADED: | |
logger.info("Dia model already loaded.") | |
return True | |
# Get configuration values | |
repo_id = get_model_repo_id() | |
config_filename = get_model_config_filename() | |
weights_filename = get_model_weights_filename() | |
cache_path = get_model_cache_path() # Already absolute path | |
model_device = get_device() | |
logger.info(f"Attempting to load Dia model:") | |
logger.info(f" Repo ID: {repo_id}") | |
logger.info(f" Config File: {config_filename}") | |
logger.info(f" Weights File: {weights_filename}") | |
logger.info(f" Cache Directory: {cache_path}") | |
logger.info(f" Target Device: {model_device}") | |
# Ensure cache directory exists | |
try: | |
os.makedirs(cache_path, exist_ok=True) | |
except OSError as e: | |
logger.error( | |
f"Failed to create cache directory '{cache_path}': {e}", exc_info=True | |
) | |
# Depending on severity, might want to return False here | |
# return False | |
pass # Continue and let hf_hub_download handle potential issues | |
try: | |
start_time = time.time() | |
# --- Download Model Files --- | |
logger.info( | |
f"Downloading/finding configuration file '{config_filename}' from repo '{repo_id}'..." | |
) | |
local_config_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=config_filename, | |
cache_dir=cache_path, | |
# force_download=False, # Default: only download if missing or outdated | |
# resume_download=True, # Default: resume interrupted downloads | |
) | |
logger.info(f"Configuration file path: {local_config_path}") | |
logger.info( | |
f"Downloading/finding weights file '{weights_filename}' from repo '{repo_id}'..." | |
) | |
local_weights_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=weights_filename, | |
cache_dir=cache_path, | |
) | |
logger.info(f"Weights file path: {local_weights_path}") | |
# --- Load Model using the class method --- | |
# The Dia class method now handles config loading, instantiation, weight loading, etc. | |
dia_model = Dia.load_model_from_files( | |
config_path=local_config_path, | |
weights_path=local_weights_path, | |
device=model_device, | |
) | |
# Store the config instance if needed globally (optional) | |
model_config_instance = dia_model.config | |
end_time = time.time() | |
logger.info( | |
f"Dia model loaded successfully in {end_time - start_time:.2f} seconds." | |
) | |
MODEL_LOADED = True | |
return True | |
except FileNotFoundError as e: | |
logger.error( | |
f"Model loading failed: Required file not found. {e}", exc_info=True | |
) | |
MODEL_LOADED = False | |
return False | |
except ImportError: | |
# This catches if the 'dia' package itself is missing | |
logger.critical( | |
"Failed to load model: Dia package or its core dependencies not found.", | |
exc_info=True, | |
) | |
MODEL_LOADED = False | |
return False | |
except Exception as e: | |
# Catch other potential errors during download or loading | |
logger.error( | |
f"Error loading Dia model from repo '{repo_id}': {e}", exc_info=True | |
) | |
dia_model = None | |
model_config_instance = None | |
MODEL_LOADED = False | |
return False | |
# --- Speech Generation --- | |
def generate_speech( | |
text: str, | |
voice_mode: str = "single_s1", | |
clone_reference_filename: Optional[str] = None, | |
max_tokens: Optional[int] = None, | |
cfg_scale: float = 3.0, | |
temperature: float = 1.3, | |
top_p: float = 0.95, | |
speed_factor: float = 0.94, # Keep speed factor separate from model generation params | |
cfg_filter_top_k: int = 35, | |
) -> Optional[Tuple[np.ndarray, int]]: | |
""" | |
Generates speech using the loaded Dia model, handling voice modes and speed adjustment. | |
Args: | |
text: Text to synthesize. | |
voice_mode: 'dialogue', 'single_s1', 'single_s2', 'clone'. | |
clone_reference_filename: Filename for voice cloning (if mode is 'clone'). Located in reference audio path. | |
max_tokens: Max generation tokens for the model's generate method. | |
cfg_scale: CFG scale for the model's generate method. | |
temperature: Sampling temperature for the model's generate method. | |
top_p: Nucleus sampling p for the model's generate method. | |
speed_factor: Factor to adjust the playback speed *after* generation (e.g., 0.9 = slower, 1.1 = faster). | |
cfg_filter_top_k: CFG filter top K for the model's generate method. | |
Returns: | |
Tuple of (numpy_audio_array, sample_rate), or None on failure. | |
""" | |
if not MODEL_LOADED or dia_model is None: | |
logger.error("Dia model is not loaded. Cannot generate speech.") | |
return None | |
logger.info(f"Generating speech with mode: {voice_mode}") | |
logger.debug(f"Input text (start): '{text[:100]}...'") | |
# Log model generation parameters | |
logger.debug( | |
f"Model Params: max_tokens={max_tokens}, cfg={cfg_scale}, temp={temperature}, top_p={top_p}, top_k={cfg_filter_top_k}" | |
) | |
# Log post-processing parameters | |
logger.debug(f"Post-processing Params: speed_factor={speed_factor}") | |
audio_prompt_path = None | |
processed_text = text # Start with original text | |
# --- Handle Voice Mode --- | |
if voice_mode == "clone": | |
if not clone_reference_filename: | |
logger.error("Clone mode selected but no reference filename provided.") | |
return None | |
ref_base_path = get_reference_audio_path() # Gets absolute path | |
potential_path = os.path.join(ref_base_path, clone_reference_filename) | |
if os.path.isfile(potential_path): | |
audio_prompt_path = potential_path | |
logger.info(f"Using audio prompt for cloning: {audio_prompt_path}") | |
# Dia requires the transcript of the clone audio to be prepended to the target text. | |
# The UI/API caller is responsible for constructing this combined text. | |
logger.warning( | |
"Clone mode active. Ensure the 'text' input includes the transcript of the reference audio for best results (e.g., '[S1] Reference transcript. [S1] Target text...')." | |
) | |
processed_text = text # Use the combined text provided by the caller | |
else: | |
logger.error(f"Reference audio file not found: {potential_path}") | |
return None # Fail generation if reference file is missing | |
elif voice_mode == "dialogue": | |
# Assume text already contains [S1]/[S2] tags as required by the model | |
logger.info("Using dialogue mode. Expecting [S1]/[S2] tags in input text.") | |
if "[S1]" not in text and "[S2]" not in text: | |
logger.warning( | |
"Dialogue mode selected, but no [S1] or [S2] tags found in the input text." | |
) | |
processed_text = text # Pass directly | |
elif voice_mode == "single_s1": | |
logger.info("Using single voice mode (S1).") | |
# Check if text *already* contains tags, warn if so, as it might confuse the model | |
if "[S1]" in text or "[S2]" in text: | |
logger.warning( | |
"Input text contains dialogue tags ([S1]/[S2]), but 'single_s1' mode was selected. Model behavior might be unexpected." | |
) | |
# Dia likely expects tags even for single speaker. Prepending [S1] might be safer. | |
# Let's assume for now the model handles untagged text as S1, but this could be adjusted. | |
# Consider: processed_text = f"[S1] {text}" # Option to enforce S1 tag | |
processed_text = text # Pass directly for now | |
elif voice_mode == "single_s2": | |
logger.info("Using single voice mode (S2).") | |
if "[S1]" in text or "[S2]" in text: | |
logger.warning( | |
"Input text contains dialogue tags ([S1]/[S2]), but 'single_s2' mode was selected." | |
) | |
# Similar to S1, how to signal S2? Prepending [S2] seems logical if needed. | |
# Consider: processed_text = f"[S2] {text}" # Option to enforce S2 tag | |
processed_text = text # Pass directly for now | |
else: | |
logger.error( | |
f"Unsupported voice_mode: {voice_mode}. Defaulting to 'single_s1'." | |
) | |
processed_text = text # Fallback | |
# --- Call Dia Generate --- | |
try: | |
start_time = time.time() | |
logger.info("Calling Dia model generate method...") | |
# Call the model's generate method with appropriate parameters | |
generated_audio_np = dia_model.generate( | |
text=processed_text, | |
audio_prompt_path=audio_prompt_path, | |
max_tokens=max_tokens, # Pass None if not specified, Dia uses its default | |
cfg_scale=cfg_scale, | |
temperature=temperature, | |
top_p=top_p, | |
use_cfg_filter=True, # Default from Dia's app.py, seems reasonable | |
cfg_filter_top_k=cfg_filter_top_k, | |
use_torch_compile=False, # Keep False for stability unless specifically tested/enabled | |
) | |
gen_end_time = time.time() | |
logger.info( | |
f"Dia model generation finished in {gen_end_time - start_time:.2f} seconds." | |
) | |
if generated_audio_np is None or generated_audio_np.size == 0: | |
logger.warning("Dia model returned None or empty audio array.") | |
return None | |
# --- Apply Speed Factor (Post-processing) --- | |
# This mimics the logic in Dia's original app.py | |
if speed_factor != 1.0: | |
logger.info(f"Applying speed factor: {speed_factor}") | |
original_len = len(generated_audio_np) | |
# Ensure speed_factor is within a reasonable range to avoid extreme distortion | |
# Adjust range based on observed quality (e.g., 0.5 to 2.0) | |
speed_factor = max(0.5, min(speed_factor, 2.0)) | |
target_len = int(original_len / speed_factor) | |
if target_len > 0 and target_len != original_len: | |
logger.debug( | |
f"Resampling audio from {original_len} to {target_len} samples." | |
) | |
# Create time axes for original and resampled audio | |
x_original = np.linspace(0, original_len - 1, original_len) | |
x_resampled = np.linspace(0, original_len - 1, target_len) | |
# Interpolate using numpy | |
resampled_audio_np = np.interp( | |
x_resampled, x_original, generated_audio_np | |
) | |
final_audio_np = resampled_audio_np.astype(np.float32) # Ensure float32 | |
logger.info(f"Audio resampled for {speed_factor:.2f}x speed.") | |
else: | |
logger.warning( | |
f"Skipping speed adjustment (factor: {speed_factor:.2f}). Target length invalid ({target_len}) or no change needed." | |
) | |
final_audio_np = generated_audio_np # Use original audio | |
else: | |
logger.info("Speed factor is 1.0, no speed adjustment needed.") | |
final_audio_np = generated_audio_np # No speed change needed | |
# Ensure output is float32 (DAC output should be, but good practice) | |
if final_audio_np.dtype != np.float32: | |
logger.warning( | |
f"Generated audio was not float32 ({final_audio_np.dtype}), converting." | |
) | |
final_audio_np = final_audio_np.astype(np.float32) | |
logger.info( | |
f"Final audio ready. Shape: {final_audio_np.shape}, dtype: {final_audio_np.dtype}" | |
) | |
# Return the processed audio and the expected sample rate | |
return final_audio_np, EXPECTED_SAMPLE_RATE | |
except Exception as e: | |
logger.error( | |
f"Error during Dia generation or post-processing: {e}", exc_info=True | |
) | |
return None # Return None on failure | |