qfuxa commited on
Commit
6143582
·
1 Parent(s): 33573ca

diarization now works at word - not chunk - level!

Browse files
src/diarization/diarization_online.py CHANGED
@@ -81,11 +81,10 @@ class DiartDiarization:
81
  def close(self):
82
  self.source.close()
83
 
84
- def assign_speakers_to_chunks(self, chunks: list) -> list:
85
- end_attributed_speaker = 0
86
- for chunk in chunks:
87
  for segment in self.segment_speakers:
88
- if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]):
89
- chunk["speaker"] = extract_number(segment["speaker"]) + 1
90
- end_attributed_speaker = chunk["end"]
91
  return end_attributed_speaker
 
81
  def close(self):
82
  self.source.close()
83
 
84
+ def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
85
+ for token in tokens:
 
86
  for segment in self.segment_speakers:
87
+ if not (segment["end"] <= token.start or segment["beg"] >= token.end):
88
+ token.speaker = extract_number(segment["speaker"]) + 1
89
+ end_attributed_speaker = max(token.end, end_attributed_speaker)
90
  return end_attributed_speaker
src/whisper_streaming/online_asr.py CHANGED
@@ -202,7 +202,7 @@ class OnlineASRProcessor:
202
  logger.debug(
203
  f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
204
  )
205
- return self.concatenate_tokens(committed_tokens)
206
 
207
  def chunk_completed_sentence(self):
208
  """
 
202
  logger.debug(
203
  f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
204
  )
205
+ return committed_tokens
206
 
207
  def chunk_completed_sentence(self):
208
  """
src/whisper_streaming/timed_objects.py CHANGED
@@ -5,7 +5,8 @@ from typing import Optional
5
  class TimedText:
6
  start: Optional[float]
7
  end: Optional[float]
8
- text: str
 
9
 
10
  @dataclass
11
  class ASRToken(TimedText):
 
5
  class TimedText:
6
  start: Optional[float]
7
  end: Optional[float]
8
+ text: Optional[str] = ''
9
+ speaker: Optional[int] = -1
10
 
11
  @dataclass
12
  class ASRToken(TimedText):
whisper_fastapi_online_server.py CHANGED
@@ -11,6 +11,7 @@ from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
 
14
 
15
  import math
16
  import logging
@@ -47,7 +48,7 @@ parser.add_argument(
47
  parser.add_argument(
48
  "--diarization",
49
  type=bool,
50
- default=False,
51
  help="Whether to enable speaker diarization.",
52
  )
53
 
@@ -157,7 +158,9 @@ async def websocket_endpoint(websocket: WebSocket):
157
  full_transcription = ""
158
  beg = time()
159
  beg_loop = time()
160
- chunk_history = [] # Will store dicts: {beg, end, text, speaker}
 
 
161
 
162
  while True:
163
  try:
@@ -177,7 +180,6 @@ async def websocket_endpoint(websocket: WebSocket):
177
  logger.warning("FFmpeg read timeout. Restarting...")
178
  await restart_ffmpeg()
179
  full_transcription = ""
180
- chunk_history = []
181
  beg = time()
182
  continue # Skip processing and read from new process
183
 
@@ -202,63 +204,53 @@ async def websocket_endpoint(websocket: WebSocket):
202
  if args.transcription:
203
  logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
204
  online.insert_audio_chunk(pcm_array)
205
- transcription = online.process_iter()
206
- if transcription.start:
207
- chunk_history.append({
208
- "beg": transcription.start,
209
- "end": transcription.end,
210
- "text": transcription.text,
211
- "speaker": -1
212
- })
213
- full_transcription += transcription.text if transcription else ""
214
  buffer = online.get_buffer()
215
  if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
216
  buffer = ""
217
  else:
218
- chunk_history.append({
219
- "beg": time() - beg_loop,
220
- "end": time() - beg_loop + 1,
221
- "text": '',
222
- "speaker": -1
223
- })
224
- sleep(1)
225
  buffer = ''
226
 
227
  if args.diarization:
228
  await diarization.diarize(pcm_array)
229
- end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history)
230
-
231
 
232
- current_speaker = -10
233
  lines = []
234
  last_end_diarized = 0
235
- previous_speaker = -1
236
- for ind, ch in enumerate(chunk_history):
237
- speaker = ch.get("speaker")
238
  if args.diarization:
239
  if speaker == -1 or speaker == 0:
240
- if ch['end'] < end_attributed_speaker:
241
  speaker = previous_speaker
242
  else:
243
  speaker = 0
244
  else:
245
- last_end_diarized = max(ch['end'], last_end_diarized)
246
 
247
- if speaker != current_speaker:
248
  lines.append(
249
  {
250
  "speaker": speaker,
251
- "text": ch['text'],
252
- "beg": format_time(ch['beg']),
253
- "end": format_time(ch['end']),
254
- "diff": round(ch['end'] - last_end_diarized, 2)
255
  }
256
  )
257
- current_speaker = speaker
258
  else:
259
- lines[-1]["text"] += ch['text']
260
- lines[-1]["end"] = format_time(ch['end'])
261
- lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2)
262
 
263
  response = {"lines": lines, "buffer": buffer}
264
  await websocket.send_json(response)
 
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
+ from src.whisper_streaming.timed_objects import ASRToken
15
 
16
  import math
17
  import logging
 
48
  parser.add_argument(
49
  "--diarization",
50
  type=bool,
51
+ default=True,
52
  help="Whether to enable speaker diarization.",
53
  )
54
 
 
158
  full_transcription = ""
159
  beg = time()
160
  beg_loop = time()
161
+ tokens = []
162
+ end_attributed_speaker = 0
163
+ sep = online.asr.sep
164
 
165
  while True:
166
  try:
 
180
  logger.warning("FFmpeg read timeout. Restarting...")
181
  await restart_ffmpeg()
182
  full_transcription = ""
 
183
  beg = time()
184
  continue # Skip processing and read from new process
185
 
 
204
  if args.transcription:
205
  logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
206
  online.insert_audio_chunk(pcm_array)
207
+ new_tokens = online.process_iter()
208
+ tokens.extend(new_tokens)
209
+ full_transcription += sep.join([t.text for t in new_tokens])
 
 
 
 
 
 
210
  buffer = online.get_buffer()
211
  if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
212
  buffer = ""
213
  else:
214
+ tokens.append(
215
+ ASRToken(
216
+ start = time() - beg_loop,
217
+ end = time() - beg_loop + 0.5))
218
+ sleep(0.5)
 
 
219
  buffer = ''
220
 
221
  if args.diarization:
222
  await diarization.diarize(pcm_array)
223
+ end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
 
224
 
225
+ previous_speaker = -10
226
  lines = []
227
  last_end_diarized = 0
228
+ for token in tokens:
229
+ speaker = token.speaker
 
230
  if args.diarization:
231
  if speaker == -1 or speaker == 0:
232
+ if token.end < end_attributed_speaker:
233
  speaker = previous_speaker
234
  else:
235
  speaker = 0
236
  else:
237
+ last_end_diarized = max(token.end, last_end_diarized)
238
 
239
+ if speaker != previous_speaker:
240
  lines.append(
241
  {
242
  "speaker": speaker,
243
+ "text": token.text,
244
+ "beg": format_time(token.start),
245
+ "end": format_time(token.end),
246
+ "diff": round(token.end - last_end_diarized, 2)
247
  }
248
  )
249
+ previous_speaker = speaker
250
  else:
251
+ lines[-1]["text"] += sep + token.text
252
+ lines[-1]["end"] = format_time(token.end)
253
+ lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
254
 
255
  response = {"lines": lines, "buffer": buffer}
256
  await websocket.send_json(response)