qfuxa commited on
Commit
02f90cf
·
2 Parent(s): b7c5736 1cb2d95

Merge branch 'main' into fix-sentencesegmenter

Browse files
README.md CHANGED
@@ -12,6 +12,8 @@ This project extends the [Whisper Streaming](https://github.com/ufal/whisper_str
12
 
13
  5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
14
 
 
 
15
  ![Demo Screenshot](src/web/demo.png)
16
 
17
  ## Code Origins
@@ -64,6 +66,9 @@ This project reuses and extends code from the original Whisper Streaming reposit
64
 
65
  # If you want to run the server using uvicorn (recommended)
66
  uvicorn
 
 
 
67
  ```
68
 
69
 
@@ -76,6 +81,8 @@ This project reuses and extends code from the original Whisper Streaming reposit
76
  - `--host` and `--port` let you specify the server’s IP/port.
77
  - `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
78
  - For a full list of configurable options, run `python whisper_fastapi_online_server.py -h`
 
 
79
 
80
  4. **Open the Provided HTML**:
81
 
 
12
 
13
  5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
14
 
15
+ 6. **Diarization (beta)**: Adds speaker labeling in real-time alongside transcription using the [Diart](https://github.com/juanmc2005/diart) library. Each transcription segment is tagged with a speaker.
16
+
17
  ![Demo Screenshot](src/web/demo.png)
18
 
19
  ## Code Origins
 
66
 
67
  # If you want to run the server using uvicorn (recommended)
68
  uvicorn
69
+
70
+ # If you want to use diarization
71
+ diart
72
  ```
73
 
74
 
 
81
  - `--host` and `--port` let you specify the server’s IP/port.
82
  - `-min-chunk-size` sets the minimum chunk size for audio processing. Make sure this value aligns with the chunk size selected in the frontend. If not aligned, the system will work but may unnecessarily over-process audio data.
83
  - For a full list of configurable options, run `python whisper_fastapi_online_server.py -h`
84
+ - `--diarization`, default to False, let you choose whether or not you want to run diarization in parallel
85
+ - For other parameters, look at [whisper streaming](https://github.com/ufal/whisper_streaming) readme.
86
 
87
  4. **Open the Provided HTML**:
88
 
src/diarization/diarization_online.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diart import SpeakerDiarization
2
+ from diart.inference import StreamingInference
3
+ from diart.sources import AudioSource
4
+ from rx.subject import Subject
5
+ import threading
6
+ import numpy as np
7
+ import asyncio
8
+
9
+ class WebSocketAudioSource(AudioSource):
10
+ """
11
+ Simple custom AudioSource that blocks in read()
12
+ until close() is called.
13
+ push_audio() is used to inject new PCM chunks.
14
+ """
15
+ def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
16
+ super().__init__(uri, sample_rate)
17
+ self._close_event = threading.Event()
18
+ self._closed = False
19
+
20
+ def read(self):
21
+ self._close_event.wait()
22
+
23
+ def close(self):
24
+ if not self._closed:
25
+ self._closed = True
26
+ self.stream.on_completed()
27
+ self._close_event.set()
28
+
29
+ def push_audio(self, chunk: np.ndarray):
30
+ chunk = np.expand_dims(chunk, axis=0)
31
+ if not self._closed:
32
+ self.stream.on_next(chunk)
33
+
34
+
35
+ def create_pipeline(SAMPLE_RATE):
36
+ diar_pipeline = SpeakerDiarization()
37
+ ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
38
+ inference = StreamingInference(
39
+ pipeline=diar_pipeline,
40
+ source=ws_source,
41
+ do_plot=False,
42
+ show_progress=False,
43
+ )
44
+ return inference, ws_source
45
+
46
+
47
+ def init_diart(SAMPLE_RATE):
48
+ inference, ws_source = create_pipeline(SAMPLE_RATE)
49
+
50
+ def diar_hook(result):
51
+ """
52
+ Hook called each time Diart processes a chunk.
53
+ result is (annotation, audio).
54
+ We store the label of the last segment in 'current_speaker'.
55
+ """
56
+ global l_speakers
57
+ l_speakers = []
58
+ annotation, audio = result
59
+ for speaker in annotation._labels:
60
+ segments_beg = annotation._labels[speaker].segments_boundaries_[0]
61
+ segments_end = annotation._labels[speaker].segments_boundaries_[-1]
62
+ asyncio.create_task(
63
+ l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
64
+ )
65
+
66
+ l_speakers_queue = asyncio.Queue()
67
+ inference.attach_hooks(diar_hook)
68
+
69
+ # Launch Diart in a background thread
70
+ loop = asyncio.get_event_loop()
71
+ diar_future = loop.run_in_executor(None, inference)
72
+ return inference, l_speakers_queue, ws_source
73
+
74
+
75
+ class DiartDiarization():
76
+ def __init__(self, SAMPLE_RATE):
77
+ self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
78
+ self.segment_speakers = []
79
+
80
+ async def diarize(self, pcm_array):
81
+ self.ws_source.push_audio(pcm_array)
82
+ self.segment_speakers = []
83
+ while not self.l_speakers_queue.empty():
84
+ self.segment_speakers.append(await self.l_speakers_queue.get())
85
+
86
+ def close(self):
87
+ self.ws_source.close()
88
+
89
+
90
+ def assign_speakers_to_chunks(self, chunks):
91
+ """
92
+ Go through each chunk and see which speaker(s) overlap
93
+ that chunk's time range in the Diart annotation.
94
+ Then store the speaker label(s) (or choose the most overlapping).
95
+ This modifies `chunks` in-place or returns a new list with assigned speakers.
96
+ """
97
+ if not self.segment_speakers:
98
+ return chunks
99
+
100
+ for segment in self.segment_speakers:
101
+ seg_beg = segment["beg"]
102
+ seg_end = segment["end"]
103
+ speaker = segment["speaker"]
104
+ for ch in chunks:
105
+ if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
106
+ continue
107
+ # We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
108
+ ch["speaker"] = speaker
109
+
110
+ return chunks
src/web/live_transcription.html CHANGED
@@ -7,8 +7,8 @@
7
  <style>
8
  body {
9
  font-family: 'Inter', sans-serif;
10
- text-align: center;
11
  margin: 20px;
 
12
  }
13
  #recordButton {
14
  width: 80px;
@@ -28,18 +28,10 @@
28
  #recordButton:active {
29
  transform: scale(0.95);
30
  }
31
- #transcriptions {
32
  margin-top: 20px;
33
- font-size: 18px;
34
- text-align: left;
35
- }
36
- .transcription {
37
- display: inline;
38
- color: black;
39
- }
40
- .buffer {
41
- display: inline;
42
- color: rgb(197, 197, 197);
43
  }
44
  .settings-container {
45
  display: flex;
@@ -73,9 +65,29 @@
73
  label {
74
  font-size: 14px;
75
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  </style>
77
  </head>
78
  <body>
 
79
  <div class="settings-container">
80
  <button id="recordButton">🎙️</button>
81
  <div class="settings">
@@ -96,9 +108,11 @@
96
  </div>
97
  </div>
98
  </div>
 
99
  <p id="status"></p>
100
 
101
- <div id="transcriptions"></div>
 
102
 
103
  <script>
104
  let isRecording = false;
@@ -106,89 +120,97 @@
106
  let recorder = null;
107
  let chunkDuration = 1000;
108
  let websocketUrl = "ws://localhost:8000/asr";
109
-
110
- // Tracks whether the user voluntarily closed the WebSocket
111
  let userClosing = false;
112
 
113
  const statusText = document.getElementById("status");
114
  const recordButton = document.getElementById("recordButton");
115
  const chunkSelector = document.getElementById("chunkSelector");
116
  const websocketInput = document.getElementById("websocketInput");
117
- const transcriptionsDiv = document.getElementById("transcriptions");
118
 
119
- let fullTranscription = ""; // Store confirmed transcription
120
-
121
- // Update chunk duration based on the selector
122
  chunkSelector.addEventListener("change", () => {
123
  chunkDuration = parseInt(chunkSelector.value);
124
  });
125
 
126
- // Update WebSocket URL dynamically, with some basic checks
127
  websocketInput.addEventListener("change", () => {
128
  const urlValue = websocketInput.value.trim();
129
-
130
- // Quick check to see if it starts with ws:// or wss://
131
  if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
132
- statusText.textContent =
133
- "Invalid WebSocket URL. It should start with ws:// or wss://";
134
  return;
135
  }
136
  websocketUrl = urlValue;
137
  statusText.textContent = "WebSocket URL updated. Ready to connect.";
138
  });
139
 
140
- /**
141
- * Opens webSocket connection.
142
- * returns a Promise that resolves when the connection is open.
143
- * rejects if there was an error.
144
- */
145
  function setupWebSocket() {
146
  return new Promise((resolve, reject) => {
147
  try {
148
  websocket = new WebSocket(websocketUrl);
149
  } catch (error) {
150
- statusText.textContent =
151
- "Invalid WebSocket URL. Please check the URL and try again.";
152
  reject(error);
153
  return;
154
  }
155
 
156
  websocket.onopen = () => {
157
- statusText.textContent = "Connected to server";
158
  resolve();
159
  };
160
 
161
- websocket.onclose = (event) => {
162
- // If we manually closed it, we say so
163
  if (userClosing) {
164
  statusText.textContent = "WebSocket closed by user.";
165
  } else {
166
- statusText.textContent = "Disconnected from the websocket server. If this is the first launch, the model may be downloading in the backend. Check the API logs for more information.";
 
167
  }
168
  userClosing = false;
169
  };
170
 
171
  websocket.onerror = () => {
172
- statusText.textContent = "Error connecting to WebSocket";
173
  reject(new Error("Error connecting to WebSocket"));
174
  };
175
 
 
176
  websocket.onmessage = (event) => {
177
  const data = JSON.parse(event.data);
178
- const { transcription, buffer } = data;
179
-
180
- // Update confirmed transcription
181
- fullTranscription += transcription;
182
-
183
- // Update the transcription display
184
- transcriptionsDiv.innerHTML = `
185
- <span class="transcription">${fullTranscription}</span>
186
- <span class="buffer">${buffer}</span>
187
- `;
 
 
 
188
  };
189
  });
190
  }
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  async function startRecording() {
193
  try {
194
  const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
@@ -202,22 +224,18 @@
202
  isRecording = true;
203
  updateUI();
204
  } catch (err) {
205
- statusText.textContent =
206
- "Error accessing microphone. Please allow microphone access.";
207
  }
208
  }
209
 
210
  function stopRecording() {
211
  userClosing = true;
212
-
213
- // Stop the recorder if it exists
214
  if (recorder) {
215
  recorder.stop();
216
  recorder = null;
217
  }
218
  isRecording = false;
219
 
220
- // Close the websocket if it exists
221
  if (websocket) {
222
  websocket.close();
223
  websocket = null;
@@ -228,15 +246,12 @@
228
 
229
  async function toggleRecording() {
230
  if (!isRecording) {
231
- fullTranscription = "";
232
- transcriptionsDiv.innerHTML = "";
233
-
234
  try {
235
  await setupWebSocket();
236
  await startRecording();
237
  } catch (err) {
238
- statusText.textContent =
239
- "Could not connect to WebSocket or access mic. Recording aborted.";
240
  }
241
  } else {
242
  stopRecording();
@@ -245,9 +260,7 @@
245
 
246
  function updateUI() {
247
  recordButton.classList.toggle("recording", isRecording);
248
- statusText.textContent = isRecording
249
- ? "Recording..."
250
- : "Click to start transcription";
251
  }
252
 
253
  recordButton.addEventListener("click", toggleRecording);
 
7
  <style>
8
  body {
9
  font-family: 'Inter', sans-serif;
 
10
  margin: 20px;
11
+ text-align: center;
12
  }
13
  #recordButton {
14
  width: 80px;
 
28
  #recordButton:active {
29
  transform: scale(0.95);
30
  }
31
+ #status {
32
  margin-top: 20px;
33
+ font-size: 16px;
34
+ color: #333;
 
 
 
 
 
 
 
 
35
  }
36
  .settings-container {
37
  display: flex;
 
65
  label {
66
  font-size: 14px;
67
  }
68
+ /* Speaker-labeled transcript area */
69
+ #linesTranscript {
70
+ margin: 20px auto;
71
+ max-width: 600px;
72
+ text-align: left;
73
+ font-size: 16px;
74
+ }
75
+ #linesTranscript p {
76
+ margin: 5px 0;
77
+ }
78
+ #linesTranscript strong {
79
+ color: #333;
80
+ }
81
+ /* Grey buffer styling */
82
+ .buffer {
83
+ color: rgb(180, 180, 180);
84
+ font-style: italic;
85
+ margin-left: 4px;
86
+ }
87
  </style>
88
  </head>
89
  <body>
90
+
91
  <div class="settings-container">
92
  <button id="recordButton">🎙️</button>
93
  <div class="settings">
 
108
  </div>
109
  </div>
110
  </div>
111
+
112
  <p id="status"></p>
113
 
114
+ <!-- Speaker-labeled transcript -->
115
+ <div id="linesTranscript"></div>
116
 
117
  <script>
118
  let isRecording = false;
 
120
  let recorder = null;
121
  let chunkDuration = 1000;
122
  let websocketUrl = "ws://localhost:8000/asr";
 
 
123
  let userClosing = false;
124
 
125
  const statusText = document.getElementById("status");
126
  const recordButton = document.getElementById("recordButton");
127
  const chunkSelector = document.getElementById("chunkSelector");
128
  const websocketInput = document.getElementById("websocketInput");
129
+ const linesTranscriptDiv = document.getElementById("linesTranscript");
130
 
 
 
 
131
  chunkSelector.addEventListener("change", () => {
132
  chunkDuration = parseInt(chunkSelector.value);
133
  });
134
 
 
135
  websocketInput.addEventListener("change", () => {
136
  const urlValue = websocketInput.value.trim();
 
 
137
  if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
138
+ statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
 
139
  return;
140
  }
141
  websocketUrl = urlValue;
142
  statusText.textContent = "WebSocket URL updated. Ready to connect.";
143
  });
144
 
 
 
 
 
 
145
  function setupWebSocket() {
146
  return new Promise((resolve, reject) => {
147
  try {
148
  websocket = new WebSocket(websocketUrl);
149
  } catch (error) {
150
+ statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
 
151
  reject(error);
152
  return;
153
  }
154
 
155
  websocket.onopen = () => {
156
+ statusText.textContent = "Connected to server.";
157
  resolve();
158
  };
159
 
160
+ websocket.onclose = () => {
 
161
  if (userClosing) {
162
  statusText.textContent = "WebSocket closed by user.";
163
  } else {
164
+ statusText.textContent =
165
+ "Disconnected from the WebSocket server. (Check logs if model is loading.)";
166
  }
167
  userClosing = false;
168
  };
169
 
170
  websocket.onerror = () => {
171
+ statusText.textContent = "Error connecting to WebSocket.";
172
  reject(new Error("Error connecting to WebSocket"));
173
  };
174
 
175
+ // Handle messages from server
176
  websocket.onmessage = (event) => {
177
  const data = JSON.parse(event.data);
178
+ /*
179
+ The server might send:
180
+ {
181
+ "lines": [
182
+ {"speaker": 0, "text": "Hello."},
183
+ {"speaker": 1, "text": "Bonjour."},
184
+ ...
185
+ ],
186
+ "buffer": "..."
187
+ }
188
+ */
189
+ const { lines = [], buffer = "" } = data;
190
+ renderLinesWithBuffer(lines, buffer);
191
  };
192
  });
193
  }
194
 
195
+ function renderLinesWithBuffer(lines, buffer) {
196
+ // Clears if no lines
197
+ if (!Array.isArray(lines) || lines.length === 0) {
198
+ linesTranscriptDiv.innerHTML = "";
199
+ return;
200
+ }
201
+ // Build the HTML
202
+ // The buffer is appended to the last line if it's non-empty
203
+ const linesHtml = lines.map((item, idx) => {
204
+ let textContent = item.text;
205
+ if (idx === lines.length - 1 && buffer) {
206
+ textContent += `<span class="buffer">${buffer}</span>`;
207
+ }
208
+ return `<p><strong>Speaker ${item.speaker}:</strong> ${textContent}</p>`;
209
+ }).join("");
210
+
211
+ linesTranscriptDiv.innerHTML = linesHtml;
212
+ }
213
+
214
  async function startRecording() {
215
  try {
216
  const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
 
224
  isRecording = true;
225
  updateUI();
226
  } catch (err) {
227
+ statusText.textContent = "Error accessing microphone. Please allow microphone access.";
 
228
  }
229
  }
230
 
231
  function stopRecording() {
232
  userClosing = true;
 
 
233
  if (recorder) {
234
  recorder.stop();
235
  recorder = null;
236
  }
237
  isRecording = false;
238
 
 
239
  if (websocket) {
240
  websocket.close();
241
  websocket = null;
 
246
 
247
  async function toggleRecording() {
248
  if (!isRecording) {
249
+ linesTranscriptDiv.innerHTML = "";
 
 
250
  try {
251
  await setupWebSocket();
252
  await startRecording();
253
  } catch (err) {
254
+ statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
 
255
  }
256
  } else {
257
  stopRecording();
 
260
 
261
  function updateUI() {
262
  recordButton.classList.toggle("recording", isRecording);
263
+ statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
 
 
264
  }
265
 
266
  recordButton.addEventListener("click", toggleRecording);
src/whisper_streaming/online_asr.py CHANGED
@@ -215,21 +215,14 @@ class OnlineASRProcessor:
215
  # self.chunk_at(t)
216
 
217
 
218
-
219
-
220
-
221
-
222
-
223
-
224
-
225
  return completed
226
 
227
- def chunk_completed_sentence(self, commited_text):
228
- if commited_text == []:
229
- return
230
-
231
- sents = self.words_to_sentences(commited_text)
232
-
233
 
234
 
235
  if len(sents) < 2:
@@ -322,7 +315,7 @@ class OnlineASRProcessor:
322
  """
323
  o = self.transcript_buffer.complete()
324
  f = self.concatenate_tsw(o)
325
- logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
326
  self.buffer_time_offset += len(self.audio_buffer) / 16000
327
  return f
328
 
@@ -365,7 +358,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
365
  import torch
366
 
367
  model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
368
- from silero_vad_iterator import FixedVADIterator
369
 
370
  self.vac = FixedVADIterator(
371
  model
 
215
  # self.chunk_at(t)
216
 
217
 
 
 
 
 
 
 
 
218
  return completed
219
 
220
+ def chunk_completed_sentence(self):
221
+ if self.commited == []:
222
+ return
223
+ raw_text = self.asr.sep.join([s[2] for s in self.commited])
224
+ logger.debug(f"COMPLETED SENTENCE: {raw_text}")
225
+ sents = self.words_to_sentences(self.commited)
226
 
227
 
228
  if len(sents) < 2:
 
315
  """
316
  o = self.transcript_buffer.complete()
317
  f = self.concatenate_tsw(o)
318
+ logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2][0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
319
  self.buffer_time_offset += len(self.audio_buffer) / 16000
320
  return f
321
 
 
358
  import torch
359
 
360
  model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
361
+ from src.whisper_streaming.silero_vad_iterator import FixedVADIterator
362
 
363
  self.vac = FixedVADIterator(
364
  model
silero_vad_iterator.py → src/whisper_streaming/silero_vad_iterator.py RENAMED
File without changes
whisper_online.py → src/whisper_streaming/whisper_online.py RENAMED
File without changes
whisper_fastapi_online_server.py CHANGED
@@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.responses import HTMLResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
- from whisper_online import backend_factory, online_factory, add_shared_args
13
 
14
  app = FastAPI()
15
  app.add_middleware(
@@ -37,11 +37,24 @@ parser.add_argument(
37
  dest="warmup_file",
38
  help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
39
  )
 
 
 
 
 
 
 
 
 
40
  add_shared_args(parser)
41
  args = parser.parse_args()
42
 
43
  asr, tokenizer = backend_factory(args)
44
 
 
 
 
 
45
  # Load demo HTML for the root endpoint
46
  with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
47
  html = f.read()
@@ -78,6 +91,7 @@ async def start_ffmpeg_decoder():
78
  return process
79
 
80
 
 
81
  @app.websocket("/asr")
82
  async def websocket_endpoint(websocket: WebSocket):
83
  await websocket.accept()
@@ -89,12 +103,18 @@ async def websocket_endpoint(websocket: WebSocket):
89
  online = online_factory(args, asr, tokenizer)
90
  print("Online loaded.")
91
 
 
 
 
92
  # Continuously read decoded PCM from ffmpeg stdout in a background task
93
  async def ffmpeg_stdout_reader():
94
  nonlocal pcm_buffer
95
  loop = asyncio.get_event_loop()
96
  full_transcription = ""
97
  beg = time()
 
 
 
98
  while True:
99
  try:
100
  elapsed_time = int(time() - beg)
@@ -122,8 +142,17 @@ async def websocket_endpoint(websocket: WebSocket):
122
  )
123
  pcm_buffer = bytearray()
124
  online.insert_audio_chunk(pcm_array)
125
- transcription = online.process_iter()[2]
126
- full_transcription += transcription
 
 
 
 
 
 
 
 
 
127
  if args.vac:
128
  buffer = online.online.concatenate_tsw(
129
  online.online.transcript_buffer.buffer
@@ -136,9 +165,32 @@ async def websocket_endpoint(websocket: WebSocket):
136
  buffer in full_transcription
137
  ): # With VAC, the buffer is not updated until the next chunk is processed
138
  buffer = ""
139
- await websocket.send_json(
140
- {"transcription": transcription, "buffer": buffer}
141
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
  print(f"Exception in ffmpeg_stdout_reader: {e}")
144
  break
@@ -174,6 +226,11 @@ async def websocket_endpoint(websocket: WebSocket):
174
 
175
  ffmpeg_process.wait()
176
  del online
 
 
 
 
 
177
 
178
 
179
 
 
9
  from fastapi.responses import HTMLResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
+ from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
13
 
14
  app = FastAPI()
15
  app.add_middleware(
 
37
  dest="warmup_file",
38
  help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
39
  )
40
+
41
+ parser.add_argument(
42
+ "--diarization",
43
+ type=bool,
44
+ default=False,
45
+ help="Whether to enable speaker diarization.",
46
+ )
47
+
48
+
49
  add_shared_args(parser)
50
  args = parser.parse_args()
51
 
52
  asr, tokenizer = backend_factory(args)
53
 
54
+ if args.diarization:
55
+ from src.diarization.diarization_online import DiartDiarization
56
+
57
+
58
  # Load demo HTML for the root endpoint
59
  with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
60
  html = f.read()
 
91
  return process
92
 
93
 
94
+
95
  @app.websocket("/asr")
96
  async def websocket_endpoint(websocket: WebSocket):
97
  await websocket.accept()
 
103
  online = online_factory(args, asr, tokenizer)
104
  print("Online loaded.")
105
 
106
+ if args.diarization:
107
+ diarization = DiartDiarization(SAMPLE_RATE)
108
+
109
  # Continuously read decoded PCM from ffmpeg stdout in a background task
110
  async def ffmpeg_stdout_reader():
111
  nonlocal pcm_buffer
112
  loop = asyncio.get_event_loop()
113
  full_transcription = ""
114
  beg = time()
115
+
116
+ chunk_history = [] # Will store dicts: {beg, end, text, speaker}
117
+
118
  while True:
119
  try:
120
  elapsed_time = int(time() - beg)
 
142
  )
143
  pcm_buffer = bytearray()
144
  online.insert_audio_chunk(pcm_array)
145
+ beg_trans, end_trans, trans = online.process_iter()
146
+
147
+ if trans:
148
+ chunk_history.append({
149
+ "beg": beg_trans,
150
+ "end": end_trans,
151
+ "text": trans,
152
+ "speaker": "0"
153
+ })
154
+
155
+ full_transcription += trans
156
  if args.vac:
157
  buffer = online.online.concatenate_tsw(
158
  online.online.transcript_buffer.buffer
 
165
  buffer in full_transcription
166
  ): # With VAC, the buffer is not updated until the next chunk is processed
167
  buffer = ""
168
+
169
+ lines = [
170
+ {
171
+ "speaker": "0",
172
+ "text": "",
173
+ }
174
+ ]
175
+
176
+ if args.diarization:
177
+ await diarization.diarize(pcm_array)
178
+ diarization.assign_speakers_to_chunks(chunk_history)
179
+
180
+ for ch in chunk_history:
181
+ if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
182
+ lines.append(
183
+ {
184
+ "speaker": ch["speaker"][-1],
185
+ "text": ch['text'],
186
+ }
187
+ )
188
+ else:
189
+ lines[-1]["text"] += ch['text']
190
+
191
+ response = {"lines": lines, "buffer": buffer}
192
+ await websocket.send_json(response)
193
+
194
  except Exception as e:
195
  print(f"Exception in ffmpeg_stdout_reader: {e}")
196
  break
 
226
 
227
  ffmpeg_process.wait()
228
  del online
229
+
230
+ if args.diarization:
231
+ # Stop Diart
232
+ diarization.close()
233
+
234
 
235
 
236