Dominik Macháček commited on
Commit
8f32dea
·
1 Parent(s): bd0d848

logfile reviewed, whisper_timestamped loading module and vad

Browse files
Files changed (1) hide show
  1. whisper_online.py +33 -20
whisper_online.py CHANGED
@@ -26,12 +26,15 @@ class ASRBase:
26
  sep = " " # join transcribe words with this character (" " for whisper_timestamped,
27
  # "" for faster-whisper because it emits the spaces when neeeded)
28
 
29
- def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
 
 
30
  self.transcribe_kargs = {}
31
  self.original_language = lan
32
 
33
  self.model = self.load_model(modelsize, cache_dir, model_dir)
34
 
 
35
  def load_model(self, modelsize, cache_dir):
36
  raise NotImplemented("must be implemented in the child class")
37
 
@@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase):
50
  sep = " "
51
 
52
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
53
- global whisper_timestamped # has to be global as it is used at each `transcribe` call
54
  import whisper
55
- import whisper_timestamped
 
56
  if model_dir is not None:
57
  print("ignoring model_dir, not implemented",file=self.logfile)
58
  return whisper.load_model(modelsize, download_root=cache_dir)
59
 
60
  def transcribe(self, audio, init_prompt=""):
61
- result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
 
 
 
62
  return result
63
 
64
  def ts_words(self,r):
@@ -74,7 +80,12 @@ class WhisperTimestampedASR(ASRBase):
74
  return [s["end"] for s in res["segments"]]
75
 
76
  def use_vad(self):
77
- raise NotImplemented("Feature use_vad is not implemented for whisper_timestamped backend.")
 
 
 
 
 
78
 
79
 
80
  class FasterWhisperASR(ASRBase):
@@ -135,7 +146,6 @@ class FasterWhisperASR(ASRBase):
135
  class HypothesisBuffer:
136
 
137
  def __init__(self, logfile=sys.stderr):
138
- """output: where to store the log. Leave it unchanged to print to terminal."""
139
  self.commited_in_buffer = []
140
  self.buffer = []
141
  self.new = []
@@ -205,7 +215,7 @@ class OnlineASRProcessor:
205
  def __init__(self, asr, tokenizer, logfile=sys.stderr):
206
  """asr: WhisperASR object
207
  tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
208
- output: where to store the log. Leave it unchanged to print to terminal.
209
  """
210
  self.asr = asr
211
  self.tokenizer = tokenizer
@@ -468,21 +478,24 @@ if __name__ == "__main__":
468
  parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
469
  args = parser.parse_args()
470
 
 
 
 
471
  if args.offline and args.comp_unaware:
472
- print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=sys.stderr)
473
  sys.exit(1)
474
 
475
  audio_path = args.audio_path
476
 
477
  SAMPLING_RATE = 16000
478
  duration = len(load_audio(audio_path))/SAMPLING_RATE
479
- print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
480
 
481
  size = args.model
482
  language = args.lan
483
 
484
  t = time.time()
485
- print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
486
 
487
  if args.backend == "faster-whisper":
488
  asr_cls = FasterWhisperASR
@@ -499,15 +512,15 @@ if __name__ == "__main__":
499
 
500
 
501
  e = time.time()
502
- print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
503
 
504
  if args.vad:
505
- print("setting VAD filter",file=sys.stderr)
506
  asr.use_vad()
507
 
508
 
509
  min_chunk = args.min_chunk_size
510
- online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))
511
 
512
 
513
  # load the audio into the LRU cache before we start the timer
@@ -529,10 +542,10 @@ if __name__ == "__main__":
529
  if now is None:
530
  now = time.time()-start
531
  if o[0] is not None:
532
- print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=sys.stderr,flush=True)
533
  print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
534
  else:
535
- print(o,file=sys.stderr,flush=True)
536
 
537
  if args.offline: ## offline mode processing (for testing/debugging)
538
  a = load_audio(audio_path)
@@ -540,7 +553,7 @@ if __name__ == "__main__":
540
  try:
541
  o = online.process_iter()
542
  except AssertionError:
543
- print("assertion error",file=sys.stderr)
544
  pass
545
  else:
546
  output_transcript(o)
@@ -553,12 +566,12 @@ if __name__ == "__main__":
553
  try:
554
  o = online.process_iter()
555
  except AssertionError:
556
- print("assertion error",file=sys.stderr)
557
  pass
558
  else:
559
  output_transcript(o, now=end)
560
 
561
- print(f"## last processed {end:.2f}s",file=sys.stderr,flush=True)
562
 
563
  beg = end
564
  end += min_chunk
@@ -580,12 +593,12 @@ if __name__ == "__main__":
580
  try:
581
  o = online.process_iter()
582
  except AssertionError:
583
- print("assertion error",file=sys.stderr)
584
  pass
585
  else:
586
  output_transcript(o)
587
  now = time.time() - start
588
- print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr,flush=True)
589
 
590
  if end >= duration:
591
  break
 
26
  sep = " " # join transcribe words with this character (" " for whisper_timestamped,
27
  # "" for faster-whisper because it emits the spaces when neeeded)
28
 
29
+ def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
30
+ self.logfile = logfile
31
+
32
  self.transcribe_kargs = {}
33
  self.original_language = lan
34
 
35
  self.model = self.load_model(modelsize, cache_dir, model_dir)
36
 
37
+
38
  def load_model(self, modelsize, cache_dir):
39
  raise NotImplemented("must be implemented in the child class")
40
 
 
53
  sep = " "
54
 
55
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
 
56
  import whisper
57
+ from whisper_timestamped import transcribe_timestamped
58
+ self.transcribe_timestamped = transcribe_timestamped
59
  if model_dir is not None:
60
  print("ignoring model_dir, not implemented",file=self.logfile)
61
  return whisper.load_model(modelsize, download_root=cache_dir)
62
 
63
  def transcribe(self, audio, init_prompt=""):
64
+ result = self.transcribe_timestamped(self.model,
65
+ audio, language=self.original_language,
66
+ initial_prompt=init_prompt, verbose=None,
67
+ condition_on_previous_text=True, **self.transcribe_kargs)
68
  return result
69
 
70
  def ts_words(self,r):
 
80
  return [s["end"] for s in res["segments"]]
81
 
82
  def use_vad(self):
83
+ self.transcribe_kargs["vad"] = True
84
+
85
+ def set_translate_task(self):
86
+ self.transcribe_kargs["task"] = "translate"
87
+
88
+
89
 
90
 
91
  class FasterWhisperASR(ASRBase):
 
146
  class HypothesisBuffer:
147
 
148
  def __init__(self, logfile=sys.stderr):
 
149
  self.commited_in_buffer = []
150
  self.buffer = []
151
  self.new = []
 
215
  def __init__(self, asr, tokenizer, logfile=sys.stderr):
216
  """asr: WhisperASR object
217
  tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
218
+ logfile: where to store the log.
219
  """
220
  self.asr = asr
221
  self.tokenizer = tokenizer
 
478
  parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
479
  args = parser.parse_args()
480
 
481
+ # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
482
+ logfile = sys.stderr
483
+
484
  if args.offline and args.comp_unaware:
485
+ print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
486
  sys.exit(1)
487
 
488
  audio_path = args.audio_path
489
 
490
  SAMPLING_RATE = 16000
491
  duration = len(load_audio(audio_path))/SAMPLING_RATE
492
+ print("Audio duration is: %2.2f seconds" % duration, file=logfile)
493
 
494
  size = args.model
495
  language = args.lan
496
 
497
  t = time.time()
498
+ print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
499
 
500
  if args.backend == "faster-whisper":
501
  asr_cls = FasterWhisperASR
 
512
 
513
 
514
  e = time.time()
515
+ print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
516
 
517
  if args.vad:
518
+ print("setting VAD filter",file=logfile)
519
  asr.use_vad()
520
 
521
 
522
  min_chunk = args.min_chunk_size
523
+ online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile)
524
 
525
 
526
  # load the audio into the LRU cache before we start the timer
 
542
  if now is None:
543
  now = time.time()-start
544
  if o[0] is not None:
545
+ print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
546
  print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
547
  else:
548
+ print(o,file=logfile,flush=True)
549
 
550
  if args.offline: ## offline mode processing (for testing/debugging)
551
  a = load_audio(audio_path)
 
553
  try:
554
  o = online.process_iter()
555
  except AssertionError:
556
+ print("assertion error",file=logfile)
557
  pass
558
  else:
559
  output_transcript(o)
 
566
  try:
567
  o = online.process_iter()
568
  except AssertionError:
569
+ print("assertion error",file=logfile)
570
  pass
571
  else:
572
  output_transcript(o, now=end)
573
 
574
+ print(f"## last processed {end:.2f}s",file=logfile,flush=True)
575
 
576
  beg = end
577
  end += min_chunk
 
593
  try:
594
  o = online.process_iter()
595
  except AssertionError:
596
+ print("assertion error",file=logfile)
597
  pass
598
  else:
599
  output_transcript(o)
600
  now = time.time() - start
601
+ print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
602
 
603
  if end >= duration:
604
  break