"""Base class for STT provider implementations.""" import logging import os import tempfile from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, TYPE_CHECKING if TYPE_CHECKING: from ...domain.models.audio_content import AudioContent from ...domain.models.text_content import TextContent from ...domain.interfaces.speech_recognition import ISpeechRecognitionService from ...domain.exceptions import SpeechRecognitionException logger = logging.getLogger(__name__) class STTProviderBase(ISpeechRecognitionService, ABC): """Abstract base class for STT provider implementations.""" def __init__(self, provider_name: str, supported_languages: list[str] = None): """ Initialize the STT provider. Args: provider_name: Name of the STT provider supported_languages: List of supported language codes """ self.provider_name = provider_name self.supported_languages = supported_languages or [] self._temp_dir = self._ensure_temp_directory() def transcribe(self, audio: 'AudioContent', model: str) -> 'TextContent': """ Transcribe audio content to text. Args: audio: The audio content to transcribe model: The STT model to use for transcription Returns: TextContent: The transcribed text Raises: SpeechRecognitionException: If transcription fails """ try: logger.info(f"Starting transcription with {self.provider_name} provider using model {model}") self._validate_audio(audio) # Preprocess audio if needed processed_audio_path = self._preprocess_audio(audio) try: # Perform transcription using provider-specific implementation transcribed_text = self._perform_transcription(processed_audio_path, model) # Create TextContent from transcription result from ...domain.models.text_content import TextContent # Detect language if not specified (default to English) detected_language = self._detect_language(transcribed_text) or 'en' text_content = TextContent( text=transcribed_text, language=detected_language, encoding='utf-8' ) logger.info(f"Transcription completed successfully with {self.provider_name}") return text_content finally: # Clean up temporary audio file self._cleanup_temp_file(processed_audio_path) except Exception as e: logger.error(f"Transcription failed with {self.provider_name}: {str(e)}") raise SpeechRecognitionException(f"STT transcription failed: {str(e)}") from e @abstractmethod def _perform_transcription(self, audio_path: Path, model: str) -> str: """ Perform the actual transcription using provider-specific implementation. Args: audio_path: Path to the preprocessed audio file model: The STT model to use Returns: str: The transcribed text """ pass @abstractmethod def is_available(self) -> bool: """ Check if the STT provider is available and ready to use. Returns: bool: True if provider is available, False otherwise """ pass @abstractmethod def get_available_models(self) -> list[str]: """ Get list of available models for this provider. Returns: list[str]: List of model identifiers """ pass @abstractmethod def get_default_model(self) -> str: """ Get the default model for this provider. Returns: str: Default model name """ pass def _preprocess_audio(self, audio: 'AudioContent') -> Path: """ Preprocess audio content for transcription. Args: audio: The audio content to preprocess Returns: Path: Path to the preprocessed audio file """ try: # Create temporary file for audio processing temp_file = self._temp_dir / f"audio_{id(audio)}.wav" # Write audio data to temporary file with open(temp_file, 'wb') as f: f.write(audio.data) # Convert to required format if needed processed_file = self._convert_audio_format(temp_file, audio) logger.info(f"Audio preprocessed and saved to: {processed_file}") return processed_file except Exception as e: logger.error(f"Audio preprocessing failed: {str(e)}") raise SpeechRecognitionException(f"Audio preprocessing failed: {str(e)}") from e def _convert_audio_format(self, audio_path: Path, audio: 'AudioContent') -> Path: """ Convert audio to the required format for transcription. Args: audio_path: Path to the original audio file audio: The audio content metadata Returns: Path: Path to the converted audio file """ try: # Import audio processing library from pydub import AudioSegment # Load audio file if audio.format.lower() == 'mp3': audio_segment = AudioSegment.from_mp3(audio_path) elif audio.format.lower() == 'wav': audio_segment = AudioSegment.from_wav(audio_path) elif audio.format.lower() == 'flac': audio_segment = AudioSegment.from_file(audio_path, format='flac') elif audio.format.lower() == 'ogg': audio_segment = AudioSegment.from_ogg(audio_path) else: # Try to load as generic audio file audio_segment = AudioSegment.from_file(audio_path) # Convert to standard format for STT (16kHz, mono, WAV) standardized_audio = audio_segment.set_frame_rate(16000).set_channels(1) # Create output path output_path = audio_path.with_suffix('.wav') if output_path == audio_path: output_path = audio_path.with_name(f"converted_{audio_path.name}") # Export converted audio standardized_audio.export(output_path, format="wav") logger.info(f"Audio converted from {audio.format} to WAV: {output_path}") return output_path except ImportError: logger.warning("pydub not available, using original audio file") return audio_path except Exception as e: logger.warning(f"Audio conversion failed, using original file: {str(e)}") return audio_path def _validate_audio(self, audio: 'AudioContent') -> None: """ Validate the audio content for transcription. Args: audio: The audio content to validate Raises: SpeechRecognitionException: If audio is invalid """ if not audio.data: raise SpeechRecognitionException("Audio data cannot be empty") if audio.duration > 3600: # 1 hour limit raise SpeechRecognitionException("Audio duration exceeds maximum limit of 1 hour") if audio.duration < 0.1: # Minimum 100ms raise SpeechRecognitionException("Audio duration too short (minimum 100ms)") if not audio.is_valid_format: raise SpeechRecognitionException(f"Unsupported audio format: {audio.format}") def _detect_language(self, text: str) -> Optional[str]: """ Detect the language of transcribed text. Args: text: The transcribed text Returns: Optional[str]: Detected language code or None if detection fails """ try: # Simple heuristic-based language detection # This is a basic implementation - in production, you might use langdetect or similar # Check for common English words english_indicators = ['the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with'] text_lower = text.lower() english_count = sum(1 for word in english_indicators if word in text_lower) if english_count >= 2: return 'en' # Default to English if uncertain return 'en' except Exception as e: logger.warning(f"Language detection failed: {str(e)}") return None def _ensure_temp_directory(self) -> Path: """ Ensure temporary directory exists and return its path. Returns: Path: Path to the temporary directory """ temp_dir = Path(tempfile.gettempdir()) / "stt_temp" temp_dir.mkdir(exist_ok=True) return temp_dir def _cleanup_temp_file(self, file_path: Path) -> None: """ Clean up a temporary file. Args: file_path: Path to the file to clean up """ try: if file_path.exists(): file_path.unlink() logger.info(f"Cleaned up temp file: {file_path}") except Exception as e: logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}") def _cleanup_old_temp_files(self, max_age_hours: int = 24) -> None: """ Clean up old temporary files. Args: max_age_hours: Maximum age of files to keep in hours """ try: import time current_time = time.time() max_age_seconds = max_age_hours * 3600 for file_path in self._temp_dir.glob("*"): if file_path.is_file(): file_age = current_time - file_path.stat().st_mtime if file_age > max_age_seconds: file_path.unlink() logger.info(f"Cleaned up old temp file: {file_path}") except Exception as e: logger.warning(f"Failed to cleanup old temp files: {str(e)}") def _handle_provider_error(self, error: Exception, context: str = "") -> None: """ Handle provider-specific errors and convert to domain exceptions. Args: error: The original error context: Additional context about when the error occurred """ error_msg = f"{self.provider_name} error" if context: error_msg += f" during {context}" error_msg += f": {str(error)}" logger.error(error_msg, exception=error) raise SpeechRecognitionException(error_msg) from error