Spaces:
Build error
Build error
"""Base class for STT provider implementations.""" | |
import logging | |
import os | |
import tempfile | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from typing import Optional, TYPE_CHECKING | |
if TYPE_CHECKING: | |
from ...domain.models.audio_content import AudioContent | |
from ...domain.models.text_content import TextContent | |
from ...domain.interfaces.speech_recognition import ISpeechRecognitionService | |
from ...domain.exceptions import SpeechRecognitionException | |
logger = logging.getLogger(__name__) | |
class STTProviderBase(ISpeechRecognitionService, ABC): | |
"""Abstract base class for STT provider implementations.""" | |
def __init__(self, provider_name: str, supported_languages: list[str] = None): | |
""" | |
Initialize the STT provider. | |
Args: | |
provider_name: Name of the STT provider | |
supported_languages: List of supported language codes | |
""" | |
self.provider_name = provider_name | |
self.supported_languages = supported_languages or [] | |
self._temp_dir = self._ensure_temp_directory() | |
def transcribe(self, audio: 'AudioContent', model: str) -> 'TextContent': | |
""" | |
Transcribe audio content to text. | |
Args: | |
audio: The audio content to transcribe | |
model: The STT model to use for transcription | |
Returns: | |
TextContent: The transcribed text | |
Raises: | |
SpeechRecognitionException: If transcription fails | |
""" | |
try: | |
logger.info(f"Starting transcription with {self.provider_name} provider using model {model}") | |
self._validate_audio(audio) | |
# Preprocess audio if needed | |
processed_audio_path = self._preprocess_audio(audio) | |
try: | |
# Perform transcription using provider-specific implementation | |
transcribed_text = self._perform_transcription(processed_audio_path, model) | |
# Create TextContent from transcription result | |
from ...domain.models.text_content import TextContent | |
# Detect language if not specified (default to English) | |
detected_language = self._detect_language(transcribed_text) or 'en' | |
text_content = TextContent( | |
text=transcribed_text, | |
language=detected_language, | |
encoding='utf-8' | |
) | |
logger.info(f"Transcription completed successfully with {self.provider_name}") | |
return text_content | |
finally: | |
# Clean up temporary audio file | |
self._cleanup_temp_file(processed_audio_path) | |
except Exception as e: | |
logger.error(f"Transcription failed with {self.provider_name}: {str(e)}") | |
raise SpeechRecognitionException(f"STT transcription failed: {str(e)}") from e | |
def _perform_transcription(self, audio_path: Path, model: str) -> str: | |
""" | |
Perform the actual transcription using provider-specific implementation. | |
Args: | |
audio_path: Path to the preprocessed audio file | |
model: The STT model to use | |
Returns: | |
str: The transcribed text | |
""" | |
pass | |
def is_available(self) -> bool: | |
""" | |
Check if the STT provider is available and ready to use. | |
Returns: | |
bool: True if provider is available, False otherwise | |
""" | |
pass | |
def get_available_models(self) -> list[str]: | |
""" | |
Get list of available models for this provider. | |
Returns: | |
list[str]: List of model identifiers | |
""" | |
pass | |
def get_default_model(self) -> str: | |
""" | |
Get the default model for this provider. | |
Returns: | |
str: Default model name | |
""" | |
pass | |
def _preprocess_audio(self, audio: 'AudioContent') -> Path: | |
""" | |
Preprocess audio content for transcription. | |
Args: | |
audio: The audio content to preprocess | |
Returns: | |
Path: Path to the preprocessed audio file | |
""" | |
try: | |
# Create temporary file for audio processing | |
temp_file = self._temp_dir / f"audio_{id(audio)}.wav" | |
# Write audio data to temporary file | |
with open(temp_file, 'wb') as f: | |
f.write(audio.data) | |
# Convert to required format if needed | |
processed_file = self._convert_audio_format(temp_file, audio) | |
logger.info(f"Audio preprocessed and saved to: {processed_file}") | |
return processed_file | |
except Exception as e: | |
logger.error(f"Audio preprocessing failed: {str(e)}") | |
raise SpeechRecognitionException(f"Audio preprocessing failed: {str(e)}") from e | |
def _convert_audio_format(self, audio_path: Path, audio: 'AudioContent') -> Path: | |
""" | |
Convert audio to the required format for transcription. | |
Args: | |
audio_path: Path to the original audio file | |
audio: The audio content metadata | |
Returns: | |
Path: Path to the converted audio file | |
""" | |
try: | |
# Import audio processing library | |
from pydub import AudioSegment | |
# Load audio file | |
if audio.format.lower() == 'mp3': | |
audio_segment = AudioSegment.from_mp3(audio_path) | |
elif audio.format.lower() == 'wav': | |
audio_segment = AudioSegment.from_wav(audio_path) | |
elif audio.format.lower() == 'flac': | |
audio_segment = AudioSegment.from_file(audio_path, format='flac') | |
elif audio.format.lower() == 'ogg': | |
audio_segment = AudioSegment.from_ogg(audio_path) | |
else: | |
# Try to load as generic audio file | |
audio_segment = AudioSegment.from_file(audio_path) | |
# Convert to standard format for STT (16kHz, mono, WAV) | |
standardized_audio = audio_segment.set_frame_rate(16000).set_channels(1) | |
# Create output path | |
output_path = audio_path.with_suffix('.wav') | |
if output_path == audio_path: | |
output_path = audio_path.with_name(f"converted_{audio_path.name}") | |
# Export converted audio | |
standardized_audio.export(output_path, format="wav") | |
logger.info(f"Audio converted from {audio.format} to WAV: {output_path}") | |
return output_path | |
except ImportError: | |
logger.warning("pydub not available, using original audio file") | |
return audio_path | |
except Exception as e: | |
logger.warning(f"Audio conversion failed, using original file: {str(e)}") | |
return audio_path | |
def _validate_audio(self, audio: 'AudioContent') -> None: | |
""" | |
Validate the audio content for transcription. | |
Args: | |
audio: The audio content to validate | |
Raises: | |
SpeechRecognitionException: If audio is invalid | |
""" | |
if not audio.data: | |
raise SpeechRecognitionException("Audio data cannot be empty") | |
if audio.duration > 3600: # 1 hour limit | |
raise SpeechRecognitionException("Audio duration exceeds maximum limit of 1 hour") | |
if audio.duration < 0.1: # Minimum 100ms | |
raise SpeechRecognitionException("Audio duration too short (minimum 100ms)") | |
if not audio.is_valid_format: | |
raise SpeechRecognitionException(f"Unsupported audio format: {audio.format}") | |
def _detect_language(self, text: str) -> Optional[str]: | |
""" | |
Detect the language of transcribed text. | |
Args: | |
text: The transcribed text | |
Returns: | |
Optional[str]: Detected language code or None if detection fails | |
""" | |
try: | |
# Simple heuristic-based language detection | |
# This is a basic implementation - in production, you might use langdetect or similar | |
# Check for common English words | |
english_indicators = ['the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with'] | |
text_lower = text.lower() | |
english_count = sum(1 for word in english_indicators if word in text_lower) | |
if english_count >= 2: | |
return 'en' | |
# Default to English if uncertain | |
return 'en' | |
except Exception as e: | |
logger.warning(f"Language detection failed: {str(e)}") | |
return None | |
def _ensure_temp_directory(self) -> Path: | |
""" | |
Ensure temporary directory exists and return its path. | |
Returns: | |
Path: Path to the temporary directory | |
""" | |
temp_dir = Path(tempfile.gettempdir()) / "stt_temp" | |
temp_dir.mkdir(exist_ok=True) | |
return temp_dir | |
def _cleanup_temp_file(self, file_path: Path) -> None: | |
""" | |
Clean up a temporary file. | |
Args: | |
file_path: Path to the file to clean up | |
""" | |
try: | |
if file_path.exists(): | |
file_path.unlink() | |
logger.info(f"Cleaned up temp file: {file_path}") | |
except Exception as e: | |
logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}") | |
def _cleanup_old_temp_files(self, max_age_hours: int = 24) -> None: | |
""" | |
Clean up old temporary files. | |
Args: | |
max_age_hours: Maximum age of files to keep in hours | |
""" | |
try: | |
import time | |
current_time = time.time() | |
max_age_seconds = max_age_hours * 3600 | |
for file_path in self._temp_dir.glob("*"): | |
if file_path.is_file(): | |
file_age = current_time - file_path.stat().st_mtime | |
if file_age > max_age_seconds: | |
file_path.unlink() | |
logger.info(f"Cleaned up old temp file: {file_path}") | |
except Exception as e: | |
logger.warning(f"Failed to cleanup old temp files: {str(e)}") | |
def _handle_provider_error(self, error: Exception, context: str = "") -> None: | |
""" | |
Handle provider-specific errors and convert to domain exceptions. | |
Args: | |
error: The original error | |
context: Additional context about when the error occurred | |
""" | |
error_msg = f"{self.provider_name} error" | |
if context: | |
error_msg += f" during {context}" | |
error_msg += f": {str(error)}" | |
logger.error(error_msg, exception=error) | |
raise SpeechRecognitionException(error_msg) from error |