Spaces:
Build error
Build error
""" | |
Base class for TTS provider implementations. | |
This module provides the abstract base class for all Text-to-Speech provider | |
implementations in the infrastructure layer. It implements common functionality | |
and defines the contract that all TTS providers must follow. | |
The base class handles: | |
- Common validation logic | |
- File management and cleanup | |
- Error handling and logging | |
- Audio format processing | |
- Provider lifecycle management | |
Example implementation: | |
```python | |
from src.infrastructure.base.tts_provider_base import TTSProviderBase | |
class MyTTSProvider(TTSProviderBase): | |
def __init__(self): | |
super().__init__("my_tts", ["en", "es"]) | |
def _generate_audio(self, request): | |
# Implement TTS-specific logic | |
audio_data = my_tts_engine.synthesize(request.text_content.text) | |
return audio_data, 22050 # audio_bytes, sample_rate | |
def is_available(self): | |
return my_tts_engine.is_loaded() | |
def get_available_voices(self): | |
return ["voice1", "voice2"] | |
``` | |
""" | |
import logging | |
import os | |
import time | |
import tempfile | |
from abc import ABC, abstractmethod | |
from typing import Iterator, Optional, TYPE_CHECKING | |
from pathlib import Path | |
if TYPE_CHECKING: | |
from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
from ...domain.models.audio_content import AudioContent | |
from ...domain.models.audio_chunk import AudioChunk | |
from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService | |
from ...domain.exceptions import SpeechSynthesisException | |
logger = logging.getLogger(__name__) | |
class TTSProviderBase(ISpeechSynthesisService, ABC): | |
""" | |
Abstract base class for TTS provider implementations. | |
This class provides a foundation for implementing Text-to-Speech providers | |
in the infrastructure layer. It handles common concerns like validation, | |
file management, error handling, and audio processing while allowing | |
concrete implementations to focus on provider-specific logic. | |
Key features: | |
- Automatic validation of synthesis requests | |
- Temporary file management with cleanup | |
- Consistent error handling and logging | |
- Support for both batch and streaming synthesis | |
- Audio format standardization | |
- Provider availability checking | |
Subclasses must implement: | |
- _generate_audio(): Core synthesis logic | |
- _generate_audio_stream(): Streaming synthesis (optional) | |
- is_available(): Provider availability check | |
- get_available_voices(): Voice enumeration | |
The base class ensures that all providers follow the same patterns | |
for error handling, logging, and resource management, making the | |
system more maintainable and predictable. | |
""" | |
def __init__(self, provider_name: str, supported_languages: list[str] = None): | |
""" | |
Initialize the TTS provider. | |
Sets up the provider with basic configuration and creates necessary | |
directories for temporary file storage. This constructor should be | |
called by all subclass implementations. | |
Args: | |
provider_name: Unique identifier for this TTS provider (e.g., "kokoro", "dia"). | |
Used for logging, error messages, and provider selection. | |
supported_languages: List of ISO language codes supported by this provider | |
(e.g., ["en", "zh", "es"]). If None, no language validation | |
will be performed. | |
Example: | |
```python | |
class MyTTSProvider(TTSProviderBase): | |
def __init__(self): | |
super().__init__( | |
provider_name="my_tts", | |
supported_languages=["en", "es", "fr"] | |
) | |
``` | |
""" | |
self.provider_name = provider_name | |
self.supported_languages = supported_languages or [] | |
self._output_dir = self._ensure_output_directory() | |
def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent': | |
""" | |
Synthesize speech from text. | |
Args: | |
request: The speech synthesis request | |
Returns: | |
AudioContent: The synthesized audio | |
Raises: | |
SpeechSynthesisException: If synthesis fails | |
""" | |
try: | |
logger.info(f"Starting synthesis with {self.provider_name} provider") | |
self._validate_request(request) | |
# Generate audio using provider-specific implementation | |
audio_data, sample_rate = self._generate_audio(request) | |
# Create AudioContent from the generated data | |
from ...domain.models.audio_content import AudioContent | |
audio_content = AudioContent( | |
data=audio_data, | |
format='wav', # Most providers output WAV | |
sample_rate=sample_rate, | |
duration=self._calculate_duration(audio_data, sample_rate), | |
filename=f"{self.provider_name}_{int(time.time())}.wav" | |
) | |
logger.info(f"Synthesis completed successfully with {self.provider_name}") | |
return audio_content | |
except Exception as e: | |
logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}") | |
raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e | |
def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']: | |
""" | |
Synthesize speech from text as a stream. | |
Args: | |
request: The speech synthesis request | |
Returns: | |
Iterator[AudioChunk]: Stream of audio chunks | |
Raises: | |
SpeechSynthesisException: If synthesis fails | |
""" | |
try: | |
logger.info(f"Starting streaming synthesis with {self.provider_name} provider") | |
self._validate_request(request) | |
# Generate audio stream using provider-specific implementation | |
chunk_index = 0 | |
for audio_data, sample_rate, is_final in self._generate_audio_stream(request): | |
from ...domain.models.audio_chunk import AudioChunk | |
chunk = AudioChunk( | |
data=audio_data, | |
format='wav', | |
sample_rate=sample_rate, | |
chunk_index=chunk_index, | |
is_final=is_final, | |
timestamp=time.time() | |
) | |
yield chunk | |
chunk_index += 1 | |
logger.info(f"Streaming synthesis completed with {self.provider_name}") | |
except Exception as e: | |
logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}") | |
raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e | |
def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
""" | |
Generate audio data from synthesis request. | |
Args: | |
request: The speech synthesis request | |
Returns: | |
tuple: (audio_data_bytes, sample_rate) | |
""" | |
pass | |
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
""" | |
Generate audio data stream from synthesis request. | |
Args: | |
request: The speech synthesis request | |
Returns: | |
Iterator: (audio_data_bytes, sample_rate, is_final) tuples | |
""" | |
pass | |
def is_available(self) -> bool: | |
""" | |
Check if the TTS provider is available and ready to use. | |
Returns: | |
bool: True if provider is available, False otherwise | |
""" | |
pass | |
def get_available_voices(self) -> list[str]: | |
""" | |
Get list of available voices for this provider. | |
Returns: | |
list[str]: List of voice identifiers | |
""" | |
pass | |
def _validate_request(self, request: 'SpeechSynthesisRequest') -> None: | |
""" | |
Validate the synthesis request. | |
Args: | |
request: The synthesis request to validate | |
Raises: | |
SpeechSynthesisException: If request is invalid | |
""" | |
if not request.text_content.text.strip(): | |
raise SpeechSynthesisException("Text content cannot be empty") | |
if self.supported_languages and request.text_content.language not in self.supported_languages: | |
raise SpeechSynthesisException( | |
f"Language {request.text_content.language} not supported by {self.provider_name}. " | |
f"Supported languages: {self.supported_languages}" | |
) | |
available_voices = self.get_available_voices() | |
if available_voices and request.voice_settings.voice_id not in available_voices: | |
raise SpeechSynthesisException( | |
f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. " | |
f"Available voices: {available_voices}" | |
) | |
def _ensure_output_directory(self) -> Path: | |
""" | |
Ensure output directory exists and return its path. | |
Returns: | |
Path: Path to the output directory | |
""" | |
output_dir = Path(tempfile.gettempdir()) / "tts_output" | |
output_dir.mkdir(exist_ok=True) | |
return output_dir | |
def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path: | |
""" | |
Generate a unique output path for audio files. | |
Args: | |
prefix: Optional prefix for the filename | |
extension: File extension (default: wav) | |
Returns: | |
Path: Unique file path | |
""" | |
prefix = prefix or self.provider_name | |
timestamp = int(time.time() * 1000) | |
filename = f"{prefix}_{timestamp}.{extension}" | |
return self._output_dir / filename | |
def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float: | |
""" | |
Calculate audio duration from raw audio data. | |
Args: | |
audio_data: Raw audio data in bytes | |
sample_rate: Sample rate in Hz | |
channels: Number of audio channels (default: 1) | |
sample_width: Sample width in bytes (default: 2 for 16-bit) | |
Returns: | |
float: Duration in seconds | |
""" | |
if not audio_data or sample_rate <= 0: | |
return 0.0 | |
bytes_per_sample = channels * sample_width | |
total_samples = len(audio_data) // bytes_per_sample | |
return total_samples / sample_rate | |
def _cleanup_temp_files(self, max_age_hours: int = 24) -> None: | |
""" | |
Clean up old temporary files. | |
Args: | |
max_age_hours: Maximum age of files to keep in hours | |
""" | |
try: | |
current_time = time.time() | |
max_age_seconds = max_age_hours * 3600 | |
for file_path in self._output_dir.glob("*"): | |
if file_path.is_file(): | |
file_age = current_time - file_path.stat().st_mtime | |
if file_age > max_age_seconds: | |
file_path.unlink() | |
logger.info(f"Cleaned up old temp file: {file_path}") | |
except Exception as e: | |
logger.warning(f"Failed to cleanup temp files: {str(e)}") | |
def _handle_provider_error(self, error: Exception, context: str = "") -> None: | |
""" | |
Handle provider-specific errors and convert to domain exceptions. | |
Args: | |
error: The original error | |
context: Additional context about when the error occurred | |
""" | |
error_msg = f"{self.provider_name} error" | |
if context: | |
error_msg += f" during {context}" | |
error_msg += f": {str(error)}" | |
logger.error(error_msg, exception=error) | |
raise SpeechSynthesisException(error_msg) from error |