qfuxa commited on
Commit
87cab7c
·
1 Parent(s): e6648e4

add whisper mlx backend

Browse files
Files changed (1) hide show
  1. whisper_online.py +60 -1
whisper_online.py CHANGED
@@ -156,6 +156,63 @@ class FasterWhisperASR(ASRBase):
156
  def set_translate_task(self):
157
  self.transcribe_kargs["task"] = "translate"
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  class OpenaiApiASR(ASRBase):
161
  """Uses OpenAI's Whisper API for audio transcription."""
@@ -660,7 +717,7 @@ def add_shared_args(parser):
660
  parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
661
  parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
662
  parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
663
- parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
664
  parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
665
  parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
666
  parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
@@ -679,6 +736,8 @@ def asr_factory(args, logfile=sys.stderr):
679
  else:
680
  if backend == "faster-whisper":
681
  asr_cls = FasterWhisperASR
 
 
682
  else:
683
  asr_cls = WhisperTimestampedASR
684
 
 
156
  def set_translate_task(self):
157
  self.transcribe_kargs["task"] = "translate"
158
 
159
+ class MLXWhisper(ASRBase):
160
+ """
161
+ Uses MPX Whisper library as the backend, optimized for Apple Silicon.
162
+ Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
163
+ Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
164
+ """
165
+
166
+ sep = " "
167
+
168
+ def load_model(self, modelsize=None, model_dir=None):
169
+ from mlx_whisper import transcribe
170
+
171
+ if model_dir is not None:
172
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
173
+ model_size_or_path = model_dir
174
+ elif modelsize is not None:
175
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
176
+ model_size_or_path = modelsize
177
+ elif modelsize == None:
178
+ logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
179
+ model_size_or_path = "mlx-community/whisper-large-v3-mlx"
180
+
181
+ self.model_size_or_path = model_size_or_path
182
+ return transcribe
183
+
184
+ def transcribe(self, audio, init_prompt=""):
185
+ segments = self.model(
186
+ audio,
187
+ language=self.original_language,
188
+ initial_prompt=init_prompt,
189
+ word_timestamps=True,
190
+ condition_on_previous_text=True,
191
+ path_or_hf_repo=self.model_size_or_path,
192
+ **self.transcribe_kargs
193
+ )
194
+ return segments.get("segments", [])
195
+
196
+
197
+ def ts_words(self, segments):
198
+ """
199
+ Extract timestamped words from transcription segments and skips words with high no-speech probability.
200
+ """
201
+ return [
202
+ (word["start"], word["end"], word["word"])
203
+ for segment in segments
204
+ for word in segment.get("words", [])
205
+ if segment.get("no_speech_prob", 0) <= 0.9
206
+ ]
207
+
208
+ def segments_end_ts(self, res):
209
+ return [s['end'] for s in res]
210
+
211
+ def use_vad(self):
212
+ self.transcribe_kargs["vad_filter"] = True
213
+
214
+ def set_translate_task(self):
215
+ self.transcribe_kargs["task"] = "translate"
216
 
217
  class OpenaiApiASR(ASRBase):
218
  """Uses OpenAI's Whisper API for audio transcription."""
 
717
  parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
718
  parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
719
  parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
720
+ parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.')
721
  parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
722
  parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
723
  parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
 
736
  else:
737
  if backend == "faster-whisper":
738
  asr_cls = FasterWhisperASR
739
+ elif backend == "mlx-whisper":
740
+ asr_cls = MLXWhisper
741
  else:
742
  asr_cls = WhisperTimestampedASR
743