Dominik Macháček
commited on
Commit
·
50f1b94
1
Parent(s):
ab27bfb
missing features in openai-api, PR #52
Browse files- whisper_online.py +56 -32
whisper_online.py
CHANGED
@@ -6,8 +6,7 @@ from functools import lru_cache
|
|
6 |
import time
|
7 |
import io
|
8 |
import soundfile as sf
|
9 |
-
|
10 |
-
|
11 |
|
12 |
@lru_cache
|
13 |
def load_audio(fname):
|
@@ -147,24 +146,34 @@ class FasterWhisperASR(ASRBase):
|
|
147 |
class OpenaiApiASR(ASRBase):
|
148 |
"""Uses OpenAI's Whisper API for audio transcription."""
|
149 |
|
150 |
-
def __init__(self,
|
151 |
-
self.
|
|
|
|
|
152 |
self.language = lan # ISO-639-1 language code
|
153 |
self.response_format = response_format
|
154 |
self.temperature = temperature
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def load_model(self, *args, **kwargs):
|
158 |
from openai import OpenAI
|
159 |
self.client = OpenAI()
|
160 |
-
|
161 |
-
|
|
|
162 |
|
163 |
def ts_words(self, segments):
|
164 |
o = []
|
165 |
for segment in segments:
|
166 |
-
#
|
167 |
-
|
|
|
168 |
continue
|
169 |
|
170 |
# Splitting the text into words and filtering out empty strings
|
@@ -197,23 +206,39 @@ class OpenaiApiASR(ASRBase):
|
|
197 |
sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
|
198 |
buffer.seek(0) # Reset buffer's position to the beginning
|
199 |
|
200 |
-
#
|
201 |
-
|
|
|
202 |
"model": self.modelname,
|
203 |
"file": buffer,
|
204 |
"response_format": self.response_format,
|
205 |
"temperature": self.temperature
|
206 |
}
|
207 |
-
if self.language:
|
208 |
transcription_params["language"] = self.language
|
209 |
if prompt:
|
210 |
transcription_params["prompt"] = prompt
|
211 |
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
return transcript.segments
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
class HypothesisBuffer:
|
219 |
|
@@ -557,20 +582,27 @@ if __name__ == "__main__":
|
|
557 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
558 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
559 |
|
560 |
-
size = args.model
|
561 |
language = args.lan
|
562 |
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
if args.backend == "faster-whisper":
|
567 |
-
asr_cls = FasterWhisperASR
|
568 |
-
elif args.backend == "openai-api":
|
569 |
-
asr_cls = OpenaiApiASR
|
570 |
else:
|
571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
|
573 |
-
|
|
|
|
|
574 |
|
575 |
if args.task == "translate":
|
576 |
asr.set_translate_task()
|
@@ -578,14 +610,6 @@ if __name__ == "__main__":
|
|
578 |
else:
|
579 |
tgt_language = language # Whisper transcribes in this language
|
580 |
|
581 |
-
|
582 |
-
e = time.time()
|
583 |
-
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
584 |
-
|
585 |
-
if args.vad:
|
586 |
-
print("setting VAD filter",file=logfile)
|
587 |
-
asr.use_vad()
|
588 |
-
|
589 |
|
590 |
min_chunk = args.min_chunk_size
|
591 |
if args.buffer_trimming == "sentence":
|
|
|
6 |
import time
|
7 |
import io
|
8 |
import soundfile as sf
|
9 |
+
import math
|
|
|
10 |
|
11 |
@lru_cache
|
12 |
def load_audio(fname):
|
|
|
146 |
class OpenaiApiASR(ASRBase):
|
147 |
"""Uses OpenAI's Whisper API for audio transcription."""
|
148 |
|
149 |
+
def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr):
|
150 |
+
self.logfile = logfile
|
151 |
+
|
152 |
+
self.modelname = "whisper-1"
|
153 |
self.language = lan # ISO-639-1 language code
|
154 |
self.response_format = response_format
|
155 |
self.temperature = temperature
|
156 |
+
|
157 |
+
self.load_model()
|
158 |
+
|
159 |
+
self.use_vad = False
|
160 |
+
|
161 |
+
# reset the task in set_translate_task
|
162 |
+
self.task = "transcribe"
|
163 |
|
164 |
def load_model(self, *args, **kwargs):
|
165 |
from openai import OpenAI
|
166 |
self.client = OpenAI()
|
167 |
+
|
168 |
+
self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
|
169 |
+
|
170 |
|
171 |
def ts_words(self, segments):
|
172 |
o = []
|
173 |
for segment in segments:
|
174 |
+
# If VAD on, skip segments containing no speech.
|
175 |
+
# TODO: threshold can be set from outside
|
176 |
+
if self.use_vad and segment["no_speech_prob"] > 0.8:
|
177 |
continue
|
178 |
|
179 |
# Splitting the text into words and filtering out empty strings
|
|
|
206 |
sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
|
207 |
buffer.seek(0) # Reset buffer's position to the beginning
|
208 |
|
209 |
+
self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
|
210 |
+
|
211 |
+
params = {
|
212 |
"model": self.modelname,
|
213 |
"file": buffer,
|
214 |
"response_format": self.response_format,
|
215 |
"temperature": self.temperature
|
216 |
}
|
217 |
+
if self.task != "translate" and self.language:
|
218 |
transcription_params["language"] = self.language
|
219 |
if prompt:
|
220 |
transcription_params["prompt"] = prompt
|
221 |
|
222 |
+
if self.task == "translate":
|
223 |
+
proc = self.client.audio.translations
|
224 |
+
else:
|
225 |
+
proc = self.client.audio.transcriptions
|
226 |
+
|
227 |
+
# Process transcription/translation
|
228 |
+
|
229 |
+
transcript = proc.create(**params)
|
230 |
+
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds ",file=self.logfile)
|
231 |
|
232 |
return transcript.segments
|
233 |
|
234 |
+
def use_vad(self):
|
235 |
+
self.use_vad = True
|
236 |
+
|
237 |
+
def set_translate_task(self):
|
238 |
+
self.task = "translate"
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
|
243 |
class HypothesisBuffer:
|
244 |
|
|
|
582 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
583 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
584 |
|
|
|
585 |
language = args.lan
|
586 |
|
587 |
+
if args.backend == "openai-api":
|
588 |
+
print("Using OpenAI API.",file=logfile)
|
589 |
+
asr = OpenaiApiASR(lan=language)
|
|
|
|
|
|
|
|
|
590 |
else:
|
591 |
+
if args.backend == "faster-whisper":
|
592 |
+
asr_cls = FasterWhisperASR
|
593 |
+
else:
|
594 |
+
asr_cls = WhisperTimestampedASR
|
595 |
+
|
596 |
+
size = args.model
|
597 |
+
t = time.time()
|
598 |
+
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
599 |
+
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
600 |
+
e = time.time()
|
601 |
+
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
602 |
|
603 |
+
if args.vad:
|
604 |
+
print("setting VAD filter",file=logfile)
|
605 |
+
asr.use_vad()
|
606 |
|
607 |
if args.task == "translate":
|
608 |
asr.set_translate_task()
|
|
|
610 |
else:
|
611 |
tgt_language = language # Whisper transcribes in this language
|
612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
|
614 |
min_chunk = args.min_chunk_size
|
615 |
if args.buffer_trimming == "sentence":
|