Tijs Zwinkels commited on
Commit
f0a24cd
·
1 Parent(s): 3696fef

Make --vad work with --backend openai-api

Browse files
Files changed (1) hide show
  1. whisper_online.py +22 -16
whisper_online.py CHANGED
@@ -162,7 +162,7 @@ class OpenaiApiASR(ASRBase):
162
 
163
  self.load_model()
164
 
165
- self.use_vad = False
166
 
167
  # reset the task in set_translate_task
168
  self.task = "transcribe"
@@ -175,21 +175,27 @@ class OpenaiApiASR(ASRBase):
175
 
176
 
177
  def ts_words(self, segments):
178
- o = []
179
- # If VAD on, skip segments containing no speech.
180
- # TODO: threshold can be set from outside
181
- # TODO: Make VAD work again with word-level timestamps
182
- #if self.use_vad and segment["no_speech_prob"] > 0.8:
183
- # continue
184
 
185
- for word in segments:
186
- o.append((word.get("start"), word.get("end"), word.get("word")))
 
 
 
 
 
 
187
 
188
  return o
189
 
190
 
191
  def segments_end_ts(self, res):
192
- return [s["end"] for s in res]
193
 
194
  def transcribe(self, audio_data, prompt=None, *args, **kwargs):
195
  # Write the audio data to a buffer
@@ -205,7 +211,7 @@ class OpenaiApiASR(ASRBase):
205
  "file": buffer,
206
  "response_format": self.response_format,
207
  "temperature": self.temperature,
208
- "timestamp_granularities": ["word"]
209
  }
210
  if self.task != "translate" and self.language:
211
  params["language"] = self.language
@@ -221,10 +227,10 @@ class OpenaiApiASR(ASRBase):
221
  transcript = proc.create(**params)
222
  print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
223
 
224
- return transcript.words
225
 
226
  def use_vad(self):
227
- self.use_vad = True
228
 
229
  def set_translate_task(self):
230
  self.task = "translate"
@@ -592,9 +598,9 @@ if __name__ == "__main__":
592
  e = time.time()
593
  print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
594
 
595
- if args.vad:
596
- print("setting VAD filter",file=logfile)
597
- asr.use_vad()
598
 
599
  if args.task == "translate":
600
  asr.set_translate_task()
 
162
 
163
  self.load_model()
164
 
165
+ self.use_vad_opt = False
166
 
167
  # reset the task in set_translate_task
168
  self.task = "transcribe"
 
175
 
176
 
177
  def ts_words(self, segments):
178
+ no_speech_segments = []
179
+ if self.use_vad_opt:
180
+ for segment in segments.segments:
181
+ # TODO: threshold can be set from outside
182
+ if segment["no_speech_prob"] > 0.8:
183
+ no_speech_segments.append((segment.get("start"), segment.get("end")))
184
 
185
+ o = []
186
+ for word in segments.words:
187
+ start = word.get("start")
188
+ end = word.get("end")
189
+ if any(s[0] <= start <= s[1] for s in no_speech_segments):
190
+ # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
191
+ continue
192
+ o.append((start, end, word.get("word")))
193
 
194
  return o
195
 
196
 
197
  def segments_end_ts(self, res):
198
+ return [s["end"] for s in res.words]
199
 
200
  def transcribe(self, audio_data, prompt=None, *args, **kwargs):
201
  # Write the audio data to a buffer
 
211
  "file": buffer,
212
  "response_format": self.response_format,
213
  "temperature": self.temperature,
214
+ "timestamp_granularities": ["word", "segment"]
215
  }
216
  if self.task != "translate" and self.language:
217
  params["language"] = self.language
 
227
  transcript = proc.create(**params)
228
  print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
229
 
230
+ return transcript
231
 
232
  def use_vad(self):
233
+ self.use_vad_opt = True
234
 
235
  def set_translate_task(self):
236
  self.task = "translate"
 
598
  e = time.time()
599
  print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
600
 
601
+ if args.vad:
602
+ print("setting VAD filter",file=logfile)
603
+ asr.use_vad()
604
 
605
  if args.task == "translate":
606
  asr.set_translate_task()