add translate_model_name function
Browse files- whisper_online.py +51 -7
whisper_online.py
CHANGED
@@ -160,27 +160,71 @@ class MLXWhisper(ASRBase):
|
|
160 |
"""
|
161 |
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
162 |
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
163 |
-
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
164 |
"""
|
165 |
|
166 |
sep = " "
|
167 |
|
168 |
-
def load_model(self, modelsize=None, model_dir=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
from mlx_whisper import transcribe
|
170 |
|
171 |
if model_dir is not None:
|
172 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
173 |
model_size_or_path = model_dir
|
174 |
elif modelsize is not None:
|
175 |
-
|
176 |
-
model_size_or_path
|
177 |
-
elif modelsize == None:
|
178 |
-
logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
|
179 |
-
model_size_or_path = "mlx-community/whisper-large-v3-mlx"
|
180 |
|
181 |
self.model_size_or_path = model_size_or_path
|
182 |
return transcribe
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
def transcribe(self, audio, init_prompt=""):
|
185 |
segments = self.model(
|
186 |
audio,
|
|
|
160 |
"""
|
161 |
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
162 |
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
163 |
+
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
164 |
"""
|
165 |
|
166 |
sep = " "
|
167 |
|
168 |
+
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
169 |
+
"""
|
170 |
+
Loads the MLX-compatible Whisper model.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
modelsize (str, optional): The size or name of the Whisper model to load.
|
174 |
+
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
|
175 |
+
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
|
176 |
+
cache_dir (str, optional): Path to the directory for caching models.
|
177 |
+
**Note**: This is not supported by MLX Whisper and will be ignored.
|
178 |
+
model_dir (str, optional): Direct path to a custom model directory.
|
179 |
+
If specified, it overrides the `modelsize` parameter.
|
180 |
+
"""
|
181 |
from mlx_whisper import transcribe
|
182 |
|
183 |
if model_dir is not None:
|
184 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
185 |
model_size_or_path = model_dir
|
186 |
elif modelsize is not None:
|
187 |
+
model_size_or_path = self.translate_model_name(modelsize)
|
188 |
+
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
|
|
|
|
|
|
189 |
|
190 |
self.model_size_or_path = model_size_or_path
|
191 |
return transcribe
|
192 |
|
193 |
+
def translate_model_name(self, model_name):
|
194 |
+
"""
|
195 |
+
Translates a given model name to its corresponding MLX-compatible model path.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
model_name (str): The name of the model to translate.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
str: The MLX-compatible model path.
|
202 |
+
"""
|
203 |
+
# Dictionary mapping model names to MLX-compatible paths
|
204 |
+
model_mapping = {
|
205 |
+
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
206 |
+
"tiny": "mlx-community/whisper-tiny-mlx",
|
207 |
+
"base.en": "mlx-community/whisper-base.en-mlx",
|
208 |
+
"base": "mlx-community/whisper-base-mlx",
|
209 |
+
"small.en": "mlx-community/whisper-small.en-mlx",
|
210 |
+
"small": "mlx-community/whisper-small-mlx",
|
211 |
+
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
212 |
+
"medium": "mlx-community/whisper-medium-mlx",
|
213 |
+
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
214 |
+
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
215 |
+
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
216 |
+
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
217 |
+
"large": "mlx-community/whisper-large-mlx"
|
218 |
+
}
|
219 |
+
|
220 |
+
# Retrieve the corresponding MLX model path
|
221 |
+
mlx_model_path = model_mapping.get(model_name)
|
222 |
+
|
223 |
+
if mlx_model_path:
|
224 |
+
return mlx_model_path
|
225 |
+
else:
|
226 |
+
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
227 |
+
|
228 |
def transcribe(self, audio, init_prompt=""):
|
229 |
segments = self.model(
|
230 |
audio,
|