""" Base class for TTS provider implementations. This module provides the abstract base class for all Text-to-Speech provider implementations in the infrastructure layer. It implements common functionality and defines the contract that all TTS providers must follow. The base class handles: - Common validation logic - File management and cleanup - Error handling and logging - Audio format processing - Provider lifecycle management Example implementation: ```python from src.infrastructure.base.tts_provider_base import TTSProviderBase class MyTTSProvider(TTSProviderBase): def __init__(self): super().__init__("my_tts", ["en", "es"]) def _generate_audio(self, request): # Implement TTS-specific logic audio_data = my_tts_engine.synthesize(request.text_content.text) return audio_data, 22050 # audio_bytes, sample_rate def is_available(self): return my_tts_engine.is_loaded() def get_available_voices(self): return ["voice1", "voice2"] ``` """ import logging import os import time import tempfile from abc import ABC, abstractmethod from typing import Iterator, Optional, TYPE_CHECKING from pathlib import Path if TYPE_CHECKING: from ...domain.models.speech_synthesis_request import SpeechSynthesisRequest from ...domain.models.audio_content import AudioContent from ...domain.models.audio_chunk import AudioChunk from ...domain.interfaces.speech_synthesis import ISpeechSynthesisService from ...domain.exceptions import SpeechSynthesisException logger = logging.getLogger(__name__) class TTSProviderBase(ISpeechSynthesisService, ABC): """ Abstract base class for TTS provider implementations. This class provides a foundation for implementing Text-to-Speech providers in the infrastructure layer. It handles common concerns like validation, file management, error handling, and audio processing while allowing concrete implementations to focus on provider-specific logic. Key features: - Automatic validation of synthesis requests - Temporary file management with cleanup - Consistent error handling and logging - Support for both batch and streaming synthesis - Audio format standardization - Provider availability checking Subclasses must implement: - _generate_audio(): Core synthesis logic - _generate_audio_stream(): Streaming synthesis (optional) - is_available(): Provider availability check - get_available_voices(): Voice enumeration The base class ensures that all providers follow the same patterns for error handling, logging, and resource management, making the system more maintainable and predictable. """ def __init__(self, provider_name: str, supported_languages: list[str] = None): """ Initialize the TTS provider. Sets up the provider with basic configuration and creates necessary directories for temporary file storage. This constructor should be called by all subclass implementations. Args: provider_name: Unique identifier for this TTS provider (e.g., "kokoro", "dia"). Used for logging, error messages, and provider selection. supported_languages: List of ISO language codes supported by this provider (e.g., ["en", "zh", "es"]). If None, no language validation will be performed. Example: ```python class MyTTSProvider(TTSProviderBase): def __init__(self): super().__init__( provider_name="my_tts", supported_languages=["en", "es", "fr"] ) ``` """ self.provider_name = provider_name self.supported_languages = supported_languages or [] self._output_dir = self._ensure_output_directory() def synthesize(self, request: 'SpeechSynthesisRequest') -> 'AudioContent': """ Synthesize speech from text. Args: request: The speech synthesis request Returns: AudioContent: The synthesized audio Raises: SpeechSynthesisException: If synthesis fails """ try: logger.info(f"Starting synthesis with {self.provider_name} provider") self._validate_request(request) # Generate audio using provider-specific implementation audio_data, sample_rate = self._generate_audio(request) # Create AudioContent from the generated data from ...domain.models.audio_content import AudioContent audio_content = AudioContent( data=audio_data, format='wav', # Most providers output WAV sample_rate=sample_rate, duration=self._calculate_duration(audio_data, sample_rate), filename=f"{self.provider_name}_{int(time.time())}.wav" ) logger.info(f"Synthesis completed successfully with {self.provider_name}") return audio_content except Exception as e: logger.error(f"Synthesis failed with {self.provider_name}: {str(e)}") raise SpeechSynthesisException(f"TTS synthesis failed: {str(e)}") from e def synthesize_stream(self, request: 'SpeechSynthesisRequest') -> Iterator['AudioChunk']: """ Synthesize speech from text as a stream. Args: request: The speech synthesis request Returns: Iterator[AudioChunk]: Stream of audio chunks Raises: SpeechSynthesisException: If synthesis fails """ try: logger.info(f"Starting streaming synthesis with {self.provider_name} provider") self._validate_request(request) # Generate audio stream using provider-specific implementation chunk_index = 0 for audio_data, sample_rate, is_final in self._generate_audio_stream(request): from ...domain.models.audio_chunk import AudioChunk chunk = AudioChunk( data=audio_data, format='wav', sample_rate=sample_rate, chunk_index=chunk_index, is_final=is_final, timestamp=time.time() ) yield chunk chunk_index += 1 logger.info(f"Streaming synthesis completed with {self.provider_name}") except Exception as e: logger.error(f"Streaming synthesis failed with {self.provider_name}: {str(e)}") raise SpeechSynthesisException(f"TTS streaming synthesis failed: {str(e)}") from e @abstractmethod def _generate_audio(self, request: 'SpeechSynthesisRequest') -> tuple[bytes, int]: """ Generate audio data from synthesis request. Args: request: The speech synthesis request Returns: tuple: (audio_data_bytes, sample_rate) """ pass @abstractmethod def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]: """ Generate audio data stream from synthesis request. Args: request: The speech synthesis request Returns: Iterator: (audio_data_bytes, sample_rate, is_final) tuples """ pass @abstractmethod def is_available(self) -> bool: """ Check if the TTS provider is available and ready to use. Returns: bool: True if provider is available, False otherwise """ pass @abstractmethod def get_available_voices(self) -> list[str]: """ Get list of available voices for this provider. Returns: list[str]: List of voice identifiers """ pass def _validate_request(self, request: 'SpeechSynthesisRequest') -> None: """ Validate the synthesis request. Args: request: The synthesis request to validate Raises: SpeechSynthesisException: If request is invalid """ if not request.text_content.text.strip(): raise SpeechSynthesisException("Text content cannot be empty") if self.supported_languages and request.text_content.language not in self.supported_languages: raise SpeechSynthesisException( f"Language {request.text_content.language} not supported by {self.provider_name}. " f"Supported languages: {self.supported_languages}" ) available_voices = self.get_available_voices() if available_voices and request.voice_settings.voice_id not in available_voices: raise SpeechSynthesisException( f"Voice {request.voice_settings.voice_id} not available for {self.provider_name}. " f"Available voices: {available_voices}" ) def _ensure_output_directory(self) -> Path: """ Ensure output directory exists and return its path. Returns: Path: Path to the output directory """ output_dir = Path(tempfile.gettempdir()) / "tts_output" output_dir.mkdir(exist_ok=True) return output_dir def _generate_output_path(self, prefix: str = None, extension: str = "wav") -> Path: """ Generate a unique output path for audio files. Args: prefix: Optional prefix for the filename extension: File extension (default: wav) Returns: Path: Unique file path """ prefix = prefix or self.provider_name timestamp = int(time.time() * 1000) filename = f"{prefix}_{timestamp}.{extension}" return self._output_dir / filename def _calculate_duration(self, audio_data: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> float: """ Calculate audio duration from raw audio data. Args: audio_data: Raw audio data in bytes sample_rate: Sample rate in Hz channels: Number of audio channels (default: 1) sample_width: Sample width in bytes (default: 2 for 16-bit) Returns: float: Duration in seconds """ if not audio_data or sample_rate <= 0: return 0.0 bytes_per_sample = channels * sample_width total_samples = len(audio_data) // bytes_per_sample return total_samples / sample_rate def _cleanup_temp_files(self, max_age_hours: int = 24) -> None: """ Clean up old temporary files. Args: max_age_hours: Maximum age of files to keep in hours """ try: current_time = time.time() max_age_seconds = max_age_hours * 3600 for file_path in self._output_dir.glob("*"): if file_path.is_file(): file_age = current_time - file_path.stat().st_mtime if file_age > max_age_seconds: file_path.unlink() logger.info(f"Cleaned up old temp file: {file_path}") except Exception as e: logger.warning(f"Failed to cleanup temp files: {str(e)}") def _handle_provider_error(self, error: Exception, context: str = "") -> None: """ Handle provider-specific errors and convert to domain exceptions. Args: error: The original error context: Additional context about when the error occurred """ error_msg = f"{self.provider_name} error" if context: error_msg += f" during {context}" error_msg += f": {str(error)}" logger.error(error_msg, exception=error) raise SpeechSynthesisException(error_msg) from error