Dominik Macháček
commited on
Commit
·
8f32dea
1
Parent(s):
bd0d848
logfile reviewed, whisper_timestamped loading module and vad
Browse files- whisper_online.py +33 -20
whisper_online.py
CHANGED
@@ -26,12 +26,15 @@ class ASRBase:
|
|
26 |
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
27 |
# "" for faster-whisper because it emits the spaces when neeeded)
|
28 |
|
29 |
-
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
|
|
|
|
|
30 |
self.transcribe_kargs = {}
|
31 |
self.original_language = lan
|
32 |
|
33 |
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
34 |
|
|
|
35 |
def load_model(self, modelsize, cache_dir):
|
36 |
raise NotImplemented("must be implemented in the child class")
|
37 |
|
@@ -50,15 +53,18 @@ class WhisperTimestampedASR(ASRBase):
|
|
50 |
sep = " "
|
51 |
|
52 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
53 |
-
global whisper_timestamped # has to be global as it is used at each `transcribe` call
|
54 |
import whisper
|
55 |
-
import
|
|
|
56 |
if model_dir is not None:
|
57 |
print("ignoring model_dir, not implemented",file=self.logfile)
|
58 |
return whisper.load_model(modelsize, download_root=cache_dir)
|
59 |
|
60 |
def transcribe(self, audio, init_prompt=""):
|
61 |
-
result =
|
|
|
|
|
|
|
62 |
return result
|
63 |
|
64 |
def ts_words(self,r):
|
@@ -74,7 +80,12 @@ class WhisperTimestampedASR(ASRBase):
|
|
74 |
return [s["end"] for s in res["segments"]]
|
75 |
|
76 |
def use_vad(self):
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
|
80 |
class FasterWhisperASR(ASRBase):
|
@@ -135,7 +146,6 @@ class FasterWhisperASR(ASRBase):
|
|
135 |
class HypothesisBuffer:
|
136 |
|
137 |
def __init__(self, logfile=sys.stderr):
|
138 |
-
"""output: where to store the log. Leave it unchanged to print to terminal."""
|
139 |
self.commited_in_buffer = []
|
140 |
self.buffer = []
|
141 |
self.new = []
|
@@ -205,7 +215,7 @@ class OnlineASRProcessor:
|
|
205 |
def __init__(self, asr, tokenizer, logfile=sys.stderr):
|
206 |
"""asr: WhisperASR object
|
207 |
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
|
208 |
-
|
209 |
"""
|
210 |
self.asr = asr
|
211 |
self.tokenizer = tokenizer
|
@@ -468,21 +478,24 @@ if __name__ == "__main__":
|
|
468 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
469 |
args = parser.parse_args()
|
470 |
|
|
|
|
|
|
|
471 |
if args.offline and args.comp_unaware:
|
472 |
-
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=
|
473 |
sys.exit(1)
|
474 |
|
475 |
audio_path = args.audio_path
|
476 |
|
477 |
SAMPLING_RATE = 16000
|
478 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
479 |
-
print("Audio duration is: %2.2f seconds" % duration, file=
|
480 |
|
481 |
size = args.model
|
482 |
language = args.lan
|
483 |
|
484 |
t = time.time()
|
485 |
-
print(f"Loading Whisper {size} model for {language}...",file=
|
486 |
|
487 |
if args.backend == "faster-whisper":
|
488 |
asr_cls = FasterWhisperASR
|
@@ -499,15 +512,15 @@ if __name__ == "__main__":
|
|
499 |
|
500 |
|
501 |
e = time.time()
|
502 |
-
print(f"done. It took {round(e-t,2)} seconds.",file=
|
503 |
|
504 |
if args.vad:
|
505 |
-
print("setting VAD filter",file=
|
506 |
asr.use_vad()
|
507 |
|
508 |
|
509 |
min_chunk = args.min_chunk_size
|
510 |
-
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language))
|
511 |
|
512 |
|
513 |
# load the audio into the LRU cache before we start the timer
|
@@ -529,10 +542,10 @@ if __name__ == "__main__":
|
|
529 |
if now is None:
|
530 |
now = time.time()-start
|
531 |
if o[0] is not None:
|
532 |
-
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=
|
533 |
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
534 |
else:
|
535 |
-
print(o,file=
|
536 |
|
537 |
if args.offline: ## offline mode processing (for testing/debugging)
|
538 |
a = load_audio(audio_path)
|
@@ -540,7 +553,7 @@ if __name__ == "__main__":
|
|
540 |
try:
|
541 |
o = online.process_iter()
|
542 |
except AssertionError:
|
543 |
-
print("assertion error",file=
|
544 |
pass
|
545 |
else:
|
546 |
output_transcript(o)
|
@@ -553,12 +566,12 @@ if __name__ == "__main__":
|
|
553 |
try:
|
554 |
o = online.process_iter()
|
555 |
except AssertionError:
|
556 |
-
print("assertion error",file=
|
557 |
pass
|
558 |
else:
|
559 |
output_transcript(o, now=end)
|
560 |
|
561 |
-
print(f"## last processed {end:.2f}s",file=
|
562 |
|
563 |
beg = end
|
564 |
end += min_chunk
|
@@ -580,12 +593,12 @@ if __name__ == "__main__":
|
|
580 |
try:
|
581 |
o = online.process_iter()
|
582 |
except AssertionError:
|
583 |
-
print("assertion error",file=
|
584 |
pass
|
585 |
else:
|
586 |
output_transcript(o)
|
587 |
now = time.time() - start
|
588 |
-
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=
|
589 |
|
590 |
if end >= duration:
|
591 |
break
|
|
|
26 |
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
27 |
# "" for faster-whisper because it emits the spaces when neeeded)
|
28 |
|
29 |
+
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
30 |
+
self.logfile = logfile
|
31 |
+
|
32 |
self.transcribe_kargs = {}
|
33 |
self.original_language = lan
|
34 |
|
35 |
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
36 |
|
37 |
+
|
38 |
def load_model(self, modelsize, cache_dir):
|
39 |
raise NotImplemented("must be implemented in the child class")
|
40 |
|
|
|
53 |
sep = " "
|
54 |
|
55 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
|
|
56 |
import whisper
|
57 |
+
from whisper_timestamped import transcribe_timestamped
|
58 |
+
self.transcribe_timestamped = transcribe_timestamped
|
59 |
if model_dir is not None:
|
60 |
print("ignoring model_dir, not implemented",file=self.logfile)
|
61 |
return whisper.load_model(modelsize, download_root=cache_dir)
|
62 |
|
63 |
def transcribe(self, audio, init_prompt=""):
|
64 |
+
result = self.transcribe_timestamped(self.model,
|
65 |
+
audio, language=self.original_language,
|
66 |
+
initial_prompt=init_prompt, verbose=None,
|
67 |
+
condition_on_previous_text=True, **self.transcribe_kargs)
|
68 |
return result
|
69 |
|
70 |
def ts_words(self,r):
|
|
|
80 |
return [s["end"] for s in res["segments"]]
|
81 |
|
82 |
def use_vad(self):
|
83 |
+
self.transcribe_kargs["vad"] = True
|
84 |
+
|
85 |
+
def set_translate_task(self):
|
86 |
+
self.transcribe_kargs["task"] = "translate"
|
87 |
+
|
88 |
+
|
89 |
|
90 |
|
91 |
class FasterWhisperASR(ASRBase):
|
|
|
146 |
class HypothesisBuffer:
|
147 |
|
148 |
def __init__(self, logfile=sys.stderr):
|
|
|
149 |
self.commited_in_buffer = []
|
150 |
self.buffer = []
|
151 |
self.new = []
|
|
|
215 |
def __init__(self, asr, tokenizer, logfile=sys.stderr):
|
216 |
"""asr: WhisperASR object
|
217 |
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
|
218 |
+
logfile: where to store the log.
|
219 |
"""
|
220 |
self.asr = asr
|
221 |
self.tokenizer = tokenizer
|
|
|
478 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
479 |
args = parser.parse_args()
|
480 |
|
481 |
+
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
|
482 |
+
logfile = sys.stderr
|
483 |
+
|
484 |
if args.offline and args.comp_unaware:
|
485 |
+
print("No or one option from --offline and --comp_unaware are available, not both. Exiting.",file=logfile)
|
486 |
sys.exit(1)
|
487 |
|
488 |
audio_path = args.audio_path
|
489 |
|
490 |
SAMPLING_RATE = 16000
|
491 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
492 |
+
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
493 |
|
494 |
size = args.model
|
495 |
language = args.lan
|
496 |
|
497 |
t = time.time()
|
498 |
+
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
499 |
|
500 |
if args.backend == "faster-whisper":
|
501 |
asr_cls = FasterWhisperASR
|
|
|
512 |
|
513 |
|
514 |
e = time.time()
|
515 |
+
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
516 |
|
517 |
if args.vad:
|
518 |
+
print("setting VAD filter",file=logfile)
|
519 |
asr.use_vad()
|
520 |
|
521 |
|
522 |
min_chunk = args.min_chunk_size
|
523 |
+
online = OnlineASRProcessor(asr,create_tokenizer(tgt_language),logfile=logfile)
|
524 |
|
525 |
|
526 |
# load the audio into the LRU cache before we start the timer
|
|
|
542 |
if now is None:
|
543 |
now = time.time()-start
|
544 |
if o[0] is not None:
|
545 |
+
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
|
546 |
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
547 |
else:
|
548 |
+
print(o,file=logfile,flush=True)
|
549 |
|
550 |
if args.offline: ## offline mode processing (for testing/debugging)
|
551 |
a = load_audio(audio_path)
|
|
|
553 |
try:
|
554 |
o = online.process_iter()
|
555 |
except AssertionError:
|
556 |
+
print("assertion error",file=logfile)
|
557 |
pass
|
558 |
else:
|
559 |
output_transcript(o)
|
|
|
566 |
try:
|
567 |
o = online.process_iter()
|
568 |
except AssertionError:
|
569 |
+
print("assertion error",file=logfile)
|
570 |
pass
|
571 |
else:
|
572 |
output_transcript(o, now=end)
|
573 |
|
574 |
+
print(f"## last processed {end:.2f}s",file=logfile,flush=True)
|
575 |
|
576 |
beg = end
|
577 |
end += min_chunk
|
|
|
593 |
try:
|
594 |
o = online.process_iter()
|
595 |
except AssertionError:
|
596 |
+
print("assertion error",file=logfile)
|
597 |
pass
|
598 |
else:
|
599 |
output_transcript(o)
|
600 |
now = time.time() - start
|
601 |
+
print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=logfile,flush=True)
|
602 |
|
603 |
if end >= duration:
|
604 |
break
|