Spaces:
Build error
Build error
| """Base class for STT provider implementations.""" | |
| import logging | |
| import os | |
| import tempfile | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import Optional, TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from ...domain.models.audio_content import AudioContent | |
| from ...domain.models.text_content import TextContent | |
| from ...domain.interfaces.speech_recognition import ISpeechRecognitionService | |
| from ...domain.exceptions import SpeechRecognitionException | |
| logger = logging.getLogger(__name__) | |
| class STTProviderBase(ISpeechRecognitionService, ABC): | |
| """Abstract base class for STT provider implementations.""" | |
| def __init__(self, provider_name: str, supported_languages: list[str] = None): | |
| """ | |
| Initialize the STT provider. | |
| Args: | |
| provider_name: Name of the STT provider | |
| supported_languages: List of supported language codes | |
| """ | |
| self.provider_name = provider_name | |
| self.supported_languages = supported_languages or [] | |
| self._temp_dir = self._ensure_temp_directory() | |
| def transcribe(self, audio: 'AudioContent', model: str) -> 'TextContent': | |
| """ | |
| Transcribe audio content to text. | |
| Args: | |
| audio: The audio content to transcribe | |
| model: The STT model to use for transcription | |
| Returns: | |
| TextContent: The transcribed text | |
| Raises: | |
| SpeechRecognitionException: If transcription fails | |
| """ | |
| try: | |
| logger.info(f"Starting transcription with {self.provider_name} provider using model {model}") | |
| self._validate_audio(audio) | |
| # Preprocess audio if needed | |
| processed_audio_path = self._preprocess_audio(audio) | |
| try: | |
| # Perform transcription using provider-specific implementation | |
| transcribed_text = self._perform_transcription(processed_audio_path, model) | |
| # Create TextContent from transcription result | |
| from ...domain.models.text_content import TextContent | |
| # Detect language if not specified (default to English) | |
| detected_language = self._detect_language(transcribed_text) or 'en' | |
| text_content = TextContent( | |
| text=transcribed_text, | |
| language=detected_language, | |
| encoding='utf-8' | |
| ) | |
| logger.info(f"Transcription completed successfully with {self.provider_name}") | |
| return text_content | |
| finally: | |
| # Clean up temporary audio file | |
| self._cleanup_temp_file(processed_audio_path) | |
| except Exception as e: | |
| logger.error(f"Transcription failed with {self.provider_name}: {str(e)}") | |
| raise SpeechRecognitionException(f"STT transcription failed: {str(e)}") from e | |
| def _perform_transcription(self, audio_path: Path, model: str) -> str: | |
| """ | |
| Perform the actual transcription using provider-specific implementation. | |
| Args: | |
| audio_path: Path to the preprocessed audio file | |
| model: The STT model to use | |
| Returns: | |
| str: The transcribed text | |
| """ | |
| pass | |
| def is_available(self) -> bool: | |
| """ | |
| Check if the STT provider is available and ready to use. | |
| Returns: | |
| bool: True if provider is available, False otherwise | |
| """ | |
| pass | |
| def get_available_models(self) -> list[str]: | |
| """ | |
| Get list of available models for this provider. | |
| Returns: | |
| list[str]: List of model identifiers | |
| """ | |
| pass | |
| def get_default_model(self) -> str: | |
| """ | |
| Get the default model for this provider. | |
| Returns: | |
| str: Default model name | |
| """ | |
| pass | |
| def _preprocess_audio(self, audio: 'AudioContent') -> Path: | |
| """ | |
| Preprocess audio content for transcription. | |
| Args: | |
| audio: The audio content to preprocess | |
| Returns: | |
| Path: Path to the preprocessed audio file | |
| """ | |
| try: | |
| # Create temporary file for audio processing | |
| temp_file = self._temp_dir / f"audio_{id(audio)}.wav" | |
| # Write audio data to temporary file | |
| with open(temp_file, 'wb') as f: | |
| f.write(audio.data) | |
| # Convert to required format if needed | |
| processed_file = self._convert_audio_format(temp_file, audio) | |
| logger.debug(f"Audio preprocessed and saved to: {processed_file}") | |
| return processed_file | |
| except Exception as e: | |
| logger.error(f"Audio preprocessing failed: {str(e)}") | |
| raise SpeechRecognitionException(f"Audio preprocessing failed: {str(e)}") from e | |
| def _convert_audio_format(self, audio_path: Path, audio: 'AudioContent') -> Path: | |
| """ | |
| Convert audio to the required format for transcription. | |
| Args: | |
| audio_path: Path to the original audio file | |
| audio: The audio content metadata | |
| Returns: | |
| Path: Path to the converted audio file | |
| """ | |
| try: | |
| # Import audio processing library | |
| from pydub import AudioSegment | |
| # Load audio file | |
| if audio.format.lower() == 'mp3': | |
| audio_segment = AudioSegment.from_mp3(audio_path) | |
| elif audio.format.lower() == 'wav': | |
| audio_segment = AudioSegment.from_wav(audio_path) | |
| elif audio.format.lower() == 'flac': | |
| audio_segment = AudioSegment.from_file(audio_path, format='flac') | |
| elif audio.format.lower() == 'ogg': | |
| audio_segment = AudioSegment.from_ogg(audio_path) | |
| else: | |
| # Try to load as generic audio file | |
| audio_segment = AudioSegment.from_file(audio_path) | |
| # Convert to standard format for STT (16kHz, mono, WAV) | |
| standardized_audio = audio_segment.set_frame_rate(16000).set_channels(1) | |
| # Create output path | |
| output_path = audio_path.with_suffix('.wav') | |
| if output_path == audio_path: | |
| output_path = audio_path.with_name(f"converted_{audio_path.name}") | |
| # Export converted audio | |
| standardized_audio.export(output_path, format="wav") | |
| logger.debug(f"Audio converted from {audio.format} to WAV: {output_path}") | |
| return output_path | |
| except ImportError: | |
| logger.warning("pydub not available, using original audio file") | |
| return audio_path | |
| except Exception as e: | |
| logger.warning(f"Audio conversion failed, using original file: {str(e)}") | |
| return audio_path | |
| def _validate_audio(self, audio: 'AudioContent') -> None: | |
| """ | |
| Validate the audio content for transcription. | |
| Args: | |
| audio: The audio content to validate | |
| Raises: | |
| SpeechRecognitionException: If audio is invalid | |
| """ | |
| if not audio.data: | |
| raise SpeechRecognitionException("Audio data cannot be empty") | |
| if audio.duration > 3600: # 1 hour limit | |
| raise SpeechRecognitionException("Audio duration exceeds maximum limit of 1 hour") | |
| if audio.duration < 0.1: # Minimum 100ms | |
| raise SpeechRecognitionException("Audio duration too short (minimum 100ms)") | |
| if not audio.is_valid_format: | |
| raise SpeechRecognitionException(f"Unsupported audio format: {audio.format}") | |
| def _detect_language(self, text: str) -> Optional[str]: | |
| """ | |
| Detect the language of transcribed text. | |
| Args: | |
| text: The transcribed text | |
| Returns: | |
| Optional[str]: Detected language code or None if detection fails | |
| """ | |
| try: | |
| # Simple heuristic-based language detection | |
| # This is a basic implementation - in production, you might use langdetect or similar | |
| # Check for common English words | |
| english_indicators = ['the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with'] | |
| text_lower = text.lower() | |
| english_count = sum(1 for word in english_indicators if word in text_lower) | |
| if english_count >= 2: | |
| return 'en' | |
| # Default to English if uncertain | |
| return 'en' | |
| except Exception as e: | |
| logger.warning(f"Language detection failed: {str(e)}") | |
| return None | |
| def _ensure_temp_directory(self) -> Path: | |
| """ | |
| Ensure temporary directory exists and return its path. | |
| Returns: | |
| Path: Path to the temporary directory | |
| """ | |
| temp_dir = Path(tempfile.gettempdir()) / "stt_temp" | |
| temp_dir.mkdir(exist_ok=True) | |
| return temp_dir | |
| def _cleanup_temp_file(self, file_path: Path) -> None: | |
| """ | |
| Clean up a temporary file. | |
| Args: | |
| file_path: Path to the file to clean up | |
| """ | |
| try: | |
| if file_path.exists(): | |
| file_path.unlink() | |
| logger.debug(f"Cleaned up temp file: {file_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}") | |
| def _cleanup_old_temp_files(self, max_age_hours: int = 24) -> None: | |
| """ | |
| Clean up old temporary files. | |
| Args: | |
| max_age_hours: Maximum age of files to keep in hours | |
| """ | |
| try: | |
| import time | |
| current_time = time.time() | |
| max_age_seconds = max_age_hours * 3600 | |
| for file_path in self._temp_dir.glob("*"): | |
| if file_path.is_file(): | |
| file_age = current_time - file_path.stat().st_mtime | |
| if file_age > max_age_seconds: | |
| file_path.unlink() | |
| logger.debug(f"Cleaned up old temp file: {file_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to cleanup old temp files: {str(e)}") | |
| def _handle_provider_error(self, error: Exception, context: str = "") -> None: | |
| """ | |
| Handle provider-specific errors and convert to domain exceptions. | |
| Args: | |
| error: The original error | |
| context: Additional context about when the error occurred | |
| """ | |
| error_msg = f"{self.provider_name} error" | |
| if context: | |
| error_msg += f" during {context}" | |
| error_msg += f": {str(error)}" | |
| logger.error(error_msg, exc_info=True) | |
| raise SpeechRecognitionException(error_msg) from error |