qfuxa commited on
Commit
0cf8b89
·
1 Parent(s): fe0200a

update import paths

Browse files
diarization/diarization_online.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ import threading
4
+ import numpy as np
5
+
6
+ from diart import SpeakerDiarization
7
+ from diart.inference import StreamingInference
8
+ from diart.sources import AudioSource
9
+ from timed_objects import SpeakerSegment
10
+
11
+ def extract_number(s: str) -> int:
12
+ m = re.search(r'\d+', s)
13
+ return int(m.group()) if m else None
14
+
15
+
16
+ class WebSocketAudioSource(AudioSource):
17
+ """
18
+ Custom AudioSource that blocks in read() until close() is called.
19
+ Use push_audio() to inject PCM chunks.
20
+ """
21
+ def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
22
+ super().__init__(uri, sample_rate)
23
+ self._closed = False
24
+ self._close_event = threading.Event()
25
+
26
+ def read(self):
27
+ self._close_event.wait()
28
+
29
+ def close(self):
30
+ if not self._closed:
31
+ self._closed = True
32
+ self.stream.on_completed()
33
+ self._close_event.set()
34
+
35
+ def push_audio(self, chunk: np.ndarray):
36
+ if not self._closed:
37
+ self.stream.on_next(np.expand_dims(chunk, axis=0))
38
+
39
+
40
+ class DiartDiarization:
41
+ def __init__(self, sample_rate: int):
42
+ self.processed_time = 0
43
+ self.segment_speakers = []
44
+ self.speakers_queue = asyncio.Queue()
45
+ self.pipeline = SpeakerDiarization()
46
+ self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
47
+ self.inference = StreamingInference(
48
+ pipeline=self.pipeline,
49
+ source=self.source,
50
+ do_plot=False,
51
+ show_progress=False,
52
+ )
53
+ # Attache la fonction hook et démarre l'inférence en arrière-plan.
54
+ self.inference.attach_hooks(self._diar_hook)
55
+ asyncio.get_event_loop().run_in_executor(None, self.inference)
56
+
57
+ def _diar_hook(self, result):
58
+ annotation, audio = result
59
+ if annotation._labels:
60
+ for speaker, label in annotation._labels.items():
61
+ start = label.segments_boundaries_[0]
62
+ end = label.segments_boundaries_[-1]
63
+ if end > self.processed_time:
64
+ self.processed_time = end
65
+ asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
66
+ speaker=speaker,
67
+ start=start,
68
+ end=end,
69
+ )))
70
+ else:
71
+ dur = audio.extent.end
72
+ if dur > self.processed_time:
73
+ self.processed_time = dur
74
+
75
+ async def diarize(self, pcm_array: np.ndarray):
76
+ self.source.push_audio(pcm_array)
77
+ self.segment_speakers.clear()
78
+ while not self.speakers_queue.empty():
79
+ self.segment_speakers.append(await self.speakers_queue.get())
80
+
81
+ def close(self):
82
+ self.source.close()
83
+
84
+ def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
85
+ for token in tokens:
86
+ for segment in self.segment_speakers:
87
+ if not (segment.end <= token.start or segment.start >= token.end):
88
+ token.speaker = extract_number(segment.speaker) + 1
89
+ end_attributed_speaker = max(token.end, end_attributed_speaker)
90
+ return end_attributed_speaker
whisper_fastapi_online_server.py CHANGED
@@ -10,8 +10,8 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
- from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
- from src.whisper_streaming.timed_objects import ASRToken
15
 
16
  import math
17
  import logging
@@ -49,7 +49,7 @@ parser.add_argument(
49
  parser.add_argument(
50
  "--diarization",
51
  type=bool,
52
- default=False,
53
  help="Whether to enable speaker diarization.",
54
  )
55
 
@@ -157,7 +157,7 @@ async def lifespan(app: FastAPI):
157
  asr, tokenizer = None, None
158
 
159
  if args.diarization:
160
- from src.diarization.diarization_online import DiartDiarization
161
  diarization = DiartDiarization(SAMPLE_RATE)
162
  else :
163
  diarization = None
@@ -174,7 +174,7 @@ app.add_middleware(
174
 
175
 
176
  # Load demo HTML for the root endpoint
177
- with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
178
  html = f.read()
179
 
180
  async def start_ffmpeg_decoder():
@@ -277,24 +277,18 @@ async def results_formatter(shared_state, websocket):
277
 
278
  # Process tokens to create response
279
  previous_speaker = -10
280
- lines = [
281
- ]
282
  last_end_diarized = 0
283
  undiarized_text = []
284
 
285
  for token in tokens:
286
  speaker = token.speaker
287
- # Handle diarization differently if diarization is enabled
288
  if args.diarization:
289
- # If token is not yet processed by diarization
290
  if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
291
- # Add this token's text to undiarized buffer instead of creating a new line
292
  undiarized_text.append(token.text)
293
  continue
294
- # If speaker isn't assigned yet but should be (based on timestamp)
295
  elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
296
  speaker = previous_speaker
297
- # Track last diarized token end time
298
  if speaker not in [-1, 0]:
299
  last_end_diarized = max(token.end, last_end_diarized)
300
 
@@ -314,7 +308,6 @@ async def results_formatter(shared_state, websocket):
314
  lines[-1]["end"] = format_time(token.end)
315
  lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
316
 
317
- # Update buffer_diarization with undiarized text
318
  if undiarized_text:
319
  combined_buffer_diarization = sep.join(undiarized_text)
320
  if buffer_transcription:
@@ -322,7 +315,6 @@ async def results_formatter(shared_state, websocket):
322
  await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
323
  buffer_diarization = combined_buffer_diarization
324
 
325
- # Prepare response object
326
  if lines:
327
  response = {
328
  "lines": lines,
@@ -350,7 +342,6 @@ async def results_formatter(shared_state, websocket):
350
  response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
351
 
352
  if response_content != shared_state.last_response_content:
353
- # Only send if there's actual content to send
354
  if lines or buffer_transcription or buffer_diarization:
355
  await websocket.send_json(response)
356
  shared_state.last_response_content = response_content
 
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
+ from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args
14
+ from timed_objects import ASRToken
15
 
16
  import math
17
  import logging
 
49
  parser.add_argument(
50
  "--diarization",
51
  type=bool,
52
+ default=True,
53
  help="Whether to enable speaker diarization.",
54
  )
55
 
 
157
  asr, tokenizer = None, None
158
 
159
  if args.diarization:
160
+ from diarization.diarization_online import DiartDiarization
161
  diarization = DiartDiarization(SAMPLE_RATE)
162
  else :
163
  diarization = None
 
174
 
175
 
176
  # Load demo HTML for the root endpoint
177
+ with open("web/live_transcription.html", "r", encoding="utf-8") as f:
178
  html = f.read()
179
 
180
  async def start_ffmpeg_decoder():
 
277
 
278
  # Process tokens to create response
279
  previous_speaker = -10
280
+ lines = []
 
281
  last_end_diarized = 0
282
  undiarized_text = []
283
 
284
  for token in tokens:
285
  speaker = token.speaker
 
286
  if args.diarization:
 
287
  if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
 
288
  undiarized_text.append(token.text)
289
  continue
 
290
  elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
291
  speaker = previous_speaker
 
292
  if speaker not in [-1, 0]:
293
  last_end_diarized = max(token.end, last_end_diarized)
294
 
 
308
  lines[-1]["end"] = format_time(token.end)
309
  lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
310
 
 
311
  if undiarized_text:
312
  combined_buffer_diarization = sep.join(undiarized_text)
313
  if buffer_transcription:
 
315
  await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
316
  buffer_diarization = combined_buffer_diarization
317
 
 
318
  if lines:
319
  response = {
320
  "lines": lines,
 
342
  response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
343
 
344
  if response_content != shared_state.last_response_content:
 
345
  if lines or buffer_transcription or buffer_diarization:
346
  await websocket.send_json(response)
347
  shared_state.last_response_content = response_content