Spaces:
Build error
Build error
File size: 10,811 Bytes
e3cb97b 1be582a e3cb97b fdc056d e3cb97b fdc056d e3cb97b fdc056d e3cb97b fdc056d e3cb97b 6514731 e3cb97b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
"""Base class for STT provider implementations."""
import logging
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.audio_content import AudioContent
from ...domain.models.text_content import TextContent
from ...domain.interfaces.speech_recognition import ISpeechRecognitionService
from ...domain.exceptions import SpeechRecognitionException
logger = logging.getLogger(__name__)
class STTProviderBase(ISpeechRecognitionService, ABC):
"""Abstract base class for STT provider implementations."""
def __init__(self, provider_name: str, supported_languages: list[str] = None):
"""
Initialize the STT provider.
Args:
provider_name: Name of the STT provider
supported_languages: List of supported language codes
"""
self.provider_name = provider_name
self.supported_languages = supported_languages or []
self._temp_dir = self._ensure_temp_directory()
def transcribe(self, audio: 'AudioContent', model: str) -> 'TextContent':
"""
Transcribe audio content to text.
Args:
audio: The audio content to transcribe
model: The STT model to use for transcription
Returns:
TextContent: The transcribed text
Raises:
SpeechRecognitionException: If transcription fails
"""
try:
logger.info(f"Starting transcription with {self.provider_name} provider using model {model}")
self._validate_audio(audio)
# Preprocess audio if needed
processed_audio_path = self._preprocess_audio(audio)
try:
# Perform transcription using provider-specific implementation
transcribed_text = self._perform_transcription(processed_audio_path, model)
# Create TextContent from transcription result
from ...domain.models.text_content import TextContent
# Detect language if not specified (default to English)
detected_language = self._detect_language(transcribed_text) or 'en'
text_content = TextContent(
text=transcribed_text,
language=detected_language,
encoding='utf-8'
)
logger.info(f"Transcription completed successfully with {self.provider_name}")
return text_content
finally:
# Clean up temporary audio file
self._cleanup_temp_file(processed_audio_path)
except Exception as e:
logger.error(f"Transcription failed with {self.provider_name}: {str(e)}")
raise SpeechRecognitionException(f"STT transcription failed: {str(e)}") from e
@abstractmethod
def _perform_transcription(self, audio_path: Path, model: str) -> str:
"""
Perform the actual transcription using provider-specific implementation.
Args:
audio_path: Path to the preprocessed audio file
model: The STT model to use
Returns:
str: The transcribed text
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""
Check if the STT provider is available and ready to use.
Returns:
bool: True if provider is available, False otherwise
"""
pass
@abstractmethod
def get_available_models(self) -> list[str]:
"""
Get list of available models for this provider.
Returns:
list[str]: List of model identifiers
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
Get the default model for this provider.
Returns:
str: Default model name
"""
pass
def _preprocess_audio(self, audio: 'AudioContent') -> Path:
"""
Preprocess audio content for transcription.
Args:
audio: The audio content to preprocess
Returns:
Path: Path to the preprocessed audio file
"""
try:
# Create temporary file for audio processing
temp_file = self._temp_dir / f"audio_{id(audio)}.wav"
# Write audio data to temporary file
with open(temp_file, 'wb') as f:
f.write(audio.data)
# Convert to required format if needed
processed_file = self._convert_audio_format(temp_file, audio)
logger.info(f"Audio preprocessed and saved to: {processed_file}")
return processed_file
except Exception as e:
logger.error(f"Audio preprocessing failed: {str(e)}")
raise SpeechRecognitionException(f"Audio preprocessing failed: {str(e)}") from e
def _convert_audio_format(self, audio_path: Path, audio: 'AudioContent') -> Path:
"""
Convert audio to the required format for transcription.
Args:
audio_path: Path to the original audio file
audio: The audio content metadata
Returns:
Path: Path to the converted audio file
"""
try:
# Import audio processing library
from pydub import AudioSegment
# Load audio file
if audio.format.lower() == 'mp3':
audio_segment = AudioSegment.from_mp3(audio_path)
elif audio.format.lower() == 'wav':
audio_segment = AudioSegment.from_wav(audio_path)
elif audio.format.lower() == 'flac':
audio_segment = AudioSegment.from_file(audio_path, format='flac')
elif audio.format.lower() == 'ogg':
audio_segment = AudioSegment.from_ogg(audio_path)
else:
# Try to load as generic audio file
audio_segment = AudioSegment.from_file(audio_path)
# Convert to standard format for STT (16kHz, mono, WAV)
standardized_audio = audio_segment.set_frame_rate(16000).set_channels(1)
# Create output path
output_path = audio_path.with_suffix('.wav')
if output_path == audio_path:
output_path = audio_path.with_name(f"converted_{audio_path.name}")
# Export converted audio
standardized_audio.export(output_path, format="wav")
logger.info(f"Audio converted from {audio.format} to WAV: {output_path}")
return output_path
except ImportError:
logger.warning("pydub not available, using original audio file")
return audio_path
except Exception as e:
logger.warning(f"Audio conversion failed, using original file: {str(e)}")
return audio_path
def _validate_audio(self, audio: 'AudioContent') -> None:
"""
Validate the audio content for transcription.
Args:
audio: The audio content to validate
Raises:
SpeechRecognitionException: If audio is invalid
"""
if not audio.data:
raise SpeechRecognitionException("Audio data cannot be empty")
if audio.duration > 3600: # 1 hour limit
raise SpeechRecognitionException("Audio duration exceeds maximum limit of 1 hour")
if audio.duration < 0.1: # Minimum 100ms
raise SpeechRecognitionException("Audio duration too short (minimum 100ms)")
if not audio.is_valid_format:
raise SpeechRecognitionException(f"Unsupported audio format: {audio.format}")
def _detect_language(self, text: str) -> Optional[str]:
"""
Detect the language of transcribed text.
Args:
text: The transcribed text
Returns:
Optional[str]: Detected language code or None if detection fails
"""
try:
# Simple heuristic-based language detection
# This is a basic implementation - in production, you might use langdetect or similar
# Check for common English words
english_indicators = ['the', 'and', 'is', 'in', 'to', 'of', 'a', 'that', 'it', 'with']
text_lower = text.lower()
english_count = sum(1 for word in english_indicators if word in text_lower)
if english_count >= 2:
return 'en'
# Default to English if uncertain
return 'en'
except Exception as e:
logger.warning(f"Language detection failed: {str(e)}")
return None
def _ensure_temp_directory(self) -> Path:
"""
Ensure temporary directory exists and return its path.
Returns:
Path: Path to the temporary directory
"""
temp_dir = Path(tempfile.gettempdir()) / "stt_temp"
temp_dir.mkdir(exist_ok=True)
return temp_dir
def _cleanup_temp_file(self, file_path: Path) -> None:
"""
Clean up a temporary file.
Args:
file_path: Path to the file to clean up
"""
try:
if file_path.exists():
file_path.unlink()
logger.info(f"Cleaned up temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp file {file_path}: {str(e)}")
def _cleanup_old_temp_files(self, max_age_hours: int = 24) -> None:
"""
Clean up old temporary files.
Args:
max_age_hours: Maximum age of files to keep in hours
"""
try:
import time
current_time = time.time()
max_age_seconds = max_age_hours * 3600
for file_path in self._temp_dir.glob("*"):
if file_path.is_file():
file_age = current_time - file_path.stat().st_mtime
if file_age > max_age_seconds:
file_path.unlink()
logger.info(f"Cleaned up old temp file: {file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup old temp files: {str(e)}")
def _handle_provider_error(self, error: Exception, context: str = "") -> None:
"""
Handle provider-specific errors and convert to domain exceptions.
Args:
error: The original error
context: Additional context about when the error occurred
"""
error_msg = f"{self.provider_name} error"
if context:
error_msg += f" during {context}"
error_msg += f": {str(error)}"
logger.error(error_msg, exception=error)
raise SpeechRecognitionException(error_msg) from error |