Dominik Macháček commited on
Commit
50f1b94
·
1 Parent(s): ab27bfb

missing features in openai-api, PR #52

Browse files
Files changed (1) hide show
  1. whisper_online.py +56 -32
whisper_online.py CHANGED
@@ -6,8 +6,7 @@ from functools import lru_cache
6
  import time
7
  import io
8
  import soundfile as sf
9
-
10
-
11
 
12
  @lru_cache
13
  def load_audio(fname):
@@ -147,24 +146,34 @@ class FasterWhisperASR(ASRBase):
147
  class OpenaiApiASR(ASRBase):
148
  """Uses OpenAI's Whisper API for audio transcription."""
149
 
150
- def __init__(self, modelsize=None, lan=None, cache_dir=None, model_dir=None, response_format="verbose_json", temperature=0):
151
- self.modelname = "whisper-1" # modelsize is not used but kept for interface consistency
 
 
152
  self.language = lan # ISO-639-1 language code
153
  self.response_format = response_format
154
  self.temperature = temperature
155
- self.model = self.load_model(modelsize, cache_dir, model_dir)
 
 
 
 
 
 
156
 
157
  def load_model(self, *args, **kwargs):
158
  from openai import OpenAI
159
  self.client = OpenAI()
160
- # Since we're using the OpenAI API, there's no model to load locally.
161
- print("Model configuration is set to use the OpenAI Whisper API.")
 
162
 
163
  def ts_words(self, segments):
164
  o = []
165
  for segment in segments:
166
- # Skip segments containing no speech
167
- if segment["no_speech_prob"] > 0.8:
 
168
  continue
169
 
170
  # Splitting the text into words and filtering out empty strings
@@ -197,23 +206,39 @@ class OpenaiApiASR(ASRBase):
197
  sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
198
  buffer.seek(0) # Reset buffer's position to the beginning
199
 
200
- # Prepare transcription parameters
201
- transcription_params = {
 
202
  "model": self.modelname,
203
  "file": buffer,
204
  "response_format": self.response_format,
205
  "temperature": self.temperature
206
  }
207
- if self.language:
208
  transcription_params["language"] = self.language
209
  if prompt:
210
  transcription_params["prompt"] = prompt
211
 
212
- # Perform the transcription
213
- transcript = self.client.audio.transcriptions.create(**transcription_params)
 
 
 
 
 
 
 
214
 
215
  return transcript.segments
216
 
 
 
 
 
 
 
 
 
217
 
218
  class HypothesisBuffer:
219
 
@@ -557,20 +582,27 @@ if __name__ == "__main__":
557
  duration = len(load_audio(audio_path))/SAMPLING_RATE
558
  print("Audio duration is: %2.2f seconds" % duration, file=logfile)
559
 
560
- size = args.model
561
  language = args.lan
562
 
563
- t = time.time()
564
- print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
565
-
566
- if args.backend == "faster-whisper":
567
- asr_cls = FasterWhisperASR
568
- elif args.backend == "openai-api":
569
- asr_cls = OpenaiApiASR
570
  else:
571
- asr_cls = WhisperTimestampedASR
 
 
 
 
 
 
 
 
 
 
572
 
573
- asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
 
 
574
 
575
  if args.task == "translate":
576
  asr.set_translate_task()
@@ -578,14 +610,6 @@ if __name__ == "__main__":
578
  else:
579
  tgt_language = language # Whisper transcribes in this language
580
 
581
-
582
- e = time.time()
583
- print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
584
-
585
- if args.vad:
586
- print("setting VAD filter",file=logfile)
587
- asr.use_vad()
588
-
589
 
590
  min_chunk = args.min_chunk_size
591
  if args.buffer_trimming == "sentence":
 
6
  import time
7
  import io
8
  import soundfile as sf
9
+ import math
 
10
 
11
  @lru_cache
12
  def load_audio(fname):
 
146
  class OpenaiApiASR(ASRBase):
147
  """Uses OpenAI's Whisper API for audio transcription."""
148
 
149
+ def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr):
150
+ self.logfile = logfile
151
+
152
+ self.modelname = "whisper-1"
153
  self.language = lan # ISO-639-1 language code
154
  self.response_format = response_format
155
  self.temperature = temperature
156
+
157
+ self.load_model()
158
+
159
+ self.use_vad = False
160
+
161
+ # reset the task in set_translate_task
162
+ self.task = "transcribe"
163
 
164
  def load_model(self, *args, **kwargs):
165
  from openai import OpenAI
166
  self.client = OpenAI()
167
+
168
+ self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
169
+
170
 
171
  def ts_words(self, segments):
172
  o = []
173
  for segment in segments:
174
+ # If VAD on, skip segments containing no speech.
175
+ # TODO: threshold can be set from outside
176
+ if self.use_vad and segment["no_speech_prob"] > 0.8:
177
  continue
178
 
179
  # Splitting the text into words and filtering out empty strings
 
206
  sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
207
  buffer.seek(0) # Reset buffer's position to the beginning
208
 
209
+ self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
210
+
211
+ params = {
212
  "model": self.modelname,
213
  "file": buffer,
214
  "response_format": self.response_format,
215
  "temperature": self.temperature
216
  }
217
+ if self.task != "translate" and self.language:
218
  transcription_params["language"] = self.language
219
  if prompt:
220
  transcription_params["prompt"] = prompt
221
 
222
+ if self.task == "translate":
223
+ proc = self.client.audio.translations
224
+ else:
225
+ proc = self.client.audio.transcriptions
226
+
227
+ # Process transcription/translation
228
+
229
+ transcript = proc.create(**params)
230
+ print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds ",file=self.logfile)
231
 
232
  return transcript.segments
233
 
234
+ def use_vad(self):
235
+ self.use_vad = True
236
+
237
+ def set_translate_task(self):
238
+ self.task = "translate"
239
+
240
+
241
+
242
 
243
  class HypothesisBuffer:
244
 
 
582
  duration = len(load_audio(audio_path))/SAMPLING_RATE
583
  print("Audio duration is: %2.2f seconds" % duration, file=logfile)
584
 
 
585
  language = args.lan
586
 
587
+ if args.backend == "openai-api":
588
+ print("Using OpenAI API.",file=logfile)
589
+ asr = OpenaiApiASR(lan=language)
 
 
 
 
590
  else:
591
+ if args.backend == "faster-whisper":
592
+ asr_cls = FasterWhisperASR
593
+ else:
594
+ asr_cls = WhisperTimestampedASR
595
+
596
+ size = args.model
597
+ t = time.time()
598
+ print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
599
+ asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
600
+ e = time.time()
601
+ print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
602
 
603
+ if args.vad:
604
+ print("setting VAD filter",file=logfile)
605
+ asr.use_vad()
606
 
607
  if args.task == "translate":
608
  asr.set_translate_task()
 
610
  else:
611
  tgt_language = language # Whisper transcribes in this language
612
 
 
 
 
 
 
 
 
 
613
 
614
  min_chunk = args.min_chunk_size
615
  if args.buffer_trimming == "sentence":