qfuxa commited on
Commit
1ebc0b3
·
1 Parent(s): 6a8a8bf

Enhance diarization logic to improve speaker attribution : corrects several bugs

Browse files
Files changed (1) hide show
  1. whisper_fastapi_online_server.py +19 -10
whisper_fastapi_online_server.py CHANGED
@@ -208,6 +208,7 @@ async def websocket_endpoint(websocket: WebSocket):
208
  "beg": transcription.start,
209
  "end": transcription.end,
210
  "text": transcription.text,
 
211
  })
212
  full_transcription += transcription.text if transcription else ""
213
  buffer = online.get_buffer()
@@ -218,23 +219,32 @@ async def websocket_endpoint(websocket: WebSocket):
218
  "beg": time() - beg_loop,
219
  "end": time() - beg_loop + 1,
220
  "text": '',
 
221
  })
222
  sleep(1)
223
  buffer = ''
224
 
225
  if args.diarization:
226
  await diarization.diarize(pcm_array)
227
- diarization.assign_speakers_to_chunks(chunk_history)
228
 
229
 
230
- current_speaker = 0
231
  lines = []
232
  last_end_diarized = 0
 
233
  for ind, ch in enumerate(chunk_history):
234
- speaker = ch.get("speaker", -3)
235
- if speaker == -1 and ind < len(chunk_history) - 1:
236
- continue
237
- elif speaker != current_speaker:
 
 
 
 
 
 
 
238
  lines.append(
239
  {
240
  "speaker": speaker,
@@ -245,12 +255,11 @@ async def websocket_endpoint(websocket: WebSocket):
245
  }
246
  )
247
  current_speaker = speaker
248
- elif speaker != -1:
249
  lines[-1]["text"] += ch['text']
250
  lines[-1]["end"] = format_time(ch['end'])
251
- if speaker != -1:
252
- last_end_diarized = max(ch['end'], last_end_diarized)
253
-
254
  response = {"lines": lines, "buffer": buffer}
255
  await websocket.send_json(response)
256
 
 
208
  "beg": transcription.start,
209
  "end": transcription.end,
210
  "text": transcription.text,
211
+ "speaker": -1
212
  })
213
  full_transcription += transcription.text if transcription else ""
214
  buffer = online.get_buffer()
 
219
  "beg": time() - beg_loop,
220
  "end": time() - beg_loop + 1,
221
  "text": '',
222
+ "speaker": -1
223
  })
224
  sleep(1)
225
  buffer = ''
226
 
227
  if args.diarization:
228
  await diarization.diarize(pcm_array)
229
+ end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history)
230
 
231
 
232
+ current_speaker = -10
233
  lines = []
234
  last_end_diarized = 0
235
+ previous_speaker = -1
236
  for ind, ch in enumerate(chunk_history):
237
+ speaker = ch.get("speaker")
238
+ if args.diarization:
239
+ if speaker == -1 or speaker == 0:
240
+ if ch['end'] < end_attributed_speaker:
241
+ speaker = previous_speaker
242
+ else:
243
+ speaker = 0
244
+ else:
245
+ last_end_diarized = max(ch['end'], last_end_diarized)
246
+
247
+ if speaker != current_speaker:
248
  lines.append(
249
  {
250
  "speaker": speaker,
 
255
  }
256
  )
257
  current_speaker = speaker
258
+ else:
259
  lines[-1]["text"] += ch['text']
260
  lines[-1]["end"] = format_time(ch['end'])
261
+ lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2)
262
+
 
263
  response = {"lines": lines, "buffer": buffer}
264
  await websocket.send_json(response)
265