Dominik Macháček commited on
Commit
949304a
·
2 Parent(s): d65fd8a 9fcd403

Merge branch 'opeanai-api2' into opeanai-api

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. whisper_online.py +35 -38
README.md CHANGED
@@ -91,7 +91,7 @@ options:
91
  --model_dir MODEL_DIR
92
  Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.
93
  --lan LAN, --language LAN
94
- Language code for transcription, e.g. en,de,cs.
95
  --task {transcribe,translate}
96
  Transcribe or translate.
97
  --backend {faster-whisper,whisper_timestamped,openai-api}
 
91
  --model_dir MODEL_DIR
92
  Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.
93
  --lan LAN, --language LAN
94
+ Source language code, e.g. en,de,cs, or 'auto' for language detection.
95
  --task {transcribe,translate}
96
  Transcribe or translate.
97
  --backend {faster-whisper,whisper_timestamped,openai-api}
whisper_online.py CHANGED
@@ -31,7 +31,10 @@ class ASRBase:
31
  self.logfile = logfile
32
 
33
  self.transcribe_kargs = {}
34
- self.original_language = lan
 
 
 
35
 
36
  self.model = self.load_model(modelsize, cache_dir, model_dir)
37
 
@@ -119,8 +122,11 @@ class FasterWhisperASR(ASRBase):
119
  return model
120
 
121
  def transcribe(self, audio, init_prompt=""):
 
122
  # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
123
  segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
 
 
124
  return list(segments)
125
 
126
  def ts_words(self, segments):
@@ -146,17 +152,17 @@ class FasterWhisperASR(ASRBase):
146
  class OpenaiApiASR(ASRBase):
147
  """Uses OpenAI's Whisper API for audio transcription."""
148
 
149
- def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr):
150
  self.logfile = logfile
151
 
152
  self.modelname = "whisper-1"
153
- self.language = lan # ISO-639-1 language code
154
- self.response_format = response_format
155
  self.temperature = temperature
156
 
157
  self.load_model()
158
 
159
- self.use_vad = False
160
 
161
  # reset the task in set_translate_task
162
  self.task = "transcribe"
@@ -169,35 +175,26 @@ class OpenaiApiASR(ASRBase):
169
 
170
 
171
  def ts_words(self, segments):
172
- o = []
173
- for segment in segments:
174
- # If VAD on, skip segments containing no speech.
175
- # TODO: threshold can be set from outside
176
- if self.use_vad and segment["no_speech_prob"] > 0.8:
177
- continue
178
 
179
- # Splitting the text into words and filtering out empty strings
180
- words = [word.strip() for word in segment["text"].split() if word.strip()]
181
-
182
- if not words:
 
 
183
  continue
184
-
185
- # Assign start and end times for each word
186
- # We only have timestamps per segment, so interpolating start and end-times
187
- # assuming equal duration per word
188
- segment_duration = segment["end"] - segment["start"]
189
- duration_per_word = segment_duration / len(words)
190
- start_time = segment["start"]
191
- for word in words:
192
- end_time = start_time + duration_per_word
193
- o.append((start_time, end_time, word))
194
- start_time = end_time
195
-
196
  return o
197
 
198
 
199
  def segments_end_ts(self, res):
200
- return [s["end"] for s in res]
201
 
202
  def transcribe(self, audio_data, prompt=None, *args, **kwargs):
203
  # Write the audio data to a buffer
@@ -212,10 +209,11 @@ class OpenaiApiASR(ASRBase):
212
  "model": self.modelname,
213
  "file": buffer,
214
  "response_format": self.response_format,
215
- "temperature": self.temperature
 
216
  }
217
- if self.task != "translate" and self.language:
218
- params["language"] = self.language
219
  if prompt:
220
  params["prompt"] = prompt
221
 
@@ -225,14 +223,13 @@ class OpenaiApiASR(ASRBase):
225
  proc = self.client.audio.transcriptions
226
 
227
  # Process transcription/translation
228
-
229
  transcript = proc.create(**params)
230
  print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
231
 
232
- return transcript.segments
233
 
234
  def use_vad(self):
235
- self.use_vad = True
236
 
237
  def set_translate_task(self):
238
  self.task = "translate"
@@ -548,7 +545,7 @@ def add_shared_args(parser):
548
  parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
549
  parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
550
  parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
551
- parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
552
  parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
553
  parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
554
  parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
@@ -600,9 +597,9 @@ if __name__ == "__main__":
600
  e = time.time()
601
  print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
602
 
603
- if args.vad:
604
- print("setting VAD filter",file=logfile)
605
- asr.use_vad()
606
 
607
  if args.task == "translate":
608
  asr.set_translate_task()
 
31
  self.logfile = logfile
32
 
33
  self.transcribe_kargs = {}
34
+ if lan == "auto":
35
+ self.original_language = None
36
+ else:
37
+ self.original_language = lan
38
 
39
  self.model = self.load_model(modelsize, cache_dir, model_dir)
40
 
 
122
  return model
123
 
124
  def transcribe(self, audio, init_prompt=""):
125
+
126
  # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
127
  segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
128
+ #print(info) # info contains language detection result
129
+
130
  return list(segments)
131
 
132
  def ts_words(self, segments):
 
152
  class OpenaiApiASR(ASRBase):
153
  """Uses OpenAI's Whisper API for audio transcription."""
154
 
155
+ def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
156
  self.logfile = logfile
157
 
158
  self.modelname = "whisper-1"
159
+ self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
160
+ self.response_format = "verbose_json"
161
  self.temperature = temperature
162
 
163
  self.load_model()
164
 
165
+ self.use_vad_opt = False
166
 
167
  # reset the task in set_translate_task
168
  self.task = "transcribe"
 
175
 
176
 
177
  def ts_words(self, segments):
178
+ no_speech_segments = []
179
+ if self.use_vad_opt:
180
+ for segment in segments.segments:
181
+ # TODO: threshold can be set from outside
182
+ if segment["no_speech_prob"] > 0.8:
183
+ no_speech_segments.append((segment.get("start"), segment.get("end")))
184
 
185
+ o = []
186
+ for word in segments.words:
187
+ start = word.get("start")
188
+ end = word.get("end")
189
+ if any(s[0] <= start <= s[1] for s in no_speech_segments):
190
+ # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
191
  continue
192
+ o.append((start, end, word.get("word")))
 
 
 
 
 
 
 
 
 
 
 
193
  return o
194
 
195
 
196
  def segments_end_ts(self, res):
197
+ return [s["end"] for s in res.words]
198
 
199
  def transcribe(self, audio_data, prompt=None, *args, **kwargs):
200
  # Write the audio data to a buffer
 
209
  "model": self.modelname,
210
  "file": buffer,
211
  "response_format": self.response_format,
212
+ "temperature": self.temperature,
213
+ "timestamp_granularities": ["word", "segment"]
214
  }
215
+ if self.task != "translate" and self.original_language:
216
+ params["language"] = self.original_language
217
  if prompt:
218
  params["prompt"] = prompt
219
 
 
223
  proc = self.client.audio.transcriptions
224
 
225
  # Process transcription/translation
 
226
  transcript = proc.create(**params)
227
  print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
228
 
229
+ return transcript
230
 
231
  def use_vad(self):
232
+ self.use_vad_opt = True
233
 
234
  def set_translate_task(self):
235
  self.task = "translate"
 
545
  parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
546
  parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
547
  parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
548
+ parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
549
  parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
550
  parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
551
  parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
 
597
  e = time.time()
598
  print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
599
 
600
+ if args.vad:
601
+ print("setting VAD filter",file=logfile)
602
+ asr.use_vad()
603
 
604
  if args.task == "translate":
605
  asr.set_translate_task()