Michael Hu
Migrate existing TTS providers to infrastructure layer
1f9c751
raw
history blame
6.2 kB
"""Dia TTS provider implementation."""
import logging
import numpy as np
import soundfile as sf
import io
from typing import Iterator, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ..base.tts_provider_base import TTSProviderBase
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
# Flag to track Dia availability
DIA_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000
# Try to import Dia dependencies
try:
import torch
from dia.model import Dia
DIA_AVAILABLE = True
logger.info("Dia TTS engine is available")
except ImportError:
logger.warning("Dia TTS engine is not available")
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("Dia TTS engine is not available due to missing 'dac' module")
else:
logger.warning(f"Dia TTS engine is not available: {str(e)}")
DIA_AVAILABLE = False
class DiaTTSProvider(TTSProviderBase):
"""Dia TTS provider implementation."""
def __init__(self, lang_code: str = 'z'):
"""Initialize the Dia TTS provider."""
super().__init__(
provider_name="Dia",
supported_languages=['en', 'z'] # Dia supports English and multilingual
)
self.lang_code = lang_code
self.model = None
def _ensure_model(self):
"""Ensure the model is loaded."""
if self.model is None and DIA_AVAILABLE:
try:
import torch
from dia.model import Dia
self.model = Dia.from_pretrained()
logger.info("Dia model successfully loaded")
except ImportError as e:
logger.error(f"Failed to import Dia dependencies: {str(e)}")
self.model = None
except FileNotFoundError as e:
logger.error(f"Failed to load Dia model files: {str(e)}")
self.model = None
except Exception as e:
logger.error(f"Failed to initialize Dia model: {str(e)}")
self.model = None
return self.model is not None
def is_available(self) -> bool:
"""Check if Dia TTS is available."""
return DIA_AVAILABLE and self._ensure_model()
def get_available_voices(self) -> list[str]:
"""Get available voices for Dia."""
# Dia typically uses a default voice
return ['default']
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""Generate audio using Dia TTS."""
if not self.is_available():
raise SpeechSynthesisException("Dia TTS engine is not available")
try:
import torch
# Extract parameters from request
text = request.text_content.text
# Generate audio using Dia
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is None:
raise SpeechSynthesisException("Dia model returned None for audio output")
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
return audio_bytes, DEFAULT_SAMPLE_RATE
except ModuleNotFoundError as e:
if "dac" in str(e):
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
else:
self._handle_provider_error(e, "audio generation")
except Exception as e:
self._handle_provider_error(e, "audio generation")
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""Generate audio stream using Dia TTS."""
if not self.is_available():
raise SpeechSynthesisException("Dia TTS engine is not available")
try:
import torch
# Extract parameters from request
text = request.text_content.text
# Generate audio using Dia
with torch.inference_mode():
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
cfg_filter_top_k=35,
use_torch_compile=False,
verbose=False
)
if output_audio_np is None:
raise SpeechSynthesisException("Dia model returned None for audio output")
# Convert numpy array to bytes
audio_bytes = self._numpy_to_bytes(output_audio_np, sample_rate=DEFAULT_SAMPLE_RATE)
# Dia generates complete audio in one go
yield audio_bytes, DEFAULT_SAMPLE_RATE, True
except ModuleNotFoundError as e:
if "dac" in str(e):
raise SpeechSynthesisException("Dia TTS engine failed due to missing 'dac' module") from e
else:
self._handle_provider_error(e, "streaming audio generation")
except Exception as e:
self._handle_provider_error(e, "streaming audio generation")
def _numpy_to_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
"""Convert numpy audio array to bytes."""
try:
# Create an in-memory buffer
buffer = io.BytesIO()
# Write audio data to buffer as WAV
sf.write(buffer, audio_array, sample_rate, format='WAV')
# Get bytes from buffer
buffer.seek(0)
return buffer.read()
except Exception as e:
raise SpeechSynthesisException(f"Failed to convert audio to bytes: {str(e)}") from e