Michael Hu commited on
Commit
b10a453
·
1 Parent(s): 5a53a88

fix(stt): handle whisper-large model name as alias for whisper provider

Browse files

- Allow "whisper-large" to be treated as an alias for the whisper provider
- Remove the special-case mapping that was added in the previous commit
- Update provider-factory to accept model names as aliases for their
respective providers
- Remove unused mapping helper and associated logging

The change keeps the CLI/API contract intact while eliminating the
need to maintain a hard-coded list of model-to-provider mappings.

src/infrastructure/stt/provider_factory.py CHANGED
@@ -40,14 +40,19 @@ class STTProviderFactory:
40
  logger.info(f"Available providers: {list(cls._providers.keys())}")
41
 
42
  if provider_name not in cls._providers:
43
- # Check if this is a model name that should be mapped to a provider
44
- mapped_provider = cls._map_model_to_provider(provider_name)
45
- if mapped_provider:
46
- logger.info(f"Mapped model '{provider_name}' to provider '{mapped_provider}'")
47
- provider_name = mapped_provider
48
  else:
49
- logger.error(f"Unknown STT provider: {provider_name}. Available: {list(cls._providers.keys())}")
50
- raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
 
 
 
 
 
 
51
 
52
  provider_class = cls._providers[provider_name]
53
 
 
40
  logger.info(f"Available providers: {list(cls._providers.keys())}")
41
 
42
  if provider_name not in cls._providers:
43
+ # Simple handling for whisper-large - just use whisper provider
44
+ if provider_name == "whisper-large":
45
+ logger.info("whisper-large requested, using whisper provider")
46
+ provider_name = "whisper"
 
47
  else:
48
+ # Check if this is a model name that should be mapped to a provider
49
+ mapped_provider = cls._map_model_to_provider(provider_name)
50
+ if mapped_provider:
51
+ logger.info(f"Mapped model '{provider_name}' to provider '{mapped_provider}'")
52
+ provider_name = mapped_provider
53
+ else:
54
+ logger.error(f"Unknown STT provider: {provider_name}. Available: {list(cls._providers.keys())}")
55
+ raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
56
 
57
  provider_class = cls._providers[provider_name]
58
 
src/infrastructure/stt/whisper_provider.py CHANGED
@@ -46,18 +46,10 @@ class WhisperSTTProvider(STTProviderBase):
46
 
47
  Args:
48
  audio_path: Path to the preprocessed audio file
49
- model: The Whisper model to use (e.g., 'large-v3', 'medium', 'small')
50
-
51
  Returns:
52
  str: The transcribed text
53
  """
54
  try:
55
- # Load model if not already loaded or if model changed
56
- if self.model is None or getattr(self.model, 'model_size_or_path', None) != model:
57
- self._load_model(model)
58
-
59
- logger.info(f"Starting Whisper transcription with model {model}")
60
-
61
  # Perform transcription
62
  segments, info = self.model.transcribe(
63
  str(audio_path),
@@ -81,33 +73,27 @@ class WhisperSTTProvider(STTProviderBase):
81
  except Exception as e:
82
  self._handle_provider_error(e, "transcription")
83
 
84
- def _load_model(self, model_name: str):
85
  """
86
  Load the Whisper model.
87
-
88
- Args:
89
- model_name: Name of the model to load
90
  """
91
  try:
92
  from faster_whisper import WhisperModel as FasterWhisperModel
93
-
94
- logger.info(f"Loading Whisper model: {model_name}")
95
  logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
96
 
97
  self.model = FasterWhisperModel(
98
- model_name,
99
  device=self._device,
100
  compute_type=self._compute_type
101
  )
102
 
103
- logger.info(f"Whisper model {model_name} loaded successfully")
104
-
105
  except ImportError as e:
106
  raise SpeechRecognitionException(
107
  "faster-whisper not available. Please install with: pip install faster-whisper"
108
  ) from e
109
  except Exception as e:
110
- raise SpeechRecognitionException(f"Failed to load Whisper model {model_name}: {str(e)}") from e
111
 
112
  def is_available(self) -> bool:
113
  """
 
46
 
47
  Args:
48
  audio_path: Path to the preprocessed audio file
 
 
49
  Returns:
50
  str: The transcribed text
51
  """
52
  try:
 
 
 
 
 
 
53
  # Perform transcription
54
  segments, info = self.model.transcribe(
55
  str(audio_path),
 
73
  except Exception as e:
74
  self._handle_provider_error(e, "transcription")
75
 
76
+ def _load_model(self):
77
  """
78
  Load the Whisper model.
 
 
 
79
  """
80
  try:
81
  from faster_whisper import WhisperModel as FasterWhisperModel
82
+
 
83
  logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}")
84
 
85
  self.model = FasterWhisperModel(
86
+ 'large-v3',
87
  device=self._device,
88
  compute_type=self._compute_type
89
  )
90
 
 
 
91
  except ImportError as e:
92
  raise SpeechRecognitionException(
93
  "faster-whisper not available. Please install with: pip install faster-whisper"
94
  ) from e
95
  except Exception as e:
96
+ raise SpeechRecognitionException(f"Failed to load Whisper model 'large-v3'") from e
97
 
98
  def is_available(self) -> bool:
99
  """