qfuxa commited on
Commit
b0d49ce
·
1 Parent(s): 400ff66

// execution for diarization and transcription

Browse files
Files changed (1) hide show
  1. whisper_fastapi_online_server.py +238 -71
whisper_fastapi_online_server.py CHANGED
@@ -70,6 +70,78 @@ BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
70
  MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  ##### LOAD APP #####
75
 
@@ -120,6 +192,133 @@ async def start_ffmpeg_decoder():
120
  )
121
  return process
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  ##### ENDPOINTS #####
125
 
@@ -134,8 +333,12 @@ async def websocket_endpoint(websocket: WebSocket):
134
 
135
  ffmpeg_process = None
136
  pcm_buffer = bytearray()
137
- online = online_factory(args, asr, tokenizer) if args.transcription else None
138
 
 
 
 
 
139
 
140
  async def restart_ffmpeg():
141
  nonlocal ffmpeg_process, online, pcm_buffer
@@ -147,20 +350,29 @@ async def websocket_endpoint(websocket: WebSocket):
147
  logger.warning(f"Error killing FFmpeg process: {e}")
148
  ffmpeg_process = await start_ffmpeg_decoder()
149
  pcm_buffer = bytearray()
150
- online = online_factory(args, asr, tokenizer) if args.transcription else None
 
 
 
 
151
  logger.info("FFmpeg process started.")
152
 
153
  await restart_ffmpeg()
154
 
 
 
 
 
 
 
 
 
 
 
155
  async def ffmpeg_stdout_reader():
156
- nonlocal ffmpeg_process, online, pcm_buffer
157
  loop = asyncio.get_event_loop()
158
- full_transcription = ""
159
  beg = time()
160
- beg_loop = time()
161
- tokens = []
162
- end_attributed_speaker = 0
163
- sep = online.asr.sep
164
 
165
  while True:
166
  try:
@@ -179,7 +391,6 @@ async def websocket_endpoint(websocket: WebSocket):
179
  except asyncio.TimeoutError:
180
  logger.warning("FFmpeg read timeout. Restarting...")
181
  await restart_ffmpeg()
182
- full_transcription = ""
183
  beg = time()
184
  continue # Skip processing and read from new process
185
 
@@ -200,62 +411,14 @@ async def websocket_endpoint(websocket: WebSocket):
200
  )
201
  pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
202
 
203
- if args.transcription:
204
- logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
205
- online.insert_audio_chunk(pcm_array)
206
- new_tokens = online.process_iter()
207
- tokens.extend(new_tokens)
208
- full_transcription += sep.join([t.text for t in new_tokens])
209
- _buffer = online.get_buffer()
210
- buffer = _buffer.text
211
- end_buffer = _buffer.end if _buffer.end else tokens[-1].end if tokens else 0
212
- if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
213
- buffer = ""
214
- else:
215
- tokens.append(
216
- ASRToken(
217
- start = time() - beg_loop,
218
- end = time() - beg_loop + 0.5))
219
- sleep(0.5)
220
- buffer = ''
221
-
222
- if args.diarization:
223
- await diarization.diarize(pcm_array)
224
- end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
225
 
226
- previous_speaker = -10
227
- lines = []
228
- last_end_diarized = 0
229
- for token in tokens:
230
- speaker = token.speaker
231
- if args.diarization:
232
- if speaker == -1 or speaker == 0:
233
- if token.end < end_attributed_speaker:
234
- speaker = previous_speaker
235
- else:
236
- speaker = 0
237
- else:
238
- last_end_diarized = max(token.end, last_end_diarized)
239
-
240
- if speaker != previous_speaker:
241
- lines.append(
242
- {
243
- "speaker": speaker,
244
- "text": token.text,
245
- "beg": format_time(token.start),
246
- "end": format_time(token.end),
247
- "diff": round(token.end - last_end_diarized, 2)
248
- }
249
- )
250
- previous_speaker = speaker
251
- else:
252
- lines[-1]["text"] += sep + token.text
253
- lines[-1]["end"] = format_time(token.end)
254
- lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
255
-
256
- response = {"lines": lines, "buffer": buffer}
257
- # response = {"lines": lines, "buffer": buffer, "time_buffer_transcription": time() + beg_loop - end_buffer, "time_buffer_diarization": time() + beg_loop - end_attributed_speaker}
258
- await websocket.send_json(response)
259
 
260
  except Exception as e:
261
  logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
@@ -264,7 +427,7 @@ async def websocket_endpoint(websocket: WebSocket):
264
  logger.info("Exiting ffmpeg_stdout_reader...")
265
 
266
  stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
267
-
268
  try:
269
  while True:
270
  # Receive incoming WebM audio chunks from the client
@@ -280,16 +443,20 @@ async def websocket_endpoint(websocket: WebSocket):
280
  except WebSocketDisconnect:
281
  logger.warning("WebSocket disconnected.")
282
  finally:
283
- stdout_reader_task.cancel()
 
 
284
  try:
 
285
  ffmpeg_process.stdin.close()
286
  ffmpeg_process.wait()
287
- except:
288
- pass
289
- if args.diarization:
 
290
  diarization.close()
291
-
292
-
293
 
294
  if __name__ == "__main__":
295
  import uvicorn
 
70
  MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
71
 
72
 
73
+ class SharedState:
74
+ def __init__(self):
75
+ self.tokens = []
76
+ self.buffer_transcription = ""
77
+ self.buffer_diarization = ""
78
+ self.full_transcription = ""
79
+ self.end_buffer = 0
80
+ self.end_attributed_speaker = 0
81
+ self.lock = asyncio.Lock()
82
+ self.beg_loop = time()
83
+ self.sep = " " # Default separator
84
+
85
+ async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
86
+ async with self.lock:
87
+ self.tokens.extend(new_tokens)
88
+ self.buffer_transcription = buffer
89
+ self.end_buffer = end_buffer
90
+ self.full_transcription = full_transcription
91
+ self.sep = sep
92
+
93
+ async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
94
+ async with self.lock:
95
+ self.end_attributed_speaker = end_attributed_speaker
96
+ if buffer_diarization:
97
+ self.buffer_diarization = buffer_diarization
98
+
99
+ async def add_dummy_token(self):
100
+ async with self.lock:
101
+ current_time = time() - self.beg_loop
102
+ dummy_token = ASRToken(
103
+ start=current_time,
104
+ end=current_time + 0.5,
105
+ text="",
106
+ speaker=-1
107
+ )
108
+ self.tokens.append(dummy_token)
109
+
110
+ async def get_current_state(self):
111
+ async with self.lock:
112
+ current_time = time()
113
+ remaining_time_transcription = 0
114
+ remaining_time_diarization = 0
115
+
116
+ # Calculate remaining time for transcription buffer
117
+ if self.end_buffer > 0:
118
+ remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
119
+
120
+ # Calculate remaining time for diarization
121
+ if self.end_attributed_speaker > 0:
122
+ remaining_time_diarization = max(0, round(current_time - self.beg_loop - self.end_attributed_speaker, 2))
123
+
124
+ return {
125
+ "tokens": self.tokens.copy(),
126
+ "buffer_transcription": self.buffer_transcription,
127
+ "buffer_diarization": self.buffer_diarization,
128
+ "end_buffer": self.end_buffer,
129
+ "end_attributed_speaker": self.end_attributed_speaker,
130
+ "sep": self.sep,
131
+ "remaining_time_transcription": remaining_time_transcription,
132
+ "remaining_time_diarization": remaining_time_diarization
133
+ }
134
+
135
+ async def reset(self):
136
+ """Reset the state."""
137
+ async with self.lock:
138
+ self.tokens = []
139
+ self.buffer_transcription = ""
140
+ self.buffer_diarization = ""
141
+ self.end_buffer = 0
142
+ self.end_attributed_speaker = 0
143
+ self.full_transcription = ""
144
+ self.beg_loop = time()
145
 
146
  ##### LOAD APP #####
147
 
 
192
  )
193
  return process
194
 
195
+ async def transcription_processor(shared_state, pcm_queue, online):
196
+ full_transcription = ""
197
+ sep = online.asr.sep
198
+
199
+ while True:
200
+ try:
201
+ pcm_array = await pcm_queue.get()
202
+
203
+ logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
204
+
205
+ # Process transcription
206
+ online.insert_audio_chunk(pcm_array)
207
+ new_tokens = online.process_iter()
208
+
209
+ if new_tokens:
210
+ full_transcription += sep.join([t.text for t in new_tokens])
211
+
212
+ _buffer = online.get_buffer()
213
+ buffer = _buffer.text
214
+ end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
215
+
216
+ if buffer in full_transcription:
217
+ buffer = ""
218
+
219
+ await shared_state.update_transcription(
220
+ new_tokens, buffer, end_buffer, full_transcription, sep)
221
+
222
+ except Exception as e:
223
+ logger.warning(f"Exception in transcription_processor: {e}")
224
+ finally:
225
+ pcm_queue.task_done()
226
+
227
+ async def diarization_processor(shared_state, pcm_queue, diarization_obj):
228
+ buffer_diarization = ""
229
+
230
+ while True:
231
+ try:
232
+ pcm_array = await pcm_queue.get()
233
+
234
+ # Process diarization
235
+ await diarization_obj.diarize(pcm_array)
236
+
237
+ # Get current state
238
+ state = await shared_state.get_current_state()
239
+ tokens = state["tokens"]
240
+ end_attributed_speaker = state["end_attributed_speaker"]
241
+
242
+ # Update speaker information
243
+ new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens(
244
+ end_attributed_speaker, tokens)
245
+
246
+ await shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization)
247
+
248
+ except Exception as e:
249
+ logger.warning(f"Exception in diarization_processor: {e}")
250
+ finally:
251
+ pcm_queue.task_done()
252
+
253
+ async def results_formatter(shared_state, websocket):
254
+ while True:
255
+ try:
256
+ # Get the current state
257
+ state = await shared_state.get_current_state()
258
+ tokens = state["tokens"]
259
+ buffer_transcription = state["buffer_transcription"]
260
+ buffer_diarization = state["buffer_diarization"]
261
+ end_attributed_speaker = state["end_attributed_speaker"]
262
+ remaining_time_transcription = state["remaining_time_transcription"]
263
+ remaining_time_diarization = state["remaining_time_diarization"]
264
+ sep = state["sep"]
265
+
266
+ # If diarization is enabled but no transcription, add dummy tokens periodically
267
+ if not tokens and not args.transcription and args.diarization:
268
+ await shared_state.add_dummy_token()
269
+ # Re-fetch tokens after adding dummy
270
+ state = await shared_state.get_current_state()
271
+ tokens = state["tokens"]
272
+
273
+ # Process tokens to create response
274
+ previous_speaker = -10
275
+ lines = []
276
+ last_end_diarized = 0
277
+
278
+ for token in tokens:
279
+ speaker = token.speaker
280
+ if args.diarization:
281
+ if speaker == -1 or speaker == 0:
282
+ if token.end < end_attributed_speaker:
283
+ speaker = previous_speaker
284
+ else:
285
+ speaker = 0
286
+ else:
287
+ last_end_diarized = max(token.end, last_end_diarized)
288
+
289
+ if speaker != previous_speaker:
290
+ lines.append(
291
+ {
292
+ "speaker": speaker,
293
+ "text": token.text,
294
+ "beg": format_time(token.start),
295
+ "end": format_time(token.end),
296
+ "diff": round(token.end - last_end_diarized, 2)
297
+ }
298
+ )
299
+ previous_speaker = speaker
300
+ elif token.text: # Only append if text isn't empty
301
+ lines[-1]["text"] += sep + token.text
302
+ lines[-1]["end"] = format_time(token.end)
303
+ lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
304
+
305
+ # Prepare response object
306
+ response = {
307
+ "lines": lines,
308
+ "buffer_transcription": buffer_transcription,
309
+ "buffer_diarization": buffer_diarization,
310
+ "remaining_time_transcription": remaining_time_transcription,
311
+ "remaining_time_diarization": remaining_time_diarization
312
+ }
313
+
314
+ await websocket.send_json(response)
315
+
316
+ # Add a small delay to avoid overwhelming the client
317
+ await asyncio.sleep(0.1)
318
+
319
+ except Exception as e:
320
+ logger.warning(f"Exception in results_formatter: {e}")
321
+ await asyncio.sleep(0.5) # Back off on error
322
 
323
  ##### ENDPOINTS #####
324
 
 
333
 
334
  ffmpeg_process = None
335
  pcm_buffer = bytearray()
336
+ shared_state = SharedState()
337
 
338
+ transcription_queue = asyncio.Queue() if args.transcription else None
339
+ diarization_queue = asyncio.Queue() if args.diarization else None
340
+
341
+ online = None
342
 
343
  async def restart_ffmpeg():
344
  nonlocal ffmpeg_process, online, pcm_buffer
 
350
  logger.warning(f"Error killing FFmpeg process: {e}")
351
  ffmpeg_process = await start_ffmpeg_decoder()
352
  pcm_buffer = bytearray()
353
+
354
+ if args.transcription:
355
+ online = online_factory(args, asr, tokenizer)
356
+
357
+ await shared_state.reset()
358
  logger.info("FFmpeg process started.")
359
 
360
  await restart_ffmpeg()
361
 
362
+ tasks = []
363
+ if args.transcription and online:
364
+ tasks.append(asyncio.create_task(
365
+ transcription_processor(shared_state, transcription_queue, online)))
366
+ if args.diarization and diarization:
367
+ tasks.append(asyncio.create_task(
368
+ diarization_processor(shared_state, diarization_queue, diarization)))
369
+ formatter_task = asyncio.create_task(results_formatter(shared_state, websocket))
370
+ tasks.append(formatter_task)
371
+
372
  async def ffmpeg_stdout_reader():
373
+ nonlocal ffmpeg_process, pcm_buffer
374
  loop = asyncio.get_event_loop()
 
375
  beg = time()
 
 
 
 
376
 
377
  while True:
378
  try:
 
391
  except asyncio.TimeoutError:
392
  logger.warning("FFmpeg read timeout. Restarting...")
393
  await restart_ffmpeg()
 
394
  beg = time()
395
  continue # Skip processing and read from new process
396
 
 
411
  )
412
  pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
413
 
414
+ if args.transcription and transcription_queue:
415
+ await transcription_queue.put(pcm_array.copy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
+ if args.diarization and diarization_queue:
418
+ await diarization_queue.put(pcm_array.copy())
419
+
420
+ if not args.transcription and not args.diarization:
421
+ await asyncio.sleep(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  except Exception as e:
424
  logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
 
427
  logger.info("Exiting ffmpeg_stdout_reader...")
428
 
429
  stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
430
+ tasks.append(stdout_reader_task)
431
  try:
432
  while True:
433
  # Receive incoming WebM audio chunks from the client
 
443
  except WebSocketDisconnect:
444
  logger.warning("WebSocket disconnected.")
445
  finally:
446
+ for task in tasks:
447
+ task.cancel()
448
+
449
  try:
450
+ await asyncio.gather(*tasks, return_exceptions=True)
451
  ffmpeg_process.stdin.close()
452
  ffmpeg_process.wait()
453
+ except Exception as e:
454
+ logger.warning(f"Error during cleanup: {e}")
455
+
456
+ if args.diarization and diarization:
457
  diarization.close()
458
+
459
+ logger.info("WebSocket endpoint cleaned up.")
460
 
461
  if __name__ == "__main__":
462
  import uvicorn