Dominik Macháček commited on
Commit
b1878ce
·
1 Parent(s): 8116b21

offline option

Browse files
Files changed (1) hide show
  1. whisper_online.py +54 -35
whisper_online.py CHANGED
@@ -22,6 +22,8 @@ def load_audio_chunk(fname, beg, end):
22
 
23
  class ASRBase:
24
 
 
 
25
  def __init__(self, modelsize, lan, cache_dir):
26
  self.original_language = lan
27
 
@@ -74,6 +76,8 @@ class FasterWhisperASR(ASRBase):
74
  import faster_whisper
75
  """
76
 
 
 
77
  def load_model(self, modelsize, cache_dir):
78
  # cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
79
 
@@ -98,8 +102,8 @@ class FasterWhisperASR(ASRBase):
98
  o = []
99
  for segment in segments:
100
  for word in segment.words:
101
- # stripping the spaces
102
- w = word.word.strip()
103
  t = (word.start, word.end, w)
104
  o.append(t)
105
  return o
@@ -109,19 +113,6 @@ class FasterWhisperASR(ASRBase):
109
 
110
 
111
 
112
- def to_flush(sents, offset=0):
113
- # concatenates the timestamped words or sentences into one sequence that is flushed in one line
114
- # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
115
- # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
116
- t = " ".join(s[2] for s in sents)
117
- if len(sents) == 0:
118
- b = None
119
- e = None
120
- else:
121
- b = offset + sents[0][0]
122
- e = offset + sents[-1][1]
123
- return (b,e,t)
124
-
125
  class HypothesisBuffer:
126
 
127
  def __init__(self):
@@ -254,8 +245,8 @@ class OnlineASRProcessor:
254
  self.transcript_buffer.insert(tsw, self.buffer_time_offset)
255
  o = self.transcript_buffer.flush()
256
  self.commited.extend(o)
257
- print(">>>>COMPLETE NOW:",to_flush(o),file=sys.stderr,flush=True)
258
- print("INCOMPLETE:",to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True)
259
 
260
  # there is a newly confirmed text
261
  if o:
@@ -301,7 +292,7 @@ class OnlineASRProcessor:
301
  #self.chunk_at(t)
302
 
303
  print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=sys.stderr)
304
- return to_flush(o)
305
 
306
  def chunk_completed_sentence(self):
307
  if self.commited == []: return
@@ -383,11 +374,26 @@ class OnlineASRProcessor:
383
  Returns: the same format as self.process_iter()
384
  """
385
  o = self.transcript_buffer.complete()
386
- f = to_flush(o)
387
  print("last, noncommited:",f,file=sys.stderr)
388
  return f
389
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
 
393
  ## main:
@@ -401,6 +407,7 @@ parser.add_argument('--model_dir', type=str, default='disk-cache-dir', help="the
401
  parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
402
  parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
403
  parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
 
404
  args = parser.parse_args()
405
 
406
  audio_path = args.audio_path
@@ -440,6 +447,9 @@ a = load_audio_chunk(audio_path,0,1)
440
  # warm up the ASR, because the very first transcribe takes much more time than the other
441
  asr.transcribe(a)
442
 
 
 
 
443
  def output_transcript(o):
444
  # output format in stdout is like:
445
  # 4186.3606 0 1720 Takhle to je
@@ -453,18 +463,9 @@ def output_transcript(o):
453
  else:
454
  print(o,file=sys.stderr,flush=True)
455
 
456
- beg = args.start_at
457
- end = 0
458
- start = time.time()-beg
459
- while True:
460
- now = time.time() - start
461
- if now < end+min_chunk:
462
- time.sleep(min_chunk+end-now)
463
- end = time.time() - start
464
- a = load_audio_chunk(audio_path,beg,end)
465
- beg = end
466
  online.insert_audio_chunk(a)
467
-
468
  try:
469
  o = online.process_iter()
470
  except AssertionError:
@@ -472,13 +473,31 @@ while True:
472
  pass
473
  else:
474
  output_transcript(o)
475
- now = time.time() - start
476
- print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- print(file=sys.stderr,flush=True)
479
 
480
- if end >= duration:
481
- break
482
 
483
  o = online.finish()
484
  output_transcript(o)
 
22
 
23
  class ASRBase:
24
 
25
+ sep = " "
26
+
27
  def __init__(self, modelsize, lan, cache_dir):
28
  self.original_language = lan
29
 
 
76
  import faster_whisper
77
  """
78
 
79
+ sep = ""
80
+
81
  def load_model(self, modelsize, cache_dir):
82
  # cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
83
 
 
102
  o = []
103
  for segment in segments:
104
  for word in segment.words:
105
+ # not stripping the spaces -- should not be merged with them!
106
+ w = word.word
107
  t = (word.start, word.end, w)
108
  o.append(t)
109
  return o
 
113
 
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  class HypothesisBuffer:
117
 
118
  def __init__(self):
 
245
  self.transcript_buffer.insert(tsw, self.buffer_time_offset)
246
  o = self.transcript_buffer.flush()
247
  self.commited.extend(o)
248
+ print(">>>>COMPLETE NOW:",self.to_flush(o),file=sys.stderr,flush=True)
249
+ print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True)
250
 
251
  # there is a newly confirmed text
252
  if o:
 
292
  #self.chunk_at(t)
293
 
294
  print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=sys.stderr)
295
+ return self.to_flush(o)
296
 
297
  def chunk_completed_sentence(self):
298
  if self.commited == []: return
 
374
  Returns: the same format as self.process_iter()
375
  """
376
  o = self.transcript_buffer.complete()
377
+ f = self.to_flush(o)
378
  print("last, noncommited:",f,file=sys.stderr)
379
  return f
380
 
381
 
382
+ def to_flush(self, sents, sep=None, offset=0, ):
383
+ # concatenates the timestamped words or sentences into one sequence that is flushed in one line
384
+ # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
385
+ # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
386
+ if sep is None:
387
+ sep = self.asr.sep
388
+ t = sep.join(s[2] for s in sents)
389
+ if len(sents) == 0:
390
+ b = None
391
+ e = None
392
+ else:
393
+ b = offset + sents[0][0]
394
+ e = offset + sents[-1][1]
395
+ return (b,e,t)
396
+
397
 
398
 
399
  ## main:
 
407
  parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
408
  parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
409
  parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
410
+ parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
411
  args = parser.parse_args()
412
 
413
  audio_path = args.audio_path
 
447
  # warm up the ASR, because the very first transcribe takes much more time than the other
448
  asr.transcribe(a)
449
 
450
+ beg = args.start_at
451
+ start = time.time()-beg
452
+
453
  def output_transcript(o):
454
  # output format in stdout is like:
455
  # 4186.3606 0 1720 Takhle to je
 
463
  else:
464
  print(o,file=sys.stderr,flush=True)
465
 
466
+ if args.offline: ## offline mode processing (for testing/debugging)
467
+ a = load_audio(audio_path)
 
 
 
 
 
 
 
 
468
  online.insert_audio_chunk(a)
 
469
  try:
470
  o = online.process_iter()
471
  except AssertionError:
 
473
  pass
474
  else:
475
  output_transcript(o)
476
+ else: # online = simultaneous mode
477
+ end = 0
478
+ while True:
479
+ now = time.time() - start
480
+ if now < end+min_chunk:
481
+ time.sleep(min_chunk+end-now)
482
+ end = time.time() - start
483
+ a = load_audio_chunk(audio_path,beg,end)
484
+ beg = end
485
+ online.insert_audio_chunk(a)
486
+
487
+ try:
488
+ o = online.process_iter()
489
+ except AssertionError:
490
+ print("assertion error",file=sys.stderr)
491
+ pass
492
+ else:
493
+ output_transcript(o)
494
+ now = time.time() - start
495
+ print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr)
496
 
497
+ print(file=sys.stderr,flush=True)
498
 
499
+ if end >= duration:
500
+ break
501
 
502
  o = online.finish()
503
  output_transcript(o)