qfuxa commited on
Commit
8dcebd9
·
1 Parent(s): 87cab7c

add translate_model_name function

Browse files
Files changed (1) hide show
  1. whisper_online.py +51 -7
whisper_online.py CHANGED
@@ -160,27 +160,71 @@ class MLXWhisper(ASRBase):
160
  """
161
  Uses MPX Whisper library as the backend, optimized for Apple Silicon.
162
  Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
163
- Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
164
  """
165
 
166
  sep = " "
167
 
168
- def load_model(self, modelsize=None, model_dir=None):
 
 
 
 
 
 
 
 
 
 
 
 
169
  from mlx_whisper import transcribe
170
 
171
  if model_dir is not None:
172
  logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
173
  model_size_or_path = model_dir
174
  elif modelsize is not None:
175
- logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
176
- model_size_or_path = modelsize
177
- elif modelsize == None:
178
- logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
179
- model_size_or_path = "mlx-community/whisper-large-v3-mlx"
180
 
181
  self.model_size_or_path = model_size_or_path
182
  return transcribe
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def transcribe(self, audio, init_prompt=""):
185
  segments = self.model(
186
  audio,
 
160
  """
161
  Uses MPX Whisper library as the backend, optimized for Apple Silicon.
162
  Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
163
+ Significantly faster than faster-whisper (without CUDA) on Apple M1.
164
  """
165
 
166
  sep = " "
167
 
168
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
169
+ """
170
+ Loads the MLX-compatible Whisper model.
171
+
172
+ Args:
173
+ modelsize (str, optional): The size or name of the Whisper model to load.
174
+ If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
175
+ Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
176
+ cache_dir (str, optional): Path to the directory for caching models.
177
+ **Note**: This is not supported by MLX Whisper and will be ignored.
178
+ model_dir (str, optional): Direct path to a custom model directory.
179
+ If specified, it overrides the `modelsize` parameter.
180
+ """
181
  from mlx_whisper import transcribe
182
 
183
  if model_dir is not None:
184
  logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
185
  model_size_or_path = model_dir
186
  elif modelsize is not None:
187
+ model_size_or_path = self.translate_model_name(modelsize)
188
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
 
 
 
189
 
190
  self.model_size_or_path = model_size_or_path
191
  return transcribe
192
 
193
+ def translate_model_name(self, model_name):
194
+ """
195
+ Translates a given model name to its corresponding MLX-compatible model path.
196
+
197
+ Args:
198
+ model_name (str): The name of the model to translate.
199
+
200
+ Returns:
201
+ str: The MLX-compatible model path.
202
+ """
203
+ # Dictionary mapping model names to MLX-compatible paths
204
+ model_mapping = {
205
+ "tiny.en": "mlx-community/whisper-tiny.en-mlx",
206
+ "tiny": "mlx-community/whisper-tiny-mlx",
207
+ "base.en": "mlx-community/whisper-base.en-mlx",
208
+ "base": "mlx-community/whisper-base-mlx",
209
+ "small.en": "mlx-community/whisper-small.en-mlx",
210
+ "small": "mlx-community/whisper-small-mlx",
211
+ "medium.en": "mlx-community/whisper-medium.en-mlx",
212
+ "medium": "mlx-community/whisper-medium-mlx",
213
+ "large-v1": "mlx-community/whisper-large-v1-mlx",
214
+ "large-v2": "mlx-community/whisper-large-v2-mlx",
215
+ "large-v3": "mlx-community/whisper-large-v3-mlx",
216
+ "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
217
+ "large": "mlx-community/whisper-large-mlx"
218
+ }
219
+
220
+ # Retrieve the corresponding MLX model path
221
+ mlx_model_path = model_mapping.get(model_name)
222
+
223
+ if mlx_model_path:
224
+ return mlx_model_path
225
+ else:
226
+ raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
227
+
228
  def transcribe(self, audio, init_prompt=""):
229
  segments = self.model(
230
  audio,