import logging import numpy as np import soundfile as sf from typing import Optional, Generator, Tuple from utils.tts_base import TTSBase # Configure logging logger = logging.getLogger(__name__) # Flag to track Dia availability DIA_AVAILABLE = False DEFAULT_SAMPLE_RATE = 24000 # Try to import Dia dependencies try: import torch from dia.model import Dia DIA_AVAILABLE = True logger.info("Dia TTS engine is available") except ImportError: logger.warning("Dia TTS engine is not available") except ModuleNotFoundError as e: if "dac" in str(e): logger.warning("Dia TTS engine is not available due to missing 'dac' module") else: logger.warning(f"Dia TTS engine is not available: {str(e)}") DIA_AVAILABLE = False def _get_model(): """Lazy-load the Dia model Returns: Dia or None: The Dia model or None if not available """ if not DIA_AVAILABLE: logger.warning("Dia TTS engine is not available") return None try: import torch from dia.model import Dia # Initialize the model model = Dia.from_pretrained() logger.info("Dia model successfully loaded") return model except ImportError as e: logger.error(f"Failed to import Dia dependencies: {str(e)}") return None except FileNotFoundError as e: logger.error(f"Failed to load Dia model files: {str(e)}") return None except Exception as e: logger.error(f"Failed to initialize Dia model: {str(e)}") return None class DiaTTS(TTSBase): """Dia TTS engine implementation This engine uses the Dia model for TTS generation. """ def __init__(self, lang_code: str = 'z'): """Initialize the Dia TTS engine Args: lang_code (str): Language code for the engine """ super().__init__(lang_code) self.model = None def _ensure_model(self): """Ensure the model is loaded Returns: bool: True if model is available, False otherwise """ if self.model is None: self.model = _get_model() return self.model is not None def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]: """Generate speech using Dia TTS engine Args: text (str): Input text to synthesize voice (str): Voice ID (not used in Dia) speed (float): Speech speed multiplier (not used in Dia) Returns: Optional[str]: Path to the generated audio file or None if generation fails """ logger.info(f"Generating speech with Dia for text length: {len(text)}") # Check if Dia is available if not DIA_AVAILABLE: logger.error("Dia TTS engine is not available") return None # Ensure model is loaded if not self._ensure_model(): logger.error("Failed to load Dia model") return None try: import torch # Generate unique output path output_path = self._generate_output_path(prefix="dia") # Generate audio with torch.inference_mode(): output_audio_np = self.model.generate( text, max_tokens=None, cfg_scale=3.0, temperature=1.3, top_p=0.95, cfg_filter_top_k=35, use_torch_compile=False, verbose=False ) if output_audio_np is not None: logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})") sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE) logger.info(f"Dia audio generation complete: {output_path}") return output_path else: logger.error("Dia model returned None for audio output") return None except ModuleNotFoundError as e: if "dac" in str(e): logger.error("Dia TTS engine failed due to missing 'dac' module") else: logger.error(f"Module not found error in Dia TTS: {str(e)}") return None except Exception as e: logger.error(f"Error generating speech with Dia: {str(e)}", exc_info=True) return None def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]: """Generate speech stream using Dia TTS engine Args: text (str): Input text to synthesize voice (str): Voice ID (not used in Dia) speed (float): Speech speed multiplier (not used in Dia) Yields: tuple: (sample_rate, audio_data) pairs for each segment """ logger.info(f"Generating speech stream with Dia for text length: {len(text)}") # Check if Dia is available if not DIA_AVAILABLE: logger.error("Dia TTS engine is not available") return # Ensure model is loaded if not self._ensure_model(): logger.error("Failed to load Dia model") return try: import torch # Generate audio with torch.inference_mode(): output_audio_np = self.model.generate( text, max_tokens=None, cfg_scale=3.0, temperature=1.3, top_p=0.95, cfg_filter_top_k=35, use_torch_compile=False, verbose=False ) if output_audio_np is not None: logger.info(f"Successfully generated audio with Dia (length: {len(output_audio_np)})") yield DEFAULT_SAMPLE_RATE, output_audio_np else: logger.error("Dia model returned None for audio output") return except ModuleNotFoundError as e: if "dac" in str(e): logger.error("Dia TTS engine failed due to missing 'dac' module") else: logger.error(f"Module not found error in Dia TTS: {str(e)}") return except Exception as e: logger.error(f"Error generating speech stream with Dia: {str(e)}", exc_info=True) return