qfuxa commited on
Commit
6933483
·
1 Parent(s): cc68f3b

add diarization (beta). Disabled by default

Browse files
src/diarization/diarization_online.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diart import SpeakerDiarization
2
+ from diart.inference import StreamingInference
3
+ from diart.sources import AudioSource
4
+ from rx.subject import Subject
5
+ import threading
6
+ import numpy as np
7
+ import asyncio
8
+
9
+ class WebSocketAudioSource(AudioSource):
10
+ """
11
+ Simple custom AudioSource that blocks in read()
12
+ until close() is called.
13
+ push_audio() is used to inject new PCM chunks.
14
+ """
15
+ def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
16
+ super().__init__(uri, sample_rate)
17
+ self._close_event = threading.Event()
18
+ self._closed = False
19
+
20
+ def read(self):
21
+ self._close_event.wait()
22
+
23
+ def close(self):
24
+ if not self._closed:
25
+ self._closed = True
26
+ self.stream.on_completed()
27
+ self._close_event.set()
28
+
29
+ def push_audio(self, chunk: np.ndarray):
30
+ chunk = np.expand_dims(chunk, axis=0)
31
+ if not self._closed:
32
+ self.stream.on_next(chunk)
33
+
34
+
35
+ def create_pipeline(SAMPLE_RATE):
36
+ diar_pipeline = SpeakerDiarization()
37
+ ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
38
+ inference = StreamingInference(
39
+ pipeline=diar_pipeline,
40
+ source=ws_source,
41
+ do_plot=False,
42
+ show_progress=False,
43
+ )
44
+ return inference, ws_source
45
+
46
+
47
+ def init_diart(SAMPLE_RATE):
48
+ inference, ws_source = create_pipeline(SAMPLE_RATE)
49
+
50
+ def diar_hook(result):
51
+ """
52
+ Hook called each time Diart processes a chunk.
53
+ result is (annotation, audio).
54
+ We store the label of the last segment in 'current_speaker'.
55
+ """
56
+ global l_speakers
57
+ l_speakers = []
58
+ annotation, audio = result
59
+ for speaker in annotation._labels:
60
+ segment = annotation._labels[speaker].__str__()
61
+ asyncio.create_task(
62
+ l_speakers_queue.put({"speaker": speaker, "segment": segment})
63
+ )
64
+
65
+ l_speakers_queue = asyncio.Queue()
66
+ inference.attach_hooks(diar_hook)
67
+
68
+ # Launch Diart in a background thread
69
+ loop = asyncio.get_event_loop()
70
+ diar_future = loop.run_in_executor(None, inference)
71
+ return inference, l_speakers_queue, ws_source
72
+
73
+
74
+ class DiartDiarization():
75
+ def __init__(self, SAMPLE_RATE):
76
+ self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
77
+
78
+ async def get_speakers(self, pcm_array):
79
+ self.ws_source.push_audio(pcm_array)
80
+ speakers = []
81
+ while not self.l_speakers_queue.empty():
82
+ speakers.append(await self.l_speakers_queue.get())
83
+ return speakers
84
+
85
+ def close(self):
86
+ self.ws_source.close()
whisper_fastapi_online_server.py CHANGED
@@ -10,7 +10,6 @@ from fastapi.responses import HTMLResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
  from whisper_online import backend_factory, online_factory, add_shared_args
13
-
14
  app = FastAPI()
15
  app.add_middleware(
16
  CORSMiddleware,
@@ -37,11 +36,24 @@ parser.add_argument(
37
  dest="warmup_file",
38
  help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
39
  )
 
 
 
 
 
 
 
 
 
40
  add_shared_args(parser)
41
  args = parser.parse_args()
42
 
43
  asr, tokenizer = backend_factory(args)
44
 
 
 
 
 
45
  # Load demo HTML for the root endpoint
46
  with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
47
  html = f.read()
@@ -89,6 +101,9 @@ async def websocket_endpoint(websocket: WebSocket):
89
  online = online_factory(args, asr, tokenizer)
90
  print("Online loaded.")
91
 
 
 
 
92
  # Continuously read decoded PCM from ffmpeg stdout in a background task
93
  async def ffmpeg_stdout_reader():
94
  nonlocal pcm_buffer
@@ -136,9 +151,13 @@ async def websocket_endpoint(websocket: WebSocket):
136
  buffer in full_transcription
137
  ): # With VAC, the buffer is not updated until the next chunk is processed
138
  buffer = ""
139
- await websocket.send_json(
140
- {"transcription": transcription, "buffer": buffer}
141
- )
 
 
 
 
142
  except Exception as e:
143
  print(f"Exception in ffmpeg_stdout_reader: {e}")
144
  break
@@ -174,6 +193,11 @@ async def websocket_endpoint(websocket: WebSocket):
174
 
175
  ffmpeg_process.wait()
176
  del online
 
 
 
 
 
177
 
178
 
179
 
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
  from whisper_online import backend_factory, online_factory, add_shared_args
 
13
  app = FastAPI()
14
  app.add_middleware(
15
  CORSMiddleware,
 
36
  dest="warmup_file",
37
  help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
38
  )
39
+
40
+ parser.add_argument(
41
+ "--diarization",
42
+ type=bool,
43
+ default=False,
44
+ help="Whether to enable speaker diarization.",
45
+ )
46
+
47
+
48
  add_shared_args(parser)
49
  args = parser.parse_args()
50
 
51
  asr, tokenizer = backend_factory(args)
52
 
53
+ if args.diarization:
54
+ from src.diarization.diarization_online import DiartDiarization
55
+
56
+
57
  # Load demo HTML for the root endpoint
58
  with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
59
  html = f.read()
 
101
  online = online_factory(args, asr, tokenizer)
102
  print("Online loaded.")
103
 
104
+ if args.diarization:
105
+ diarization = DiartDiarization(SAMPLE_RATE)
106
+
107
  # Continuously read decoded PCM from ffmpeg stdout in a background task
108
  async def ffmpeg_stdout_reader():
109
  nonlocal pcm_buffer
 
151
  buffer in full_transcription
152
  ): # With VAC, the buffer is not updated until the next chunk is processed
153
  buffer = ""
154
+ response = {"transcription": transcription, "buffer": buffer}
155
+ if args.diarization:
156
+ speakers = await diarization.get_speakers(pcm_array)
157
+ response["speakers"] = speakers
158
+
159
+ await websocket.send_json(response)
160
+
161
  except Exception as e:
162
  print(f"Exception in ffmpeg_stdout_reader: {e}")
163
  break
 
193
 
194
  ffmpeg_process.wait()
195
  del online
196
+
197
+ if args.diarization:
198
+ # Stop Diart
199
+ diarization.close()
200
+
201
 
202
 
203