Michael Hu
add more logs
fdc056d
"""
Base class for TTS provider implementations.
This module provides the abstract base class for all Text-to-Speech provider
implementations in the infrastructure layer. It implements common functionality
and defines the contract that all TTS providers must follow.
The base class handles:
- Common validation logic
- File management and cleanup
- Error handling and logging
- Audio format processing
- Provider lifecycle management
Example implementation:
```python
from src.infrastructure.base.tts_provider_base import TTSProviderBase
class MyTTSProvider(TTSProviderBase):
def __init__(self):
super().__init__("my_tts", ["en", "es"])
def _generate_audio(self, request):
# Implement TTS-specific logic
audio_data = my_tts_engine.synthesize(request.text_content.text)
return audio_data, 22050 # audio_bytes, sample_rate
def is_available(self):
return my_tts_engine.is_loaded()
def get_available_voices(self):
return ["voice1", "voice2"]
```
"""
import logging
import os
import time
import tempfile
from abc import ABC, abstractmethod
from typing import Iterator, Optional, TYPE_CHECKING
from pathlib import Path
if TYPE_CHECKING:
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest
from ...domain.models.audio_content import AudioContent
from ...domain.models.audio_chunk import AudioChunk
from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService
from ...domain.exceptions import SpeechSynthesisException
logger = logging.getLogger(__name__)
class TTSProviderBase(ISpeechSynthesisService, ABC):
"""
Abstract base class for TTS provider implementations.
This class provides a foundation for implementing Text-to-Speech providers
in the infrastructure layer. It handles common concerns like validation,
file management, error handling, and audio processing while allowing
concrete implementations to focus on provider-specific logic.
Key features:
- Automatic validation of synthesis requests
- Temporary file management with cleanup
- Consistent error handling and logging
- Support for both batch and streaming synthesis
- Audio format standardization
- Provider availability checking
Subclasses must implement:
- _generate_audio(): Core synthesis logic
- _generate_audio_stream(): Streaming synthesis (optional)
- is_available(): Provider availability check
- get_available_voices(): Voice enumeration
The base class ensures that all providers follow the same patterns
for error handling, logging, and resource management, making the
system more maintainable and predictable.
"""
def __init__(self, provider_name: str, supported_languages: list[str] = None):
"""
Initialize the TTS provider.
Sets up the provider with basic configuration and creates necessary
directories for temporary file storage. This constructor should be
called by all subclass implementations.
Args:
provider_name: Unique identifier for this TTS provider (e.g., "kokoro", "dia").
Used for logging, error messages, and provider selection.
supported_languages: List of ISO language codes supported by this provider
(e.g., ["en", "zh", "es"]). If None, no language validation
will be performed.
Example:
```python
class MyTTSProvider(TTSProviderBase):
def __init__(self):
super().__init__(
provider_name="my_tts",
supported_languages=["en", "es", "fr"]
)
```
"""
self.provider_name = provider_name
self.supported_languages = supported_languages or []
self._output_dir = self._ensure_output_directory()
def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent':
"""
Synthesize speech from text.
Args:
request: The speech synthesis request
Returns:
AudioContent: The synthesized audio
Raises:
SpeechSynthesisException: If synthesis fails
"""
try:
logger.info(f"Starting synthesis with {self.provider_name} provider")
self._validate_request(request)
# Generate audio using provider-specific implementation
audio_data, sample_rate = self._generate_audio(request)
# Create AudioContent from the generated data
from ...domain.models.audio_content import AudioContent
audio_content = AudioContent(
data=audio_data,
format='wav', # Most providers output WAV
sample_rate=sample_rate,
duration=self._calculate_duration(audio_data, sample_rate),
filename=f"{self.provider_name}_{int(time.time())}.wav"
)
logger.info(f"Synthesis completed successfully with {self.provider_name}")
return audio_content
except Exception as e:
logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}")
raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e
def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']:
"""
Synthesize speech from text as a stream.
Args:
request: The speech synthesis request
Returns:
Iterator[AudioChunk]: Stream of audio chunks
Raises:
SpeechSynthesisException: If synthesis fails
"""
try:
logger.info(f"Starting streaming synthesis with {self.provider_name} provider")
self._validate_request(request)
# Generate audio stream using provider-specific implementation
chunk_index = 0
for audio_data, sample_rate, is_final in self._generate_audio_stream(request):
from ...domain.models.audio_chunk import AudioChunk
chunk = AudioChunk(
data=audio_data,
format='wav',
sample_rate=sample_rate,
chunk_index=chunk_index,
is_final=is_final,
timestamp=time.time()
)
yield chunk
chunk_index += 1
logger.info(f"Streaming synthesis completed with {self.provider_name}")
except Exception as e:
logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}")
raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e
@abstractmethod
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]:
"""
Generate audio data from synthesis request.
Args:
request: The speech synthesis request
Returns:
tuple: (audio_data_bytes, sample_rate)
"""
pass
@abstractmethod
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
"""
Generate audio data stream from synthesis request.
Args:
request: The speech synthesis request
Returns:
Iterator: (audio_data_bytes, sample_rate, is_final) tuples
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""
Check if the TTS provider is available and ready to use.
Returns:
bool: True if provider is available, False otherwise
"""
pass
@abstractmethod
def get_available_voices(self) -> list[str]:
"""
Get list of available voices for this provider.
Returns:
list[str]: List of voice identifiers
"""
pass
def _validate_request(self, request: 'SpeechSynthesisRequest') -> None:
"""
Validate the synthesis request.
Args:
request: The synthesis request to validate
Raises:
SpeechSynthesisException: If request is invalid
"""
if not request.text_content.text.strip():
raise SpeechSynthesisException("Text content cannot be empty")
if self.supported_languages and request.text_content.language not in self.supported_languages:
raise SpeechSynthesisException(
f"Language {request.text_content.language} not supported by {self.provider_name}. "
f"Supported languages: {self.supported_languages}"
)
available_voices = self.get_available_voices()
if available_voices and request.voice_settings.voice_id not in available_voices:
raise SpeechSynthesisException(
f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. "
f"Available voices: {available_voices}"
)
def _ensure_output_directory(self) -> Path:
"""
Ensure output directory exists and return its path.
Returns:
Path: Path to the output directory
"""
output_dir = Path(tempfile.gettempdir()) / "tts_output"
output_dir.mkdir(exist_ok=True)
return output_dir
def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path:
"""
Generate a unique output path for audio files.
Args:
prefix: Optional prefix for the filename
extension: File extension (default: wav)
Returns:
Path: Unique file path
"""
prefix = prefix or self.provider_name
timestamp = int(time.time() * 1000)
filename = f"{prefix}_{timestamp}.{extension}"
return self._output_dir / filename
def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float:
"""
Calculate audio duration from raw audio data.
Args:
audio_data: Raw audio data in bytes
sample_rate: Sample rate in Hz
channels: Number of audio channels (default: 1)
sample_width: Sample width in bytes (default: 2 for 16-bit)
Returns:
float: Duration in seconds
"""
if not audio_data or sample_rate <= 0:
return 0.0
bytes_per_sample = channels * sample_width
total_samples = len(audio_data) // bytes_per_sample
return total_samples / sample_rate
def _cleanup_temp_files(self, max_age_hours: int = 24) -> None:
"""
Clean up old temporary files.
Args:
max_age_hours: Maximum age of files to keep in hours
"""
try:
current_time = time.time()
max_age_seconds = max_age_hours * 3600
for file_path in self._output_dir.glob("*"):
if file_path.is_file():
file_age = current_time - file_path.stat().st_mtime
if file_age > max_age_seconds:
file_path.unlink()
logger.info(f"Cleaned up old temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp files: {str(e)}")
def _handle_provider_error(self, error: Exception, context: str = "") -> None:
"""
Handle provider-specific errors and convert to domain exceptions.
Args:
error: The original error
context: Additional context about when the error occurred
"""
error_msg = f"{self.provider_name} error"
if context:
error_msg += f" during {context}"
error_msg += f": {str(error)}"
logger.error(error_msg, exception=error)
raise SpeechSynthesisException(error_msg) from error