qfuxa commited on
Commit
f884d11
·
1 Parent(s): 0ff6067

warning when transcribe_kargs are used with MLX Whisper

Browse files
Files changed (1) hide show
  1. whisper_online.py +16 -7
whisper_online.py CHANGED
@@ -201,7 +201,8 @@ class MLXWhisper(ASRBase):
201
  model_dir (str, optional): Direct path to a custom model directory.
202
  If specified, it overrides the `modelsize` parameter.
203
  """
204
- from mlx_whisper import transcribe
 
205
 
206
  if model_dir is not None:
207
  logger.debug(
@@ -215,6 +216,12 @@ class MLXWhisper(ASRBase):
215
  )
216
 
217
  self.model_size_or_path = model_size_or_path
 
 
 
 
 
 
218
  return transcribe
219
 
220
  def translate_model_name(self, model_name):
@@ -255,6 +262,8 @@ class MLXWhisper(ASRBase):
255
  )
256
 
257
  def transcribe(self, audio, init_prompt=""):
 
 
258
  segments = self.model(
259
  audio,
260
  language=self.original_language,
@@ -262,7 +271,6 @@ class MLXWhisper(ASRBase):
262
  word_timestamps=True,
263
  condition_on_previous_text=True,
264
  path_or_hf_repo=self.model_size_or_path,
265
- **self.transcribe_kargs,
266
  )
267
  return segments.get("segments", [])
268
 
@@ -844,7 +852,7 @@ def add_shared_args(parser):
844
  parser.add_argument(
845
  "--model",
846
  type=str,
847
- default="large-v2",
848
  choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
849
  ","
850
  ),
@@ -879,14 +887,14 @@ def add_shared_args(parser):
879
  parser.add_argument(
880
  "--backend",
881
  type=str,
882
- default="faster-whisper",
883
  choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
884
  help="Load only this backend for Whisper processing.",
885
  )
886
  parser.add_argument(
887
  "--vac",
888
  action="store_true",
889
- default=False,
890
  help="Use VAC = voice activity controller. Recommended. Requires torch.",
891
  )
892
  parser.add_argument(
@@ -895,7 +903,7 @@ def add_shared_args(parser):
895
  parser.add_argument(
896
  "--vad",
897
  action="store_true",
898
- default=False,
899
  help="Use VAD = voice activity detection, with the default parameters.",
900
  )
901
  parser.add_argument(
@@ -1006,8 +1014,9 @@ if __name__ == "__main__":
1006
 
1007
  parser = argparse.ArgumentParser()
1008
  parser.add_argument(
1009
- "audio_path",
1010
  type=str,
 
1011
  help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
1012
  )
1013
  add_shared_args(parser)
 
201
  model_dir (str, optional): Direct path to a custom model directory.
202
  If specified, it overrides the `modelsize` parameter.
203
  """
204
+ from mlx_whisper.transcribe import ModelHolder, transcribe
205
+ import mlx.core as mx
206
 
207
  if model_dir is not None:
208
  logger.debug(
 
216
  )
217
 
218
  self.model_size_or_path = model_size_or_path
219
+
220
+ # In mlx_whisper.transcribe, dtype is defined as:
221
+ # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
222
+ # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
223
+ dtype = mx.float16
224
+ ModelHolder.get_model(model_size_or_path, dtype)
225
  return transcribe
226
 
227
  def translate_model_name(self, model_name):
 
262
  )
263
 
264
  def transcribe(self, audio, init_prompt=""):
265
+ if self.transcribe_kargs:
266
+ logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
267
  segments = self.model(
268
  audio,
269
  language=self.original_language,
 
271
  word_timestamps=True,
272
  condition_on_previous_text=True,
273
  path_or_hf_repo=self.model_size_or_path,
 
274
  )
275
  return segments.get("segments", [])
276
 
 
852
  parser.add_argument(
853
  "--model",
854
  type=str,
855
+ default="tiny",
856
  choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
857
  ","
858
  ),
 
887
  parser.add_argument(
888
  "--backend",
889
  type=str,
890
+ default="mlx-whisper",
891
  choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
892
  help="Load only this backend for Whisper processing.",
893
  )
894
  parser.add_argument(
895
  "--vac",
896
  action="store_true",
897
+ default=True,
898
  help="Use VAC = voice activity controller. Recommended. Requires torch.",
899
  )
900
  parser.add_argument(
 
903
  parser.add_argument(
904
  "--vad",
905
  action="store_true",
906
+ default=True,
907
  help="Use VAD = voice activity detection, with the default parameters.",
908
  )
909
  parser.add_argument(
 
1014
 
1015
  parser = argparse.ArgumentParser()
1016
  parser.add_argument(
1017
+ "--audio_path",
1018
  type=str,
1019
+ default='samples_jfk.wav',
1020
  help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
1021
  )
1022
  add_shared_args(parser)