Spaces:
Build error
Build error
| """Base class for TTS provider implementations.""" | |
| import logging | |
| import os | |
| import time | |
| import tempfile | |
| from abc import ABC, abstractmethod | |
| from typing import Iterator, Optional, TYPE_CHECKING | |
| from pathlib import Path | |
| if TYPE_CHECKING: | |
| from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest | |
| from ...domain.models.audio_content import AudioContent | |
| from ...domain.models.audio_chunk import AudioChunk | |
| from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService | |
| from ...domain.exceptions import SpeechSynthesisException | |
| logger = logging.getLogger(__name__) | |
| class TTSProviderBase(ISpeechSynthesisService, ABC): | |
| """Abstract base class for TTS provider implementations.""" | |
| def __init__(self, provider_name: str, supported_languages: list[str] = None): | |
| """ | |
| Initialize the TTS provider. | |
| Args: | |
| provider_name: Name of the TTS provider | |
| supported_languages: List of supported language codes | |
| """ | |
| self.provider_name = provider_name | |
| self.supported_languages = supported_languages or [] | |
| self._output_dir = self._ensure_output_directory() | |
| def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent': | |
| """ | |
| Synthesize speech from text. | |
| Args: | |
| request: The speech synthesis request | |
| Returns: | |
| AudioContent: The synthesized audio | |
| Raises: | |
| SpeechSynthesisException: If synthesis fails | |
| """ | |
| try: | |
| logger.info(f"Starting synthesis with {self.provider_name} provider") | |
| self._validate_request(request) | |
| # Generate audio using provider-specific implementation | |
| audio_data, sample_rate = self._generate_audio(request) | |
| # Create AudioContent from the generated data | |
| from ...domain.models.audio_content import AudioContent | |
| audio_content = AudioContent( | |
| data=audio_data, | |
| format='wav', # Most providers output WAV | |
| sample_rate=sample_rate, | |
| duration=self._calculate_duration(audio_data, sample_rate), | |
| filename=f"{self.provider_name}_{int(time.time())}.wav" | |
| ) | |
| logger.info(f"Synthesis completed successfully with {self.provider_name}") | |
| return audio_content | |
| except Exception as e: | |
| logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}") | |
| raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e | |
| def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']: | |
| """ | |
| Synthesize speech from text as a stream. | |
| Args: | |
| request: The speech synthesis request | |
| Returns: | |
| Iterator[AudioChunk]: Stream of audio chunks | |
| Raises: | |
| SpeechSynthesisException: If synthesis fails | |
| """ | |
| try: | |
| logger.info(f"Starting streaming synthesis with {self.provider_name} provider") | |
| self._validate_request(request) | |
| # Generate audio stream using provider-specific implementation | |
| chunk_index = 0 | |
| for audio_data, sample_rate, is_final in self._generate_audio_stream(request): | |
| from ...domain.models.audio_chunk import AudioChunk | |
| chunk = AudioChunk( | |
| data=audio_data, | |
| format='wav', | |
| sample_rate=sample_rate, | |
| chunk_index=chunk_index, | |
| is_final=is_final, | |
| timestamp=time.time() | |
| ) | |
| yield chunk | |
| chunk_index += 1 | |
| logger.info(f"Streaming synthesis completed with {self.provider_name}") | |
| except Exception as e: | |
| logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}") | |
| raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e | |
| def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: | |
| """ | |
| Generate audio data from synthesis request. | |
| Args: | |
| request: The speech synthesis request | |
| Returns: | |
| tuple: (audio_data_bytes, sample_rate) | |
| """ | |
| pass | |
| def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: | |
| """ | |
| Generate audio data stream from synthesis request. | |
| Args: | |
| request: The speech synthesis request | |
| Returns: | |
| Iterator: (audio_data_bytes, sample_rate, is_final) tuples | |
| """ | |
| pass | |
| def is_available(self) -> bool: | |
| """ | |
| Check if the TTS provider is available and ready to use. | |
| Returns: | |
| bool: True if provider is available, False otherwise | |
| """ | |
| pass | |
| def get_available_voices(self) -> list[str]: | |
| """ | |
| Get list of available voices for this provider. | |
| Returns: | |
| list[str]: List of voice identifiers | |
| """ | |
| pass | |
| def _validate_request(self, request: 'SpeechSynthesisRequest') -> None: | |
| """ | |
| Validate the synthesis request. | |
| Args: | |
| request: The synthesis request to validate | |
| Raises: | |
| SpeechSynthesisException: If request is invalid | |
| """ | |
| if not request.text_content.text.strip(): | |
| raise SpeechSynthesisException("Text content cannot be empty") | |
| if self.supported_languages and request.text_content.language not in self.supported_languages: | |
| raise SpeechSynthesisException( | |
| f"Language {request.text_content.language} not supported by {self.provider_name}. " | |
| f"Supported languages: {self.supported_languages}" | |
| ) | |
| available_voices = self.get_available_voices() | |
| if available_voices and request.voice_settings.voice_id not in available_voices: | |
| raise SpeechSynthesisException( | |
| f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. " | |
| f"Available voices: {available_voices}" | |
| ) | |
| def _ensure_output_directory(self) -> Path: | |
| """ | |
| Ensure output directory exists and return its path. | |
| Returns: | |
| Path: Path to the output directory | |
| """ | |
| output_dir = Path(tempfile.gettempdir()) / "tts_output" | |
| output_dir.mkdir(exist_ok=True) | |
| return output_dir | |
| def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path: | |
| """ | |
| Generate a unique output path for audio files. | |
| Args: | |
| prefix: Optional prefix for the filename | |
| extension: File extension (default: wav) | |
| Returns: | |
| Path: Unique file path | |
| """ | |
| prefix = prefix or self.provider_name | |
| timestamp = int(time.time() * 1000) | |
| filename = f"{prefix}_{timestamp}.{extension}" | |
| return self._output_dir / filename | |
| def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float: | |
| """ | |
| Calculate audio duration from raw audio data. | |
| Args: | |
| audio_data: Raw audio data in bytes | |
| sample_rate: Sample rate in Hz | |
| channels: Number of audio channels (default: 1) | |
| sample_width: Sample width in bytes (default: 2 for 16-bit) | |
| Returns: | |
| float: Duration in seconds | |
| """ | |
| if not audio_data or sample_rate <= 0: | |
| return 0.0 | |
| bytes_per_sample = channels * sample_width | |
| total_samples = len(audio_data) // bytes_per_sample | |
| return total_samples / sample_rate | |
| def _cleanup_temp_files(self, max_age_hours: int = 24) -> None: | |
| """ | |
| Clean up old temporary files. | |
| Args: | |
| max_age_hours: Maximum age of files to keep in hours | |
| """ | |
| try: | |
| current_time = time.time() | |
| max_age_seconds = max_age_hours * 3600 | |
| for file_path in self._output_dir.glob("*"): | |
| if file_path.is_file(): | |
| file_age = current_time - file_path.stat().st_mtime | |
| if file_age > max_age_seconds: | |
| file_path.unlink() | |
| logger.debug(f"Cleaned up old temp file: {file_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup temp files: {str(e)}") | |
| def _handle_provider_error(self, error: Exception, context: str = "") -> None: | |
| """ | |
| Handle provider-specific errors and convert to domain exceptions. | |
| Args: | |
| error: The original error | |
| context: Additional context about when the error occurred | |
| """ | |
| error_msg = f"{self.provider_name} error" | |
| if context: | |
| error_msg += f" during {context}" | |
| error_msg += f": {str(error)}" | |
| logger.error(error_msg, exc_info=True) | |
| raise SpeechSynthesisException(error_msg) from error |