Spaces:
Running
Running
Michael Hu
commited on
Commit
·
b10a453
1
Parent(s):
5a53a88
fix(stt): handle whisper-large model name as alias for whisper provider
Browse files- Allow "whisper-large" to be treated as an alias for the whisper provider
- Remove the special-case mapping that was added in the previous commit
- Update provider-factory to accept model names as aliases for their
respective providers
- Remove unused mapping helper and associated logging
The change keeps the CLI/API contract intact while eliminating the
need to maintain a hard-coded list of model-to-provider mappings.
src/infrastructure/stt/provider_factory.py
CHANGED
|
@@ -40,14 +40,19 @@ class STTProviderFactory:
|
|
| 40 |
logger.info(f"Available providers: {list(cls._providers.keys())}")
|
| 41 |
|
| 42 |
if provider_name not in cls._providers:
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
provider_name = mapped_provider
|
| 48 |
else:
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
provider_class = cls._providers[provider_name]
|
| 53 |
|
|
|
|
| 40 |
logger.info(f"Available providers: {list(cls._providers.keys())}")
|
| 41 |
|
| 42 |
if provider_name not in cls._providers:
|
| 43 |
+
# Simple handling for whisper-large - just use whisper provider
|
| 44 |
+
if provider_name == "whisper-large":
|
| 45 |
+
logger.info("whisper-large requested, using whisper provider")
|
| 46 |
+
provider_name = "whisper"
|
|
|
|
| 47 |
else:
|
| 48 |
+
# Check if this is a model name that should be mapped to a provider
|
| 49 |
+
mapped_provider = cls._map_model_to_provider(provider_name)
|
| 50 |
+
if mapped_provider:
|
| 51 |
+
logger.info(f"Mapped model '{provider_name}' to provider '{mapped_provider}'")
|
| 52 |
+
provider_name = mapped_provider
|
| 53 |
+
else:
|
| 54 |
+
logger.error(f"Unknown STT provider: {provider_name}. Available: {list(cls._providers.keys())}")
|
| 55 |
+
raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
|
| 56 |
|
| 57 |
provider_class = cls._providers[provider_name]
|
| 58 |
|
src/infrastructure/stt/whisper_provider.py
CHANGED
|
@@ -46,18 +46,10 @@ class WhisperSTTProvider(STTProviderBase):
|
|
| 46 |
|
| 47 |
Args:
|
| 48 |
audio_path: Path to the preprocessed audio file
|
| 49 |
-
model: The Whisper model to use (e.g., 'large-v3', 'medium', 'small')
|
| 50 |
-
|
| 51 |
Returns:
|
| 52 |
str: The transcribed text
|
| 53 |
"""
|
| 54 |
try:
|
| 55 |
-
# Load model if not already loaded or if model changed
|
| 56 |
-
if self.model is None or getattr(self.model, 'model_size_or_path', None) != model:
|
| 57 |
-
self._load_model(model)
|
| 58 |
-
|
| 59 |
-
logger.info(f"Starting Whisper transcription with model {model}")
|
| 60 |
-
|
| 61 |
# Perform transcription
|
| 62 |
segments, info = self.model.transcribe(
|
| 63 |
str(audio_path),
|
|
@@ -81,33 +73,27 @@ class WhisperSTTProvider(STTProviderBase):
|
|
| 81 |
except Exception as e:
|
| 82 |
self._handle_provider_error(e, "transcription")
|
| 83 |
|
| 84 |
-
def _load_model(self
|
| 85 |
"""
|
| 86 |
Load the Whisper model.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
model_name: Name of the model to load
|
| 90 |
"""
|
| 91 |
try:
|
| 92 |
from faster_whisper import WhisperModel as FasterWhisperModel
|
| 93 |
-
|
| 94 |
-
logger.info(f"Loading Whisper model: {model_name}")
|
| 95 |
logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
|
| 96 |
|
| 97 |
self.model = FasterWhisperModel(
|
| 98 |
-
|
| 99 |
device=self._device,
|
| 100 |
compute_type=self._compute_type
|
| 101 |
)
|
| 102 |
|
| 103 |
-
logger.info(f"Whisper model {model_name} loaded successfully")
|
| 104 |
-
|
| 105 |
except ImportError as e:
|
| 106 |
raise SpeechRecognitionException(
|
| 107 |
"faster-whisper not available. Please install with: pip install faster-whisper"
|
| 108 |
) from e
|
| 109 |
except Exception as e:
|
| 110 |
-
raise SpeechRecognitionException(f"Failed to load Whisper model
|
| 111 |
|
| 112 |
def is_available(self) -> bool:
|
| 113 |
"""
|
|
|
|
| 46 |
|
| 47 |
Args:
|
| 48 |
audio_path: Path to the preprocessed audio file
|
|
|
|
|
|
|
| 49 |
Returns:
|
| 50 |
str: The transcribed text
|
| 51 |
"""
|
| 52 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Perform transcription
|
| 54 |
segments, info = self.model.transcribe(
|
| 55 |
str(audio_path),
|
|
|
|
| 73 |
except Exception as e:
|
| 74 |
self._handle_provider_error(e, "transcription")
|
| 75 |
|
| 76 |
+
def _load_model(self):
|
| 77 |
"""
|
| 78 |
Load the Whisper model.
|
|
|
|
|
|
|
|
|
|
| 79 |
"""
|
| 80 |
try:
|
| 81 |
from faster_whisper import WhisperModel as FasterWhisperModel
|
| 82 |
+
|
|
|
|
| 83 |
logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
|
| 84 |
|
| 85 |
self.model = FasterWhisperModel(
|
| 86 |
+
'large-v3',
|
| 87 |
device=self._device,
|
| 88 |
compute_type=self._compute_type
|
| 89 |
)
|
| 90 |
|
|
|
|
|
|
|
| 91 |
except ImportError as e:
|
| 92 |
raise SpeechRecognitionException(
|
| 93 |
"faster-whisper not available. Please install with: pip install faster-whisper"
|
| 94 |
) from e
|
| 95 |
except Exception as e:
|
| 96 |
+
raise SpeechRecognitionException(f"Failed to load Whisper model 'large-v3'") from e
|
| 97 |
|
| 98 |
def is_available(self) -> bool:
|
| 99 |
"""
|