warning when transcribe_kargs are used with MLX Whisper
Browse files- whisper_online.py +16 -7
whisper_online.py
CHANGED
@@ -201,7 +201,8 @@ class MLXWhisper(ASRBase):
|
|
201 |
model_dir (str, optional): Direct path to a custom model directory.
|
202 |
If specified, it overrides the `modelsize` parameter.
|
203 |
"""
|
204 |
-
from mlx_whisper import transcribe
|
|
|
205 |
|
206 |
if model_dir is not None:
|
207 |
logger.debug(
|
@@ -215,6 +216,12 @@ class MLXWhisper(ASRBase):
|
|
215 |
)
|
216 |
|
217 |
self.model_size_or_path = model_size_or_path
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
return transcribe
|
219 |
|
220 |
def translate_model_name(self, model_name):
|
@@ -255,6 +262,8 @@ class MLXWhisper(ASRBase):
|
|
255 |
)
|
256 |
|
257 |
def transcribe(self, audio, init_prompt=""):
|
|
|
|
|
258 |
segments = self.model(
|
259 |
audio,
|
260 |
language=self.original_language,
|
@@ -262,7 +271,6 @@ class MLXWhisper(ASRBase):
|
|
262 |
word_timestamps=True,
|
263 |
condition_on_previous_text=True,
|
264 |
path_or_hf_repo=self.model_size_or_path,
|
265 |
-
**self.transcribe_kargs,
|
266 |
)
|
267 |
return segments.get("segments", [])
|
268 |
|
@@ -844,7 +852,7 @@ def add_shared_args(parser):
|
|
844 |
parser.add_argument(
|
845 |
"--model",
|
846 |
type=str,
|
847 |
-
default="
|
848 |
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
849 |
","
|
850 |
),
|
@@ -879,14 +887,14 @@ def add_shared_args(parser):
|
|
879 |
parser.add_argument(
|
880 |
"--backend",
|
881 |
type=str,
|
882 |
-
default="
|
883 |
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
884 |
help="Load only this backend for Whisper processing.",
|
885 |
)
|
886 |
parser.add_argument(
|
887 |
"--vac",
|
888 |
action="store_true",
|
889 |
-
default=
|
890 |
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
891 |
)
|
892 |
parser.add_argument(
|
@@ -895,7 +903,7 @@ def add_shared_args(parser):
|
|
895 |
parser.add_argument(
|
896 |
"--vad",
|
897 |
action="store_true",
|
898 |
-
default=
|
899 |
help="Use VAD = voice activity detection, with the default parameters.",
|
900 |
)
|
901 |
parser.add_argument(
|
@@ -1006,8 +1014,9 @@ if __name__ == "__main__":
|
|
1006 |
|
1007 |
parser = argparse.ArgumentParser()
|
1008 |
parser.add_argument(
|
1009 |
-
"audio_path",
|
1010 |
type=str,
|
|
|
1011 |
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
|
1012 |
)
|
1013 |
add_shared_args(parser)
|
|
|
201 |
model_dir (str, optional): Direct path to a custom model directory.
|
202 |
If specified, it overrides the `modelsize` parameter.
|
203 |
"""
|
204 |
+
from mlx_whisper.transcribe import ModelHolder, transcribe
|
205 |
+
import mlx.core as mx
|
206 |
|
207 |
if model_dir is not None:
|
208 |
logger.debug(
|
|
|
216 |
)
|
217 |
|
218 |
self.model_size_or_path = model_size_or_path
|
219 |
+
|
220 |
+
# In mlx_whisper.transcribe, dtype is defined as:
|
221 |
+
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
222 |
+
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
|
223 |
+
dtype = mx.float16
|
224 |
+
ModelHolder.get_model(model_size_or_path, dtype)
|
225 |
return transcribe
|
226 |
|
227 |
def translate_model_name(self, model_name):
|
|
|
262 |
)
|
263 |
|
264 |
def transcribe(self, audio, init_prompt=""):
|
265 |
+
if self.transcribe_kargs:
|
266 |
+
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
267 |
segments = self.model(
|
268 |
audio,
|
269 |
language=self.original_language,
|
|
|
271 |
word_timestamps=True,
|
272 |
condition_on_previous_text=True,
|
273 |
path_or_hf_repo=self.model_size_or_path,
|
|
|
274 |
)
|
275 |
return segments.get("segments", [])
|
276 |
|
|
|
852 |
parser.add_argument(
|
853 |
"--model",
|
854 |
type=str,
|
855 |
+
default="tiny",
|
856 |
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
857 |
","
|
858 |
),
|
|
|
887 |
parser.add_argument(
|
888 |
"--backend",
|
889 |
type=str,
|
890 |
+
default="mlx-whisper",
|
891 |
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
892 |
help="Load only this backend for Whisper processing.",
|
893 |
)
|
894 |
parser.add_argument(
|
895 |
"--vac",
|
896 |
action="store_true",
|
897 |
+
default=True,
|
898 |
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
899 |
)
|
900 |
parser.add_argument(
|
|
|
903 |
parser.add_argument(
|
904 |
"--vad",
|
905 |
action="store_true",
|
906 |
+
default=True,
|
907 |
help="Use VAD = voice activity detection, with the default parameters.",
|
908 |
)
|
909 |
parser.add_argument(
|
|
|
1014 |
|
1015 |
parser = argparse.ArgumentParser()
|
1016 |
parser.add_argument(
|
1017 |
+
"--audio_path",
|
1018 |
type=str,
|
1019 |
+
default='samples_jfk.wav',
|
1020 |
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
|
1021 |
)
|
1022 |
add_shared_args(parser)
|