Dominik Macháček commited on
Commit
88dc796
·
1 Parent(s): b1878ce

model_dir, vad and other updates

Browse files
Files changed (1) hide show
  1. whisper_online.py +141 -103
whisper_online.py CHANGED
@@ -22,18 +22,23 @@ def load_audio_chunk(fname, beg, end):
22
 
23
  class ASRBase:
24
 
 
25
  sep = " "
26
 
27
- def __init__(self, modelsize, lan, cache_dir):
 
28
  self.original_language = lan
29
 
30
- self.model = self.load_model(modelsize, cache_dir)
31
 
32
  def load_model(self, modelsize, cache_dir):
33
- raise NotImplemented("mus be implemented in the child class")
34
 
35
  def transcribe(self, audio, init_prompt=""):
36
- raise NotImplemented("mus be implemented in the child class")
 
 
 
37
 
38
 
39
  ## requires imports:
@@ -49,7 +54,9 @@ class WhisperTimestampedASR(ASRBase):
49
  import whisper_timestamped
50
  """
51
 
52
- def load_model(self, modelsize, cache_dir):
 
 
53
  return whisper.load_model(modelsize, download_root=cache_dir)
54
 
55
  def transcribe(self, audio, init_prompt=""):
@@ -68,6 +75,9 @@ class WhisperTimestampedASR(ASRBase):
68
  def segments_end_ts(self, res):
69
  return [s["end"] for s in res["segments"]]
70
 
 
 
 
71
 
72
  class FasterWhisperASR(ASRBase):
73
  """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
@@ -78,11 +88,19 @@ class FasterWhisperASR(ASRBase):
78
 
79
  sep = ""
80
 
81
- def load_model(self, modelsize, cache_dir):
82
- # cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
 
 
 
 
 
 
 
 
83
 
84
  # this worked fast and reliably on NVIDIA L40
85
- model = WhisperModel(modelsize, device="cuda", compute_type="float16")
86
 
87
  # or run on GPU with INT8
88
  # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
@@ -90,12 +108,12 @@ class FasterWhisperASR(ASRBase):
90
 
91
  # or run on CPU with INT8
92
  # tested: works, but slow, appx 10-times than cuda FP16
93
- #model = WhisperModel(model_size, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
94
  return model
95
 
96
  def transcribe(self, audio, init_prompt=""):
97
- wt = False
98
- segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True)
99
  return list(segments)
100
 
101
  def ts_words(self, segments):
@@ -111,6 +129,12 @@ class FasterWhisperASR(ASRBase):
111
  def segments_end_ts(self, res):
112
  return [s.end for s in res]
113
 
 
 
 
 
 
 
114
 
115
 
116
  class HypothesisBuffer:
@@ -225,7 +249,7 @@ class OnlineASRProcessor:
225
  l += len(x)+1
226
  prompt.append(x)
227
  non_prompt = self.commited[k:]
228
- return " ".join(prompt[::-1]), " ".join(t for _,_,t in non_prompt)
229
 
230
  def process_iter(self):
231
  """Runs on the current audio buffer.
@@ -398,92 +422,88 @@ class OnlineASRProcessor:
398
 
399
  ## main:
400
 
401
- import argparse
402
- parser = argparse.ArgumentParser()
403
- parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
404
- parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
405
- parser.add_argument('--model', type=str, default='large-v2', help="name of the Whisper model to use (default: large-v2, options: {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}")
406
- parser.add_argument('--model_dir', type=str, default='disk-cache-dir', help="the path where Whisper models are saved (or downloaded to). Default: ./disk-cache-dir")
407
- parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
408
- parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
409
- parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
410
- parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
411
- args = parser.parse_args()
412
-
413
- audio_path = args.audio_path
414
-
415
- SAMPLING_RATE = 16000
416
- duration = len(load_audio(audio_path))/SAMPLING_RATE
417
- print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
418
-
419
- size = args.model
420
- language = args.lan
421
-
422
- t = time.time()
423
- print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
424
- #asr = WhisperASR(lan=language, modelsize=size)
425
-
426
- if args.backend == "faster-whisper":
427
- from faster_whisper import WhisperModel
428
- asr_cls = FasterWhisperASR
429
- else:
430
- import whisper
431
- import whisper_timestamped
432
- # from whisper_timestamped_model import WhisperTimestampedASR
433
- asr_cls = WhisperTimestampedASR
434
-
435
- asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_dir)
436
- e = time.time()
437
- print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
438
-
439
-
440
- min_chunk = args.min_chunk_size
441
- online = OnlineASRProcessor(language,asr,min_chunk)
442
-
443
-
444
- # load the audio into the LRU cache before we start the timer
445
- a = load_audio_chunk(audio_path,0,1)
446
-
447
- # warm up the ASR, because the very first transcribe takes much more time than the other
448
- asr.transcribe(a)
449
-
450
- beg = args.start_at
451
- start = time.time()-beg
452
-
453
- def output_transcript(o):
454
- # output format in stdout is like:
455
- # 4186.3606 0 1720 Takhle to je
456
- # - the first three words are:
457
- # - emission time from beginning of processing, in milliseconds
458
- # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
459
- # - the next words: segment transcript
460
- now = time.time()-start
461
- if o[0] is not None:
462
- print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
463
- else:
464
- print(o,file=sys.stderr,flush=True)
465
-
466
- if args.offline: ## offline mode processing (for testing/debugging)
467
- a = load_audio(audio_path)
468
- online.insert_audio_chunk(a)
469
- try:
470
- o = online.process_iter()
471
- except AssertionError:
472
- print("assertion error",file=sys.stderr)
473
- pass
474
  else:
475
- output_transcript(o)
476
- else: # online = simultaneous mode
477
- end = 0
478
- while True:
479
- now = time.time() - start
480
- if now < end+min_chunk:
481
- time.sleep(min_chunk+end-now)
482
- end = time.time() - start
483
- a = load_audio_chunk(audio_path,beg,end)
484
- beg = end
485
- online.insert_audio_chunk(a)
 
 
 
 
 
 
 
 
 
 
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  try:
488
  o = online.process_iter()
489
  except AssertionError:
@@ -491,13 +511,31 @@ else: # online = simultaneous mode
491
  pass
492
  else:
493
  output_transcript(o)
494
- now = time.time() - start
495
- print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- print(file=sys.stderr,flush=True)
498
 
499
- if end >= duration:
500
- break
501
 
502
- o = online.finish()
503
- output_transcript(o)
 
22
 
23
  class ASRBase:
24
 
25
+ # join transcribe words with this character (" " for whisper_timestamped, "" for faster-whisper because it emits the spaces when neeeded)
26
  sep = " "
27
 
28
+ def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
29
+ self.transcribe_kargs = {}
30
  self.original_language = lan
31
 
32
+ self.model = self.load_model(modelsize, cache_dir, model_dir)
33
 
34
  def load_model(self, modelsize, cache_dir):
35
+ raise NotImplemented("must be implemented in the child class")
36
 
37
  def transcribe(self, audio, init_prompt=""):
38
+ raise NotImplemented("must be implemented in the child class")
39
+
40
+ def use_vad(self):
41
+ raise NotImplemented("must be implemented in the child class")
42
 
43
 
44
  ## requires imports:
 
54
  import whisper_timestamped
55
  """
56
 
57
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
58
+ if model_dir is not None:
59
+ print("ignoring model_dir, not implemented",file=sys.stderr)
60
  return whisper.load_model(modelsize, download_root=cache_dir)
61
 
62
  def transcribe(self, audio, init_prompt=""):
 
75
  def segments_end_ts(self, res):
76
  return [s["end"] for s in res["segments"]]
77
 
78
+ def use_vad(self):
79
+ raise NotImplemented("Feature use_vad is not implemented for whisper_timestamped backend.")
80
+
81
 
82
  class FasterWhisperASR(ASRBase):
83
  """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
 
88
 
89
  sep = ""
90
 
91
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
92
+
93
+ if model_dir is not None:
94
+ print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=sys.stderr)
95
+ model_size_or_path = model_dir
96
+ elif modelsize is not None:
97
+ model_size_or_path = modelsize
98
+ else:
99
+ raise ValueError("modelsize or model_dir parameter must be set")
100
+
101
 
102
  # this worked fast and reliably on NVIDIA L40
103
+ model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
104
 
105
  # or run on GPU with INT8
106
  # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
 
108
 
109
  # or run on CPU with INT8
110
  # tested: works, but slow, appx 10-times than cuda FP16
111
+ # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
112
  return model
113
 
114
  def transcribe(self, audio, init_prompt=""):
115
+ # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
116
+ segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
117
  return list(segments)
118
 
119
  def ts_words(self, segments):
 
129
  def segments_end_ts(self, res):
130
  return [s.end for s in res]
131
 
132
+ def use_vad(self):
133
+ self.transcribe_kargs["vad_filter"] = True
134
+
135
+ def set_translate_task(self):
136
+ self.transcribe_kargs["task"] = "translate"
137
+
138
 
139
 
140
  class HypothesisBuffer:
 
249
  l += len(x)+1
250
  prompt.append(x)
251
  non_prompt = self.commited[k:]
252
+ return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt)
253
 
254
  def process_iter(self):
255
  """Runs on the current audio buffer.
 
422
 
423
  ## main:
424
 
425
+ if __name__ == "__main__":
426
+
427
+ import argparse
428
+ parser = argparse.ArgumentParser()
429
+ parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
430
+ parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
431
+ parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
432
+ parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
433
+ parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
434
+ parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
435
+ parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
436
+ parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
437
+ parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
438
+ parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
439
+ parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
440
+ args = parser.parse_args()
441
+
442
+ audio_path = args.audio_path
443
+
444
+ SAMPLING_RATE = 16000
445
+ duration = len(load_audio(audio_path))/SAMPLING_RATE
446
+ print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
447
+
448
+ size = args.model
449
+ language = args.lan
450
+
451
+ t = time.time()
452
+ print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
453
+ #asr = WhisperASR(lan=language, modelsize=size)
454
+
455
+ if args.backend == "faster-whisper":
456
+ from faster_whisper import WhisperModel
457
+ asr_cls = FasterWhisperASR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  else:
459
+ import whisper
460
+ import whisper_timestamped
461
+ # from whisper_timestamped_model import WhisperTimestampedASR
462
+ asr_cls = WhisperTimestampedASR
463
+
464
+ asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
465
+
466
+ if args.task == "translate":
467
+ asr.set_translate_task()
468
+
469
+
470
+ e = time.time()
471
+ print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
472
+
473
+ if args.vad:
474
+ print("setting VAD filter",file=sys.stderr)
475
+ asr.use_vad()
476
+
477
+ min_chunk = args.min_chunk_size
478
+ online = OnlineASRProcessor(language,asr,min_chunk)
479
+
480
 
481
+ # load the audio into the LRU cache before we start the timer
482
+ a = load_audio_chunk(audio_path,0,1)
483
+
484
+ # warm up the ASR, because the very first transcribe takes much more time than the other
485
+ asr.transcribe(a)
486
+
487
+ beg = args.start_at
488
+ start = time.time()-beg
489
+
490
+ def output_transcript(o):
491
+ # output format in stdout is like:
492
+ # 4186.3606 0 1720 Takhle to je
493
+ # - the first three words are:
494
+ # - emission time from beginning of processing, in milliseconds
495
+ # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
496
+ # - the next words: segment transcript
497
+ now = time.time()-start
498
+ if o[0] is not None:
499
+ print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=sys.stderr,flush=True)
500
+ print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
501
+ else:
502
+ print(o,file=sys.stderr,flush=True)
503
+
504
+ if args.offline: ## offline mode processing (for testing/debugging)
505
+ a = load_audio(audio_path)
506
+ online.insert_audio_chunk(a)
507
  try:
508
  o = online.process_iter()
509
  except AssertionError:
 
511
  pass
512
  else:
513
  output_transcript(o)
514
+ else: # online = simultaneous mode
515
+ end = 0
516
+ while True:
517
+ now = time.time() - start
518
+ if now < end+min_chunk:
519
+ time.sleep(min_chunk+end-now)
520
+ end = time.time() - start
521
+ a = load_audio_chunk(audio_path,beg,end)
522
+ beg = end
523
+ online.insert_audio_chunk(a)
524
+
525
+ try:
526
+ o = online.process_iter()
527
+ except AssertionError:
528
+ print("assertion error",file=sys.stderr)
529
+ pass
530
+ else:
531
+ output_transcript(o)
532
+ now = time.time() - start
533
+ print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr)
534
 
535
+ print(file=sys.stderr,flush=True)
536
 
537
+ if end >= duration:
538
+ break
539
 
540
+ o = online.finish()
541
+ output_transcript(o)