Spaces:
Build error
Build error
"""Whisper STT provider implementation.""" | |
import logging | |
from pathlib import Path | |
from typing import TYPE_CHECKING | |
if TYPE_CHECKING: | |
from ...domain.models.audio_content import AudioContent | |
from ...domain.models.text_content import TextContent | |
from ..base.stt_provider_base import STTProviderBase | |
from ...domain.exceptions import SpeechRecognitionException | |
logger = logging.getLogger(__name__) | |
class WhisperSTTProvider(STTProviderBase): | |
"""Whisper STT provider using faster-whisper implementation.""" | |
def __init__(self): | |
"""Initialize the Whisper STT provider.""" | |
super().__init__( | |
provider_name="Whisper", | |
supported_languages=["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"] | |
) | |
self.model = None | |
self._device = None | |
self._compute_type = None | |
self._initialize_device_settings() | |
def _initialize_device_settings(self): | |
"""Initialize device and compute type settings.""" | |
try: | |
import torch | |
self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
except ImportError: | |
# Fallback to CPU if torch is not available | |
self._device = "cpu" | |
self._compute_type = "float16" if self._device == "cuda" else "int8" | |
logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}") | |
def _perform_transcription(self, audio_path: Path, model: str) -> str: | |
""" | |
Perform transcription using Faster Whisper. | |
Args: | |
audio_path: Path to the preprocessed audio file | |
model: The Whisper model to use (e.g., 'large-v3', 'medium', 'small') | |
Returns: | |
str: The transcribed text | |
""" | |
try: | |
# Load model if not already loaded or if model changed | |
if self.model is None or getattr(self.model, 'model_size_or_path', None) != model: | |
self._load_model(model) | |
logger.info(f"Starting Whisper transcription with model {model}") | |
# Perform transcription | |
segments, info = self.model.transcribe( | |
str(audio_path), | |
beam_size=5, | |
language="en", # Can be made configurable | |
task="transcribe" | |
) | |
logger.info(f"Detected language '{info.language}' with probability {info.language_probability}") | |
# Collect all segments into a single text | |
result_text = "" | |
for segment in segments: | |
result_text += segment.text + " " | |
logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}") | |
result = result_text.strip() | |
logger.info("Whisper transcription completed successfully") | |
return result | |
except Exception as e: | |
self._handle_provider_error(e, "transcription") | |
def _load_model(self, model_name: str): | |
""" | |
Load the Whisper model. | |
Args: | |
model_name: Name of the model to load | |
""" | |
try: | |
from faster_whisper import WhisperModel as FasterWhisperModel | |
logger.info(f"Loading Whisper model: {model_name}") | |
logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}") | |
self.model = FasterWhisperModel( | |
model_name, | |
device=self._device, | |
compute_type=self._compute_type | |
) | |
logger.info(f"Whisper model {model_name} loaded successfully") | |
except ImportError as e: | |
raise SpeechRecognitionException( | |
"faster-whisper not available. Please install with: pip install faster-whisper" | |
) from e | |
except Exception as e: | |
raise SpeechRecognitionException(f"Failed to load Whisper model {model_name}: {str(e)}") from e | |
def is_available(self) -> bool: | |
""" | |
Check if the Whisper provider is available. | |
Returns: | |
bool: True if faster-whisper is available, False otherwise | |
""" | |
try: | |
import faster_whisper | |
return True | |
except ImportError: | |
logger.warning("faster-whisper not available") | |
return False | |
def get_available_models(self) -> list[str]: | |
""" | |
Get list of available Whisper models. | |
Returns: | |
list[str]: List of available model names | |
""" | |
return [ | |
"tiny", | |
"tiny.en", | |
"base", | |
"base.en", | |
"small", | |
"small.en", | |
"medium", | |
"medium.en", | |
"large-v1", | |
"large-v2", | |
"large-v3" | |
] | |
def get_default_model(self) -> str: | |
""" | |
Get the default model for this provider. | |
Returns: | |
str: Default model name | |
""" | |
return "large-v3" |