Spaces:
Build error
Build error
"""Dia TTS provider implementation.""" | |
import logging | |
import numpy as np | |
import soundfile as sf | |
import io | |
from typing import Iterator, TYPE_CHECKING | |
if TYPE_CHECKING: | |
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
from ..base.tts_provider_base import TTSProviderBase | |
from ...domain.exceptions import SpeechSynthesisException | |
logger = logging.getLogger(__name__) | |
# Flag to track Dia availability | |
DIA_AVAILABLE = False | |
DEFAULT_SAMPLE_RATE = 24000 | |
# Try to import Dia dependencies | |
def _check_dia_dependencies(): | |
"""Check if Dia dependencies are available.""" | |
global DIA_AVAILABLE | |
logger.info("π Checking Dia TTS dependencies...") | |
try: | |
logger.info("Attempting to import torch...") | |
import torch | |
logger.info("β Successfully imported torch") | |
logger.info("Attempting to import dia.model...") | |
from dia.model import Dia | |
logger.info("β Successfully imported dia.model") | |
DIA_AVAILABLE = True | |
logger.info("β Dia TTS engine is available") | |
return True | |
except ImportError as e: | |
logger.warning(f"β οΈ Dia TTS engine dependencies not available: {e}") | |
logger.info(f"ImportError details: {type(e).__name__}: {e}") | |
DIA_AVAILABLE = False | |
return False | |
except ModuleNotFoundError as e: | |
if "dac" in str(e): | |
logger.warning("β Dia TTS engine is not available due to missing 'dac' module") | |
logger.info("Please install descript-audio-codec: pip install descript-audio-codec") | |
elif "dia" in str(e): | |
logger.warning("β Dia TTS engine is not available due to missing 'dia' module") | |
logger.info("Please install dia: pip install git+https://github.com/nari-labs/dia.git") | |
else: | |
logger.warning(f"β Dia TTS engine is not available: {str(e)}") | |
logger.info(f"ModuleNotFoundError details: {type(e).__name__}: {e}") | |
DIA_AVAILABLE = False | |
return False | |
# Initial check | |
logger.info("π Initializing Dia TTS provider...") | |
_check_dia_dependencies() | |
class DiaTTSProvider(TTSProviderBase): | |
"""Dia TTS provider implementation.""" | |
def __init__(self, lang_code: str = 'z'): | |
"""Initialize the Dia TTS provider.""" | |
super().__init__( | |
provider_name="Dia", | |
supported_languages=['en', 'z'] # Dia supports English and multilingual | |
) | |
self.lang_code = lang_code | |
self.model = None | |
def _ensure_model(self): | |
"""Ensure the model is loaded.""" | |
global DIA_AVAILABLE | |
if self.model is None: | |
logger.info("π Ensuring Dia model is loaded...") | |
# If Dia is not available, check dependencies again | |
if not DIA_AVAILABLE: | |
logger.info("β οΈ Dia not available, checking dependencies again...") | |
if _check_dia_dependencies(): | |
DIA_AVAILABLE = True | |
logger.info("β Dependencies are now available") | |
else: | |
logger.error("β Dependencies still not available") | |
return False | |
if DIA_AVAILABLE: | |
try: | |
logger.info("π₯ Loading Dia model from pretrained...") | |
import torch | |
from dia.model import Dia | |
self.model = Dia.from_pretrained() | |
logger.info("π Dia model successfully loaded") | |
except ImportError as e: | |
logger.error(f"β Failed to import Dia dependencies: {str(e)}") | |
self.model = None | |
except FileNotFoundError as e: | |
logger.error(f"β Failed to load Dia model files: {str(e)}") | |
logger.info("βΉοΈ This might be the first time loading the model. It will be downloaded automatically.") | |
self.model = None | |
except Exception as e: | |
logger.error(f"β Failed to initialize Dia model: {str(e)}") | |
logger.info(f"Model initialization error: {type(e).__name__}: {e}") | |
self.model = None | |
is_available = self.model is not None | |
logger.info(f"Model availability check result: {is_available}") | |
return is_available | |
def is_available(self) -> bool: | |
"""Check if Dia TTS is available.""" | |
logger.info(f"π Checking Dia availability: DIA_AVAILABLE={DIA_AVAILABLE}") | |
if not DIA_AVAILABLE: | |
logger.info("β Dia dependencies not available") | |
return False | |
model_available = self._ensure_model() | |
logger.info(f"π Model availability: {model_available}") | |
result = DIA_AVAILABLE and model_available | |
logger.info(f"π― Dia TTS availability result: {result}") | |
return result | |
def get_available_voices(self) -> list[str]: | |
"""Get available voices for Dia.""" | |
# Dia typically uses a default voice | |
return ['default'] | |
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
"""Generate audio using Dia TTS.""" | |
if not self.is_available(): | |
raise SpeechSynthesisException("Dia TTS engine is not available") | |
try: | |
import torch | |
# Extract parameters from request | |
text = request.text_content.text | |
# Generate audio using Dia | |
with torch.inference_mode(): | |
output_audio_np = self.model.generate( | |
text, | |
max_tokens=None, | |
cfg_scale=3.0, | |
temperature=1.3, | |
top_p=0.95, | |
cfg_filter_top_k=35, | |
use_torch_compile=False, | |
verbose=False | |
) | |
if output_audio_np is None: | |
raise SpeechSynthesisException("Dia model returned None for audio output") | |
# Convert numpy array to bytes | |
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE) | |
return audio_bytes, DEFAULT_SAMPLE_RATE | |
except ModuleNotFoundError as e: | |
if "dac" in str(e): | |
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e | |
else: | |
self._handle_provider_error(e, "audio generation") | |
except Exception as e: | |
self._handle_provider_error(e, "audio generation") | |
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
"""Generate audio stream using Dia TTS.""" | |
if not self.is_available(): | |
raise SpeechSynthesisException("Dia TTS engine is not available") | |
try: | |
import torch | |
# Extract parameters from request | |
text = request.text_content.text | |
# Generate audio using Dia | |
with torch.inference_mode(): | |
output_audio_np = self.model.generate( | |
text, | |
max_tokens=None, | |
cfg_scale=3.0, | |
temperature=1.3, | |
top_p=0.95, | |
cfg_filter_top_k=35, | |
use_torch_compile=False, | |
verbose=False | |
) | |
if output_audio_np is None: | |
raise SpeechSynthesisException("Dia model returned None for audio output") | |
# Convert numpy array to bytes | |
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE) | |
# Dia generates complete audio in one go | |
yield audio_bytes, DEFAULT_SAMPLE_RATE, True | |
except ModuleNotFoundError as e: | |
if "dac" in str(e): | |
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e | |
else: | |
self._handle_provider_error(e, "streaming audio generation") | |
except Exception as e: | |
self._handle_provider_error(e, "streaming audio generation") | |
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes: | |
"""Convert numpy audio array to bytes.""" | |
try: | |
# Create an in-memory buffer | |
buffer = io.BytesIO() | |
# Write audio data to buffer as WAV | |
sf.write(buffer, audio_array, sample_rate, format='WAV') | |
# Get bytes from buffer | |
buffer.seek(0) | |
return buffer.read() | |
except Exception as e: | |
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e |