Michael Hu
add more logs
fdc056d
raw
history blame
4.98 kB
"""Whisper STT provider implementation."""
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ...domain.models.audio_content import AudioContent
from ...domain.models.text_content import TextContent
from ..base.stt_provider_base import STTProviderBase
from ...domain.exceptions import SpeechRecognitionException
logger = logging.getLogger(__name__)
class WhisperSTTProvider(STTProviderBase):
"""Whisper STT provider using faster-whisper implementation."""
def __init__(self):
"""Initialize the Whisper STT provider."""
super().__init__(
provider_name="Whisper",
supported_languages=["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"]
)
self.model = None
self._device = None
self._compute_type = None
self._initialize_device_settings()
def _initialize_device_settings(self):
"""Initialize device and compute type settings."""
try:
import torch
self._device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
# Fallback to CPU if torch is not available
self._device = "cpu"
self._compute_type = "float16" if self._device == "cuda" else "int8"
logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}")
def _perform_transcription(self, audio_path: Path, model: str) -> str:
"""
Perform transcription using Faster Whisper.
Args:
audio_path: Path to the preprocessed audio file
model: The Whisper model to use (e.g., 'large-v3', 'medium', 'small')
Returns:
str: The transcribed text
"""
try:
# Load model if not already loaded or if model changed
if self.model is None or getattr(self.model, 'model_size_or_path', None) != model:
self._load_model(model)
logger.info(f"Starting Whisper transcription with model {model}")
# Perform transcription
segments, info = self.model.transcribe(
str(audio_path),
beam_size=5,
language="en", # Can be made configurable
task="transcribe"
)
logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
# Collect all segments into a single text
result_text = ""
for segment in segments:
result_text += segment.text + " "
logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
result = result_text.strip()
logger.info("Whisper transcription completed successfully")
return result
except Exception as e:
self._handle_provider_error(e, "transcription")
def _load_model(self, model_name: str):
"""
Load the Whisper model.
Args:
model_name: Name of the model to load
"""
try:
from faster_whisper import WhisperModel as FasterWhisperModel
logger.info(f"Loading Whisper model: {model_name}")
logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
self.model = FasterWhisperModel(
model_name,
device=self._device,
compute_type=self._compute_type
)
logger.info(f"Whisper model {model_name} loaded successfully")
except ImportError as e:
raise SpeechRecognitionException(
"faster-whisper not available. Please install with: pip install faster-whisper"
) from e
except Exception as e:
raise SpeechRecognitionException(f"Failed to load Whisper model {model_name}: {str(e)}") from e
def is_available(self) -> bool:
"""
Check if the Whisper provider is available.
Returns:
bool: True if faster-whisper is available, False otherwise
"""
try:
import faster_whisper
return True
except ImportError:
logger.warning("faster-whisper not available")
return False
def get_available_models(self) -> list[str]:
"""
Get list of available Whisper models.
Returns:
list[str]: List of available model names
"""
return [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large-v1",
"large-v2",
"large-v3"
]
def get_default_model(self) -> str:
"""
Get the default model for this provider.
Returns:
str: Default model name
"""
return "large-v3"