Michael Hu
Implement infrastructure base classes
e3cb97b
raw
history blame
9.22 kB
"""Base class for TTS provider implementations."""
import logging
import os
import time
import tempfile
from abc import ABC, abstractmethod
from typing import Iterator, Optional, TYPE_CHECKING
from pathlib import Path
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ...domain.models.audio_content import AudioContent
from ...domain.models.audio_chunk import AudioChunk
from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
class TTSProviderBase(ISpeechSynthesisService, ABC):
"""Abstract base class for TTS provider implementations."""
def __init__(self, provider_name: str, supported_languages: list[str] = None):
"""
Initialize the TTS provider.
Args:
provider_name: Name of the TTS provider
supported_languages: List of supported language codes
"""
self.provider_name = provider_name
self.supported_languages = supported_languages or []
self._output_dir = self._ensure_output_directory()
def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent':
"""
Synthesize speech from text.
Args:
request: The speech synthesis request
Returns:
AudioContent: The synthesized audio
Raises:
SpeechSynthesisException: If synthesis fails
"""
try:
logger.info(f"Starting synthesis with {self.provider_name} provider")
self._validate_request(request)
# Generate audio using provider-specific implementation
audio_data, sample_rate = self._generate_audio(request)
# Create AudioContent from the generated data
from ...domain.models.audio_content import AudioContent
audio_content = AudioContent(
data=audio_data,
format='wav', # Most providers output WAV
sample_rate=sample_rate,
duration=self._calculate_duration(audio_data, sample_rate),
filename=f"{self.provider_name}_{int(time.time())}.wav"
)
logger.info(f"Synthesis completed successfully with {self.provider_name}")
return audio_content
except Exception as e:
logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}")
raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e
def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']:
"""
Synthesize speech from text as a stream.
Args:
request: The speech synthesis request
Returns:
Iterator[AudioChunk]: Stream of audio chunks
Raises:
SpeechSynthesisException: If synthesis fails
"""
try:
logger.info(f"Starting streaming synthesis with {self.provider_name} provider")
self._validate_request(request)
# Generate audio stream using provider-specific implementation
chunk_index = 0
for audio_data, sample_rate, is_final in self._generate_audio_stream(request):
from ...domain.models.audio_chunk import AudioChunk
chunk = AudioChunk(
data=audio_data,
format='wav',
sample_rate=sample_rate,
chunk_index=chunk_index,
is_final=is_final,
timestamp=time.time()
)
yield chunk
chunk_index += 1
logger.info(f"Streaming synthesis completed with {self.provider_name}")
except Exception as e:
logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}")
raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e
@abstractmethod
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""
Generate audio data from synthesis request.
Args:
request: The speech synthesis request
Returns:
tuple: (audio_data_bytes, sample_rate)
"""
pass
@abstractmethod
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""
Generate audio data stream from synthesis request.
Args:
request: The speech synthesis request
Returns:
Iterator: (audio_data_bytes, sample_rate, is_final) tuples
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""
Check if the TTS provider is available and ready to use.
Returns:
bool: True if provider is available, False otherwise
"""
pass
@abstractmethod
def get_available_voices(self) -> list[str]:
"""
Get list of available voices for this provider.
Returns:
list[str]: List of voice identifiers
"""
pass
def _validate_request(self, request: 'SpeechSynthesisRequest') -> None:
"""
Validate the synthesis request.
Args:
request: The synthesis request to validate
Raises:
SpeechSynthesisException: If request is invalid
"""
if not request.text_content.text.strip():
raise SpeechSynthesisException("Text content cannot be empty")
if self.supported_languages and request.text_content.language not in self.supported_languages:
raise SpeechSynthesisException(
f"Language {request.text_content.language} not supported by {self.provider_name}. "
f"Supported languages: {self.supported_languages}"
)
available_voices = self.get_available_voices()
if available_voices and request.voice_settings.voice_id not in available_voices:
raise SpeechSynthesisException(
f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. "
f"Available voices: {available_voices}"
)
def _ensure_output_directory(self) -> Path:
"""
Ensure output directory exists and return its path.
Returns:
Path: Path to the output directory
"""
output_dir = Path(tempfile.gettempdir()) / "tts_output"
output_dir.mkdir(exist_ok=True)
return output_dir
def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path:
"""
Generate a unique output path for audio files.
Args:
prefix: Optional prefix for the filename
extension: File extension (default: wav)
Returns:
Path: Unique file path
"""
prefix = prefix or self.provider_name
timestamp = int(time.time() * 1000)
filename = f"{prefix}_{timestamp}.{extension}"
return self._output_dir / filename
def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float:
"""
Calculate audio duration from raw audio data.
Args:
audio_data: Raw audio data in bytes
sample_rate: Sample rate in Hz
channels: Number of audio channels (default: 1)
sample_width: Sample width in bytes (default: 2 for 16-bit)
Returns:
float: Duration in seconds
"""
if not audio_data or sample_rate <= 0:
return 0.0
bytes_per_sample = channels * sample_width
total_samples = len(audio_data) // bytes_per_sample
return total_samples / sample_rate
def _cleanup_temp_files(self, max_age_hours: int = 24) -> None:
"""
Clean up old temporary files.
Args:
max_age_hours: Maximum age of files to keep in hours
"""
try:
current_time = time.time()
max_age_seconds = max_age_hours * 3600
for file_path in self._output_dir.glob("*"):
if file_path.is_file():
file_age = current_time - file_path.stat().st_mtime
if file_age > max_age_seconds:
file_path.unlink()
logger.debug(f"Cleaned up old temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp files: {str(e)}")
def _handle_provider_error(self, error: Exception, context: str = "") -> None:
"""
Handle provider-specific errors and convert to domain exceptions.
Args:
error: The original error
context: Additional context about when the error occurred
"""
error_msg = f"{self.provider_name} error"
if context:
error_msg += f" during {context}"
error_msg += f": {str(error)}"
logger.error(error_msg, exc_info=True)
raise SpeechSynthesisException(error_msg) from error