Michael Hu
add more logs
fdc056d
"""Base class for STT provider implementations."""
import logging
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.audio_content import AudioContent
from ...domain.models.text_content import TextContent
from ...domain.interfaces.speech_recognition import ISpeechRecognitionService
from ...domain.exceptions import SpeechRecognitionException
logger = logging.getLogger(__name__)
class STTProviderBase(ISpeechRecognitionService, ABC):
"""Abstract base class for STT provider implementations."""
def __init__(self, provider_name: str, supported_languages: list[str] = None):
"""
Initialize the STT provider.
Args:
provider_name: Name of the STT provider
supported_languages: List of supported language codes
"""
self.provider_name = provider_name
self.supported_languages = supported_languages or []
self._temp_dir = self._ensure_temp_directory()
def transcribe(self, audio: 'AudioContent', model: str) -> 'TextContent':
"""
Transcribe audio content to text.
Args:
audio: The audio content to transcribe
model: The STT model to use for transcription
Returns:
TextContent: The transcribed text
Raises:
SpeechRecognitionException: If transcription fails
"""
try:
logger.info(f"Starting transcription with {self.provider_name} provider using model {model}")
self._validate_audio(audio)
# Preprocess audio if needed
processed_audio_path = self._preprocess_audio(audio)
try:
# Perform transcription using provider-specific implementation
transcribed_text = self._perform_transcription(processed_audio_path, model)
# Create TextContent from transcription result
from ...domain.models.text_content import TextContent
# Detect language if not specified (default to English)
detected_language = self._detect_language(transcribed_text) or 'en'
text_content = TextContent(
text=transcribed_text,
language=detected_language,
encoding='utf-8'
)
logger.info(f"Transcription completed successfully with {self.provider_name}")
return text_content
finally:
# Clean up temporary audio file
self._cleanup_temp_file(processed_audio_path)
except Exception as e:
logger.error(f"Transcription failed with {self.provider_name}: {str(e)}")
raise SpeechRecognitionException(f"STT transcription failed: {str(e)}") from e
@abstractmethod
def _perform_transcription(self, audio_path: Path, model: str) -> str:
"""
Perform the actual transcription using provider-specific implementation.
Args:
audio_path: Path to the preprocessed audio file
model: The STT model to use
Returns:
str: The transcribed text
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""
Check if the STT provider is available and ready to use.
Returns:
bool: True if provider is available, False otherwise
"""
pass
@abstractmethod
def get_available_models(self) -> list[str]:
"""
Get list of available models for this provider.
Returns:
list[str]: List of model identifiers
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
Get the default model for this provider.
Returns:
str: Default model name
"""
pass
def _preprocess_audio(self, audio: 'AudioContent') -> Path:
"""
Preprocess audio content for transcription.
Args:
audio: The audio content to preprocess
Returns:
Path: Path to the preprocessed audio file
"""
try:
# Create temporary file for audio processing
temp_file = self._temp_dir / f"audio_{id(audio)}.wav"
# Write audio data to temporary file
with open(temp_file, 'wb') as f:
f.write(audio.data)
# Convert to required format if needed
processed_file = self._convert_audio_format(temp_file, audio)
logger.info(f"Audio preprocessed and saved to: {processed_file}")
return processed_file
except Exception as e:
logger.error(f"Audio preprocessing failed: {str(e)}")
raise SpeechRecognitionException(f"Audio preprocessing failed: {str(e)}") from e
def _convert_audio_format(self, audio_path: Path, audio: 'AudioContent') -> Path:
"""
Convert audio to the required format for transcription.
Args:
audio_path: Path to the original audio file
audio: The audio content metadata
Returns:
Path: Path to the converted audio file
"""
try:
# Import audio processing library
from pydub import AudioSegment
# Load audio file
if audio.format.lower() == 'mp3':
audio_segment = AudioSegment.from_mp3(audio_path)
elif audio.format.lower() == 'wav':
audio_segment = AudioSegment.from_wav(audio_path)
elif audio.format.lower() == 'flac':
audio_segment = AudioSegment.from_file(audio_path, format='flac')
elif audio.format.lower() == 'ogg':
audio_segment = AudioSegment.from_ogg(audio_path)
else:
# Try to load as generic audio file
audio_segment = AudioSegment.from_file(audio_path)
# Convert to standard format for STT (16kHz, mono, WAV)
standardized_audio = audio_segment.set_frame_rate(16000).set_channels(1)
# Create output path
output_path = audio_path.with_suffix('.wav')
if output_path == audio_path:
output_path = audio_path.with_name(f"converted_{audio_path.name}")
# Export converted audio
standardized_audio.export(output_path, format="wav")
logger.info(f"Audio converted from {audio.format} to WAV: {output_path}")
return output_path
except ImportError:
logger.warning("pydub not available, using original audio file")
return audio_path
except Exception as e:
logger.warning(f"Audio conversion failed, using original file: {str(e)}")
return audio_path
def _validate_audio(self, audio: 'AudioContent') -> None:
"""
Validate the audio content for transcription.
Args:
audio: The audio content to validate
Raises:
SpeechRecognitionException: If audio is invalid
"""
if not audio.data:
raise SpeechRecognitionException("Audio data cannot be empty")
if audio.duration > 3600: # 1 hour limit
raise SpeechRecognitionException("Audio duration exceeds maximum limit of 1 hour")
if audio.duration < 0.1: # Minimum 100ms
raise SpeechRecognitionException("Audio duration too short (minimum 100ms)")
if not audio.is_valid_format:
raise SpeechRecognitionException(f"Unsupported audio format: {audio.format}")
def _detect_language(self, text: str) -> Optional[str]:
"""
Detect the language of transcribed text.
Args:
text: The transcribed text
Returns:
Optional[str]: Detected language code or None if detection fails
"""
try:
# Simple heuristic-based language detection
# This is a basic implementation - in production, you might use langdetect or similar
# Check for common English words
english_indicators = ['the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with']
text_lower = text.lower()
english_count = sum(1 for word in english_indicators if word in text_lower)
if english_count >= 2:
return 'en'
# Default to English if uncertain
return 'en'
except Exception as e:
logger.warning(f"Language detection failed: {str(e)}")
return None
def _ensure_temp_directory(self) -> Path:
"""
Ensure temporary directory exists and return its path.
Returns:
Path: Path to the temporary directory
"""
temp_dir = Path(tempfile.gettempdir()) / "stt_temp"
temp_dir.mkdir(exist_ok=True)
return temp_dir
def _cleanup_temp_file(self, file_path: Path) -> None:
"""
Clean up a temporary file.
Args:
file_path: Path to the file to clean up
"""
try:
if file_path.exists():
file_path.unlink()
logger.info(f"Cleaned up temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}")
def _cleanup_old_temp_files(self, max_age_hours: int = 24) -> None:
"""
Clean up old temporary files.
Args:
max_age_hours: Maximum age of files to keep in hours
"""
try:
import time
current_time = time.time()
max_age_seconds = max_age_hours * 3600
for file_path in self._temp_dir.glob("*"):
if file_path.is_file():
file_age = current_time - file_path.stat().st_mtime
if file_age > max_age_seconds:
file_path.unlink()
logger.info(f"Cleaned up old temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup old temp files: {str(e)}")
def _handle_provider_error(self, error: Exception, context: str = "") -> None:
"""
Handle provider-specific errors and convert to domain exceptions.
Args:
error: The original error
context: Additional context about when the error occurred
"""
error_msg = f"{self.provider_name} error"
if context:
error_msg += f" during {context}"
error_msg += f": {str(error)}"
logger.error(error_msg, exception=error)
raise SpeechRecognitionException(error_msg) from error