refacto 0
Browse files- audio.py +306 -0
- diarization/diarization_online.py +3 -3
- formatters.py +91 -0
- parse_args.py +52 -0
- state.py +96 -0
- whisper_fastapi_online_server.py +23 -400
- whisper_streaming_custom/whisper_online.py +1 -1
audio.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import argparse
|
3 |
+
import asyncio
|
4 |
+
import numpy as np
|
5 |
+
import ffmpeg
|
6 |
+
from time import time, sleep
|
7 |
+
from contextlib import asynccontextmanager
|
8 |
+
|
9 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
10 |
+
from fastapi.responses import HTMLResponse
|
11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
|
13 |
+
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
|
14 |
+
from timed_objects import ASRToken
|
15 |
+
|
16 |
+
import math
|
17 |
+
import logging
|
18 |
+
from datetime import timedelta
|
19 |
+
import traceback
|
20 |
+
from state import SharedState
|
21 |
+
from formatters import format_time
|
22 |
+
from parse_args import parse_args
|
23 |
+
|
24 |
+
|
25 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
26 |
+
logging.getLogger().setLevel(logging.WARNING)
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logger.setLevel(logging.DEBUG)
|
29 |
+
|
30 |
+
|
31 |
+
class AudioProcessor:
|
32 |
+
|
33 |
+
def __init__(self, args, asr, tokenizer):
|
34 |
+
self.args = args
|
35 |
+
self.sample_rate = 16000
|
36 |
+
self.channels = 1
|
37 |
+
self.samples_per_sec = int(self.sample_rate * args.min_chunk_size)
|
38 |
+
self.bytes_per_sample = 2
|
39 |
+
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
40 |
+
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
41 |
+
self.shared_state = SharedState()
|
42 |
+
self.asr = asr
|
43 |
+
self.tokenizer = tokenizer
|
44 |
+
|
45 |
+
def convert_pcm_to_float(self, pcm_buffer):
|
46 |
+
"""
|
47 |
+
Converts a PCM buffer in s16le format to a normalized NumPy array.
|
48 |
+
Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format
|
49 |
+
Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0
|
50 |
+
"""
|
51 |
+
pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
|
52 |
+
/ 32768.0)
|
53 |
+
return pcm_array
|
54 |
+
|
55 |
+
async def start_ffmpeg_decoder(self):
|
56 |
+
"""
|
57 |
+
Start an FFmpeg process in async streaming mode that reads WebM from stdin
|
58 |
+
and outputs raw s16le PCM on stdout. Returns the process object.
|
59 |
+
"""
|
60 |
+
process = (
|
61 |
+
ffmpeg.input("pipe:0", format="webm")
|
62 |
+
.output(
|
63 |
+
"pipe:1",
|
64 |
+
format="s16le",
|
65 |
+
acodec="pcm_s16le",
|
66 |
+
ac=self.channels,
|
67 |
+
ar=str(self.sample_rate),
|
68 |
+
)
|
69 |
+
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
|
70 |
+
)
|
71 |
+
return process
|
72 |
+
|
73 |
+
async def restart_ffmpeg(self, ffmpeg_process, online, pcm_buffer):
|
74 |
+
if ffmpeg_process:
|
75 |
+
try:
|
76 |
+
ffmpeg_process.kill()
|
77 |
+
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
|
78 |
+
except Exception as e:
|
79 |
+
logger.warning(f"Error killing FFmpeg process: {e}")
|
80 |
+
ffmpeg_process = await self.start_ffmpeg_decoder()
|
81 |
+
pcm_buffer = bytearray()
|
82 |
+
|
83 |
+
if self.args.transcription:
|
84 |
+
online = online_factory(self.args, self.asr, self.tokenizer)
|
85 |
+
|
86 |
+
await self.shared_state.reset()
|
87 |
+
logger.info("FFmpeg process started.")
|
88 |
+
return ffmpeg_process, online, pcm_buffer
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
async def ffmpeg_stdout_reader(self, ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue):
|
93 |
+
loop = asyncio.get_event_loop()
|
94 |
+
beg = time()
|
95 |
+
|
96 |
+
while True:
|
97 |
+
try:
|
98 |
+
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
99 |
+
ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096)
|
100 |
+
beg = time()
|
101 |
+
|
102 |
+
# Read chunk with timeout
|
103 |
+
try:
|
104 |
+
chunk = await asyncio.wait_for(
|
105 |
+
loop.run_in_executor(
|
106 |
+
None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
107 |
+
),
|
108 |
+
timeout=15.0
|
109 |
+
)
|
110 |
+
except asyncio.TimeoutError:
|
111 |
+
logger.warning("FFmpeg read timeout. Restarting...")
|
112 |
+
ffmpeg_process, online, pcm_buffer = await self.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
113 |
+
beg = time()
|
114 |
+
continue # Skip processing and read from new process
|
115 |
+
|
116 |
+
if not chunk:
|
117 |
+
logger.info("FFmpeg stdout closed.")
|
118 |
+
break
|
119 |
+
pcm_buffer.extend(chunk)
|
120 |
+
|
121 |
+
if self.args.diarization and diarization_queue:
|
122 |
+
await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy())
|
123 |
+
|
124 |
+
if len(pcm_buffer) >= self.bytes_per_sec:
|
125 |
+
if len(pcm_buffer) > self.max_bytes_per_sec:
|
126 |
+
logger.warning(
|
127 |
+
f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
128 |
+
The model probably struggles to keep up. Consider using a smaller model.
|
129 |
+
""")
|
130 |
+
|
131 |
+
pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec])
|
132 |
+
pcm_buffer = pcm_buffer[self.max_bytes_per_sec:]
|
133 |
+
|
134 |
+
if self.args.transcription and transcription_queue:
|
135 |
+
await transcription_queue.put(pcm_array.copy())
|
136 |
+
|
137 |
+
|
138 |
+
if not self.args.transcription and not self.args.diarization:
|
139 |
+
await asyncio.sleep(0.1)
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
143 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
144 |
+
break
|
145 |
+
logger.info("Exiting ffmpeg_stdout_reader...")
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
async def transcription_processor(self, pcm_queue, online):
|
151 |
+
full_transcription = ""
|
152 |
+
sep = online.asr.sep
|
153 |
+
|
154 |
+
while True:
|
155 |
+
try:
|
156 |
+
pcm_array = await pcm_queue.get()
|
157 |
+
|
158 |
+
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
159 |
+
|
160 |
+
# Process transcription
|
161 |
+
online.insert_audio_chunk(pcm_array)
|
162 |
+
new_tokens = online.process_iter()
|
163 |
+
|
164 |
+
if new_tokens:
|
165 |
+
full_transcription += sep.join([t.text for t in new_tokens])
|
166 |
+
|
167 |
+
_buffer = online.get_buffer()
|
168 |
+
buffer = _buffer.text
|
169 |
+
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
170 |
+
|
171 |
+
if buffer in full_transcription:
|
172 |
+
buffer = ""
|
173 |
+
|
174 |
+
await self.shared_state.update_transcription(
|
175 |
+
new_tokens, buffer, end_buffer, full_transcription, sep)
|
176 |
+
|
177 |
+
except Exception as e:
|
178 |
+
logger.warning(f"Exception in transcription_processor: {e}")
|
179 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
180 |
+
finally:
|
181 |
+
pcm_queue.task_done()
|
182 |
+
|
183 |
+
async def diarization_processor(self, pcm_queue, diarization_obj):
|
184 |
+
buffer_diarization = ""
|
185 |
+
|
186 |
+
while True:
|
187 |
+
try:
|
188 |
+
pcm_array = await pcm_queue.get()
|
189 |
+
|
190 |
+
# Process diarization
|
191 |
+
await diarization_obj.diarize(pcm_array)
|
192 |
+
|
193 |
+
# Get current state
|
194 |
+
state = await self.shared_state.get_current_state()
|
195 |
+
tokens = state["tokens"]
|
196 |
+
end_attributed_speaker = state["end_attributed_speaker"]
|
197 |
+
|
198 |
+
# Update speaker information
|
199 |
+
new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens(
|
200 |
+
end_attributed_speaker, tokens)
|
201 |
+
|
202 |
+
await self.shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization)
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
logger.warning(f"Exception in diarization_processor: {e}")
|
206 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
207 |
+
finally:
|
208 |
+
pcm_queue.task_done()
|
209 |
+
|
210 |
+
async def results_formatter(self, websocket):
|
211 |
+
while True:
|
212 |
+
try:
|
213 |
+
# Get the current state
|
214 |
+
state = await self.shared_state.get_current_state()
|
215 |
+
tokens = state["tokens"]
|
216 |
+
buffer_transcription = state["buffer_transcription"]
|
217 |
+
buffer_diarization = state["buffer_diarization"]
|
218 |
+
end_attributed_speaker = state["end_attributed_speaker"]
|
219 |
+
remaining_time_transcription = state["remaining_time_transcription"]
|
220 |
+
remaining_time_diarization = state["remaining_time_diarization"]
|
221 |
+
sep = state["sep"]
|
222 |
+
|
223 |
+
# If diarization is enabled but no transcription, add dummy tokens periodically
|
224 |
+
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
225 |
+
await self.shared_state.add_dummy_token()
|
226 |
+
sleep(0.5)
|
227 |
+
state = await self.shared_state.get_current_state()
|
228 |
+
tokens = state["tokens"]
|
229 |
+
# Process tokens to create response
|
230 |
+
previous_speaker = -1
|
231 |
+
lines = []
|
232 |
+
last_end_diarized = 0
|
233 |
+
undiarized_text = []
|
234 |
+
|
235 |
+
for token in tokens:
|
236 |
+
speaker = token.speaker
|
237 |
+
if self.args.diarization:
|
238 |
+
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
|
239 |
+
undiarized_text.append(token.text)
|
240 |
+
continue
|
241 |
+
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
|
242 |
+
speaker = previous_speaker
|
243 |
+
if speaker not in [-1, 0]:
|
244 |
+
last_end_diarized = max(token.end, last_end_diarized)
|
245 |
+
|
246 |
+
if speaker != previous_speaker or not lines:
|
247 |
+
lines.append(
|
248 |
+
{
|
249 |
+
"speaker": speaker,
|
250 |
+
"text": token.text,
|
251 |
+
"beg": format_time(token.start),
|
252 |
+
"end": format_time(token.end),
|
253 |
+
"diff": round(token.end - last_end_diarized, 2)
|
254 |
+
}
|
255 |
+
)
|
256 |
+
previous_speaker = speaker
|
257 |
+
elif token.text: # Only append if text isn't empty
|
258 |
+
lines[-1]["text"] += sep + token.text
|
259 |
+
lines[-1]["end"] = format_time(token.end)
|
260 |
+
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
261 |
+
|
262 |
+
if undiarized_text:
|
263 |
+
combined_buffer_diarization = sep.join(undiarized_text)
|
264 |
+
if buffer_transcription:
|
265 |
+
combined_buffer_diarization += sep
|
266 |
+
await self.shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
|
267 |
+
buffer_diarization = combined_buffer_diarization
|
268 |
+
|
269 |
+
if lines:
|
270 |
+
response = {
|
271 |
+
"lines": lines,
|
272 |
+
"buffer_transcription": buffer_transcription,
|
273 |
+
"buffer_diarization": buffer_diarization,
|
274 |
+
"remaining_time_transcription": remaining_time_transcription,
|
275 |
+
"remaining_time_diarization": remaining_time_diarization
|
276 |
+
}
|
277 |
+
else:
|
278 |
+
response = {
|
279 |
+
"lines": [{
|
280 |
+
"speaker": 1,
|
281 |
+
"text": "",
|
282 |
+
"beg": format_time(0),
|
283 |
+
"end": format_time(tokens[-1].end) if tokens else format_time(0),
|
284 |
+
"diff": 0
|
285 |
+
}],
|
286 |
+
"buffer_transcription": buffer_transcription,
|
287 |
+
"buffer_diarization": buffer_diarization,
|
288 |
+
"remaining_time_transcription": remaining_time_transcription,
|
289 |
+
"remaining_time_diarization": remaining_time_diarization
|
290 |
+
|
291 |
+
}
|
292 |
+
|
293 |
+
response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
|
294 |
+
|
295 |
+
if response_content != self.shared_state.last_response_content:
|
296 |
+
if lines or buffer_transcription or buffer_diarization:
|
297 |
+
await websocket.send_json(response)
|
298 |
+
self.shared_state.last_response_content = response_content
|
299 |
+
|
300 |
+
# Add a small delay to avoid overwhelming the client
|
301 |
+
await asyncio.sleep(0.1)
|
302 |
+
|
303 |
+
except Exception as e:
|
304 |
+
logger.warning(f"Exception in results_formatter: {e}")
|
305 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
306 |
+
await asyncio.sleep(0.5) # Back off on error
|
diarization/diarization_online.py
CHANGED
@@ -5,7 +5,7 @@ import numpy as np
|
|
5 |
import logging
|
6 |
|
7 |
|
8 |
-
from diart import SpeakerDiarization
|
9 |
from diart.inference import StreamingInference
|
10 |
from diart.sources import AudioSource
|
11 |
from timed_objects import SpeakerSegment
|
@@ -103,8 +103,8 @@ class WebSocketAudioSource(AudioSource):
|
|
103 |
|
104 |
|
105 |
class DiartDiarization:
|
106 |
-
def __init__(self, sample_rate: int, use_microphone: bool = False):
|
107 |
-
self.pipeline = SpeakerDiarization()
|
108 |
self.observer = DiarizationObserver()
|
109 |
|
110 |
if use_microphone:
|
|
|
5 |
import logging
|
6 |
|
7 |
|
8 |
+
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
9 |
from diart.inference import StreamingInference
|
10 |
from diart.sources import AudioSource
|
11 |
from timed_objects import SpeakerSegment
|
|
|
103 |
|
104 |
|
105 |
class DiartDiarization:
|
106 |
+
def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
107 |
+
self.pipeline = SpeakerDiarization(config=config)
|
108 |
self.observer = DiarizationObserver()
|
109 |
|
110 |
if use_microphone:
|
formatters.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, List
|
2 |
+
from datetime import timedelta
|
3 |
+
|
4 |
+
def format_time(seconds: float) -> str:
|
5 |
+
"""Format seconds as HH:MM:SS."""
|
6 |
+
return str(timedelta(seconds=int(seconds)))
|
7 |
+
|
8 |
+
def format_response(state: Dict[str, Any], with_diarization: bool = False) -> Dict[str, Any]:
|
9 |
+
"""
|
10 |
+
Format the shared state into a client-friendly response.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
state: Current shared state dictionary
|
14 |
+
with_diarization: Whether to include diarization formatting
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Formatted response dictionary ready to send to client
|
18 |
+
"""
|
19 |
+
tokens = state["tokens"]
|
20 |
+
buffer_transcription = state["buffer_transcription"]
|
21 |
+
buffer_diarization = state["buffer_diarization"]
|
22 |
+
end_attributed_speaker = state["end_attributed_speaker"]
|
23 |
+
remaining_time_transcription = state["remaining_time_transcription"]
|
24 |
+
remaining_time_diarization = state["remaining_time_diarization"]
|
25 |
+
sep = state["sep"]
|
26 |
+
|
27 |
+
# Default response for empty state
|
28 |
+
if not tokens:
|
29 |
+
return {
|
30 |
+
"lines": [{
|
31 |
+
"speaker": 1,
|
32 |
+
"text": "",
|
33 |
+
"beg": format_time(0),
|
34 |
+
"end": format_time(0),
|
35 |
+
"diff": 0
|
36 |
+
}],
|
37 |
+
"buffer_transcription": buffer_transcription,
|
38 |
+
"buffer_diarization": buffer_diarization,
|
39 |
+
"remaining_time_transcription": remaining_time_transcription,
|
40 |
+
"remaining_time_diarization": remaining_time_diarization
|
41 |
+
}
|
42 |
+
|
43 |
+
# Process tokens to create response
|
44 |
+
previous_speaker = -1
|
45 |
+
lines = []
|
46 |
+
last_end_diarized = 0
|
47 |
+
undiarized_text = []
|
48 |
+
|
49 |
+
for token in tokens:
|
50 |
+
speaker = token.speaker
|
51 |
+
|
52 |
+
# Handle diarization logic
|
53 |
+
if with_diarization:
|
54 |
+
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
|
55 |
+
undiarized_text.append(token.text)
|
56 |
+
continue
|
57 |
+
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
|
58 |
+
speaker = previous_speaker
|
59 |
+
|
60 |
+
if speaker not in [-1, 0]:
|
61 |
+
last_end_diarized = max(token.end, last_end_diarized)
|
62 |
+
|
63 |
+
# Add new line or append to existing line
|
64 |
+
if speaker != previous_speaker or not lines:
|
65 |
+
lines.append({
|
66 |
+
"speaker": speaker,
|
67 |
+
"text": token.text,
|
68 |
+
"beg": format_time(token.start),
|
69 |
+
"end": format_time(token.end),
|
70 |
+
"diff": round(token.end - last_end_diarized, 2)
|
71 |
+
})
|
72 |
+
previous_speaker = speaker
|
73 |
+
elif token.text: # Only append if text isn't empty
|
74 |
+
lines[-1]["text"] += sep + token.text
|
75 |
+
lines[-1]["end"] = format_time(token.end)
|
76 |
+
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
77 |
+
|
78 |
+
# If we have undiarized text, include it in the buffer
|
79 |
+
if undiarized_text:
|
80 |
+
combined_buffer = sep.join(undiarized_text)
|
81 |
+
if buffer_transcription:
|
82 |
+
combined_buffer += sep + buffer_transcription
|
83 |
+
buffer_diarization = combined_buffer
|
84 |
+
|
85 |
+
return {
|
86 |
+
"lines": lines,
|
87 |
+
"buffer_transcription": buffer_transcription,
|
88 |
+
"buffer_diarization": buffer_diarization,
|
89 |
+
"remaining_time_transcription": remaining_time_transcription,
|
90 |
+
"remaining_time_diarization": remaining_time_diarization
|
91 |
+
}
|
parse_args.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
from whisper_streaming_custom.whisper_online import add_shared_args
|
4 |
+
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
|
8 |
+
parser.add_argument(
|
9 |
+
"--host",
|
10 |
+
type=str,
|
11 |
+
default="localhost",
|
12 |
+
help="The host address to bind the server to.",
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--port", type=int, default=8000, help="The port number to bind the server to."
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--warmup-file",
|
19 |
+
type=str,
|
20 |
+
default=None,
|
21 |
+
dest="warmup_file",
|
22 |
+
help="""
|
23 |
+
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
24 |
+
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
25 |
+
If False, no warmup is performed.
|
26 |
+
""",
|
27 |
+
)
|
28 |
+
|
29 |
+
parser.add_argument(
|
30 |
+
"--confidence-validation",
|
31 |
+
type=bool,
|
32 |
+
default=False,
|
33 |
+
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
34 |
+
)
|
35 |
+
|
36 |
+
parser.add_argument(
|
37 |
+
"--diarization",
|
38 |
+
type=bool,
|
39 |
+
default=True,
|
40 |
+
help="Whether to enable speaker diarization.",
|
41 |
+
)
|
42 |
+
|
43 |
+
parser.add_argument(
|
44 |
+
"--transcription",
|
45 |
+
type=bool,
|
46 |
+
default=True,
|
47 |
+
help="To disable to only see live diarization results.",
|
48 |
+
)
|
49 |
+
|
50 |
+
add_shared_args(parser)
|
51 |
+
args = parser.parse_args()
|
52 |
+
return args
|
state.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from time import time
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from timed_objects import ASRToken
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
class SharedState:
|
12 |
+
"""
|
13 |
+
Thread-safe state manager for streaming transcription and diarization.
|
14 |
+
Handles coordination between audio processing, transcription, and diarization.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.tokens: List[ASRToken] = []
|
19 |
+
self.buffer_transcription: str = ""
|
20 |
+
self.buffer_diarization: str = ""
|
21 |
+
self.full_transcription: str = ""
|
22 |
+
self.end_buffer: float = 0
|
23 |
+
self.end_attributed_speaker: float = 0
|
24 |
+
self.lock = asyncio.Lock()
|
25 |
+
self.beg_loop: float = time()
|
26 |
+
self.sep: str = " " # Default separator
|
27 |
+
self.last_response_content: str = "" # To track changes in response
|
28 |
+
|
29 |
+
async def update_transcription(self, new_tokens: List[ASRToken], buffer: str,
|
30 |
+
end_buffer: float, full_transcription: str, sep: str) -> None:
|
31 |
+
"""Update the state with new transcription data."""
|
32 |
+
async with self.lock:
|
33 |
+
self.tokens.extend(new_tokens)
|
34 |
+
self.buffer_transcription = buffer
|
35 |
+
self.end_buffer = end_buffer
|
36 |
+
self.full_transcription = full_transcription
|
37 |
+
self.sep = sep
|
38 |
+
|
39 |
+
async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None:
|
40 |
+
"""Update the state with new diarization data."""
|
41 |
+
async with self.lock:
|
42 |
+
self.end_attributed_speaker = end_attributed_speaker
|
43 |
+
if buffer_diarization:
|
44 |
+
self.buffer_diarization = buffer_diarization
|
45 |
+
|
46 |
+
async def add_dummy_token(self) -> None:
|
47 |
+
"""Add a dummy token to keep the state updated even without transcription."""
|
48 |
+
async with self.lock:
|
49 |
+
current_time = time() - self.beg_loop
|
50 |
+
dummy_token = ASRToken(
|
51 |
+
start=current_time,
|
52 |
+
end=current_time + 1,
|
53 |
+
text=".",
|
54 |
+
speaker=-1,
|
55 |
+
is_dummy=True
|
56 |
+
)
|
57 |
+
self.tokens.append(dummy_token)
|
58 |
+
|
59 |
+
async def get_current_state(self) -> Dict[str, Any]:
|
60 |
+
"""Get the current state with calculated timing information."""
|
61 |
+
async with self.lock:
|
62 |
+
current_time = time()
|
63 |
+
remaining_time_transcription = 0
|
64 |
+
remaining_time_diarization = 0
|
65 |
+
|
66 |
+
# Calculate remaining time for transcription buffer
|
67 |
+
if self.end_buffer > 0:
|
68 |
+
remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
69 |
+
|
70 |
+
# Calculate remaining time for diarization
|
71 |
+
if self.tokens:
|
72 |
+
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
73 |
+
remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
74 |
+
|
75 |
+
return {
|
76 |
+
"tokens": self.tokens.copy(),
|
77 |
+
"buffer_transcription": self.buffer_transcription,
|
78 |
+
"buffer_diarization": self.buffer_diarization,
|
79 |
+
"end_buffer": self.end_buffer,
|
80 |
+
"end_attributed_speaker": self.end_attributed_speaker,
|
81 |
+
"sep": self.sep,
|
82 |
+
"remaining_time_transcription": remaining_time_transcription,
|
83 |
+
"remaining_time_diarization": remaining_time_diarization
|
84 |
+
}
|
85 |
+
|
86 |
+
async def reset(self) -> None:
|
87 |
+
"""Reset the state to initial values."""
|
88 |
+
async with self.lock:
|
89 |
+
self.tokens = []
|
90 |
+
self.buffer_transcription = ""
|
91 |
+
self.buffer_diarization = ""
|
92 |
+
self.end_buffer = 0
|
93 |
+
self.end_attributed_speaker = 0
|
94 |
+
self.full_transcription = ""
|
95 |
+
self.beg_loop = time()
|
96 |
+
self.last_response_content = ""
|
whisper_fastapi_online_server.py
CHANGED
@@ -17,147 +17,28 @@ import math
|
|
17 |
import logging
|
18 |
from datetime import timedelta
|
19 |
import traceback
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
|
25 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
26 |
logging.getLogger().setLevel(logging.WARNING)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
logger.setLevel(logging.DEBUG)
|
29 |
|
30 |
-
##### LOAD ARGS #####
|
31 |
-
|
32 |
-
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
|
33 |
-
parser.add_argument(
|
34 |
-
"--host",
|
35 |
-
type=str,
|
36 |
-
default="localhost",
|
37 |
-
help="The host address to bind the server to.",
|
38 |
-
)
|
39 |
-
parser.add_argument(
|
40 |
-
"--port", type=int, default=8000, help="The port number to bind the server to."
|
41 |
-
)
|
42 |
-
parser.add_argument(
|
43 |
-
"--warmup-file",
|
44 |
-
type=str,
|
45 |
-
default=None,
|
46 |
-
dest="warmup_file",
|
47 |
-
help="""
|
48 |
-
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
49 |
-
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
50 |
-
If False, no warmup is performed.
|
51 |
-
""",
|
52 |
-
)
|
53 |
|
54 |
-
parser.add_argument(
|
55 |
-
"--confidence-validation",
|
56 |
-
type=bool,
|
57 |
-
default=False,
|
58 |
-
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
59 |
-
)
|
60 |
-
|
61 |
-
parser.add_argument(
|
62 |
-
"--diarization",
|
63 |
-
type=bool,
|
64 |
-
default=False,
|
65 |
-
help="Whether to enable speaker diarization.",
|
66 |
-
)
|
67 |
-
|
68 |
-
parser.add_argument(
|
69 |
-
"--transcription",
|
70 |
-
type=bool,
|
71 |
-
default=True,
|
72 |
-
help="To disable to only see live diarization results.",
|
73 |
-
)
|
74 |
|
75 |
-
|
76 |
-
args = parser.parse_args()
|
77 |
|
78 |
SAMPLE_RATE = 16000
|
79 |
-
CHANNELS = 1
|
80 |
-
SAMPLES_PER_SEC = SAMPLE_RATE *
|
81 |
-
BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
82 |
-
BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
|
83 |
-
MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
|
84 |
|
85 |
|
86 |
-
class SharedState:
|
87 |
-
def __init__(self):
|
88 |
-
self.tokens = []
|
89 |
-
self.buffer_transcription = ""
|
90 |
-
self.buffer_diarization = ""
|
91 |
-
self.full_transcription = ""
|
92 |
-
self.end_buffer = 0
|
93 |
-
self.end_attributed_speaker = 0
|
94 |
-
self.lock = asyncio.Lock()
|
95 |
-
self.beg_loop = time()
|
96 |
-
self.sep = " " # Default separator
|
97 |
-
self.last_response_content = "" # To track changes in response
|
98 |
-
|
99 |
-
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
100 |
-
async with self.lock:
|
101 |
-
self.tokens.extend(new_tokens)
|
102 |
-
self.buffer_transcription = buffer
|
103 |
-
self.end_buffer = end_buffer
|
104 |
-
self.full_transcription = full_transcription
|
105 |
-
self.sep = sep
|
106 |
-
|
107 |
-
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
108 |
-
async with self.lock:
|
109 |
-
self.end_attributed_speaker = end_attributed_speaker
|
110 |
-
if buffer_diarization:
|
111 |
-
self.buffer_diarization = buffer_diarization
|
112 |
-
|
113 |
-
async def add_dummy_token(self):
|
114 |
-
async with self.lock:
|
115 |
-
current_time = time() - self.beg_loop
|
116 |
-
dummy_token = ASRToken(
|
117 |
-
start=current_time,
|
118 |
-
end=current_time + 1,
|
119 |
-
text=".",
|
120 |
-
speaker=-1,
|
121 |
-
is_dummy=True
|
122 |
-
)
|
123 |
-
self.tokens.append(dummy_token)
|
124 |
-
|
125 |
-
async def get_current_state(self):
|
126 |
-
async with self.lock:
|
127 |
-
current_time = time()
|
128 |
-
remaining_time_transcription = 0
|
129 |
-
remaining_time_diarization = 0
|
130 |
-
|
131 |
-
# Calculate remaining time for transcription buffer
|
132 |
-
if self.end_buffer > 0:
|
133 |
-
remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
134 |
-
|
135 |
-
# Calculate remaining time for diarization
|
136 |
-
remaining_time_diarization = max(0, round(max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) - self.end_attributed_speaker, 2))
|
137 |
-
|
138 |
-
return {
|
139 |
-
"tokens": self.tokens.copy(),
|
140 |
-
"buffer_transcription": self.buffer_transcription,
|
141 |
-
"buffer_diarization": self.buffer_diarization,
|
142 |
-
"end_buffer": self.end_buffer,
|
143 |
-
"end_attributed_speaker": self.end_attributed_speaker,
|
144 |
-
"sep": self.sep,
|
145 |
-
"remaining_time_transcription": remaining_time_transcription,
|
146 |
-
"remaining_time_diarization": remaining_time_diarization
|
147 |
-
}
|
148 |
-
|
149 |
-
async def reset(self):
|
150 |
-
"""Reset the state."""
|
151 |
-
async with self.lock:
|
152 |
-
self.tokens = []
|
153 |
-
self.buffer_transcription = ""
|
154 |
-
self.buffer_diarization = ""
|
155 |
-
self.end_buffer = 0
|
156 |
-
self.end_attributed_speaker = 0
|
157 |
-
self.full_transcription = ""
|
158 |
-
self.beg_loop = time()
|
159 |
-
self.last_response_content = ""
|
160 |
-
|
161 |
##### LOAD APP #####
|
162 |
|
163 |
@asynccontextmanager
|
@@ -190,300 +71,45 @@ app.add_middleware(
|
|
190 |
with open("web/live_transcription.html", "r", encoding="utf-8") as f:
|
191 |
html = f.read()
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format
|
197 |
-
Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0
|
198 |
-
"""
|
199 |
-
pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
|
200 |
-
/ 32768.0)
|
201 |
-
return pcm_array
|
202 |
|
203 |
-
async def start_ffmpeg_decoder():
|
204 |
-
"""
|
205 |
-
Start an FFmpeg process in async streaming mode that reads WebM from stdin
|
206 |
-
and outputs raw s16le PCM on stdout. Returns the process object.
|
207 |
-
"""
|
208 |
-
process = (
|
209 |
-
ffmpeg.input("pipe:0", format="webm")
|
210 |
-
.output(
|
211 |
-
"pipe:1",
|
212 |
-
format="s16le",
|
213 |
-
acodec="pcm_s16le",
|
214 |
-
ac=CHANNELS,
|
215 |
-
ar=str(SAMPLE_RATE),
|
216 |
-
)
|
217 |
-
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
|
218 |
-
)
|
219 |
-
return process
|
220 |
|
221 |
-
async def transcription_processor(shared_state, pcm_queue, online):
|
222 |
-
full_transcription = ""
|
223 |
-
sep = online.asr.sep
|
224 |
-
|
225 |
-
while True:
|
226 |
-
try:
|
227 |
-
pcm_array = await pcm_queue.get()
|
228 |
-
|
229 |
-
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
230 |
-
|
231 |
-
# Process transcription
|
232 |
-
online.insert_audio_chunk(pcm_array)
|
233 |
-
new_tokens = online.process_iter()
|
234 |
-
|
235 |
-
if new_tokens:
|
236 |
-
full_transcription += sep.join([t.text for t in new_tokens])
|
237 |
-
|
238 |
-
_buffer = online.get_buffer()
|
239 |
-
buffer = _buffer.text
|
240 |
-
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
241 |
-
|
242 |
-
if buffer in full_transcription:
|
243 |
-
buffer = ""
|
244 |
-
|
245 |
-
await shared_state.update_transcription(
|
246 |
-
new_tokens, buffer, end_buffer, full_transcription, sep)
|
247 |
-
|
248 |
-
except Exception as e:
|
249 |
-
logger.warning(f"Exception in transcription_processor: {e}")
|
250 |
-
logger.warning(f"Traceback: {traceback.format_exc()}")
|
251 |
-
finally:
|
252 |
-
pcm_queue.task_done()
|
253 |
|
254 |
-
async def diarization_processor(shared_state, pcm_queue, diarization_obj):
|
255 |
-
buffer_diarization = ""
|
256 |
-
|
257 |
-
while True:
|
258 |
-
try:
|
259 |
-
pcm_array = await pcm_queue.get()
|
260 |
-
|
261 |
-
# Process diarization
|
262 |
-
await diarization_obj.diarize(pcm_array)
|
263 |
-
|
264 |
-
# Get current state
|
265 |
-
state = await shared_state.get_current_state()
|
266 |
-
tokens = state["tokens"]
|
267 |
-
end_attributed_speaker = state["end_attributed_speaker"]
|
268 |
-
|
269 |
-
# Update speaker information
|
270 |
-
new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens(
|
271 |
-
end_attributed_speaker, tokens)
|
272 |
-
|
273 |
-
await shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization)
|
274 |
-
|
275 |
-
except Exception as e:
|
276 |
-
logger.warning(f"Exception in diarization_processor: {e}")
|
277 |
-
logger.warning(f"Traceback: {traceback.format_exc()}")
|
278 |
-
finally:
|
279 |
-
pcm_queue.task_done()
|
280 |
|
281 |
-
async def results_formatter(shared_state, websocket):
|
282 |
-
while True:
|
283 |
-
try:
|
284 |
-
# Get the current state
|
285 |
-
state = await shared_state.get_current_state()
|
286 |
-
tokens = state["tokens"]
|
287 |
-
buffer_transcription = state["buffer_transcription"]
|
288 |
-
buffer_diarization = state["buffer_diarization"]
|
289 |
-
end_attributed_speaker = state["end_attributed_speaker"]
|
290 |
-
remaining_time_transcription = state["remaining_time_transcription"]
|
291 |
-
remaining_time_diarization = state["remaining_time_diarization"]
|
292 |
-
sep = state["sep"]
|
293 |
-
|
294 |
-
# If diarization is enabled but no transcription, add dummy tokens periodically
|
295 |
-
if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization:
|
296 |
-
await shared_state.add_dummy_token()
|
297 |
-
sleep(0.5)
|
298 |
-
state = await shared_state.get_current_state()
|
299 |
-
tokens = state["tokens"]
|
300 |
-
# Process tokens to create response
|
301 |
-
previous_speaker = -1
|
302 |
-
lines = []
|
303 |
-
last_end_diarized = 0
|
304 |
-
undiarized_text = []
|
305 |
-
|
306 |
-
for token in tokens:
|
307 |
-
speaker = token.speaker
|
308 |
-
if args.diarization:
|
309 |
-
if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
|
310 |
-
undiarized_text.append(token.text)
|
311 |
-
continue
|
312 |
-
elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
|
313 |
-
speaker = previous_speaker
|
314 |
-
if speaker not in [-1, 0]:
|
315 |
-
last_end_diarized = max(token.end, last_end_diarized)
|
316 |
-
|
317 |
-
if speaker != previous_speaker or not lines:
|
318 |
-
lines.append(
|
319 |
-
{
|
320 |
-
"speaker": speaker,
|
321 |
-
"text": token.text,
|
322 |
-
"beg": format_time(token.start),
|
323 |
-
"end": format_time(token.end),
|
324 |
-
"diff": round(token.end - last_end_diarized, 2)
|
325 |
-
}
|
326 |
-
)
|
327 |
-
previous_speaker = speaker
|
328 |
-
elif token.text: # Only append if text isn't empty
|
329 |
-
lines[-1]["text"] += sep + token.text
|
330 |
-
lines[-1]["end"] = format_time(token.end)
|
331 |
-
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
332 |
-
|
333 |
-
if undiarized_text:
|
334 |
-
combined_buffer_diarization = sep.join(undiarized_text)
|
335 |
-
if buffer_transcription:
|
336 |
-
combined_buffer_diarization += sep
|
337 |
-
await shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
|
338 |
-
buffer_diarization = combined_buffer_diarization
|
339 |
-
|
340 |
-
if lines:
|
341 |
-
response = {
|
342 |
-
"lines": lines,
|
343 |
-
"buffer_transcription": buffer_transcription,
|
344 |
-
"buffer_diarization": buffer_diarization,
|
345 |
-
"remaining_time_transcription": remaining_time_transcription,
|
346 |
-
"remaining_time_diarization": remaining_time_diarization
|
347 |
-
}
|
348 |
-
else:
|
349 |
-
response = {
|
350 |
-
"lines": [{
|
351 |
-
"speaker": 1,
|
352 |
-
"text": "",
|
353 |
-
"beg": format_time(0),
|
354 |
-
"end": format_time(tokens[-1].end) if tokens else format_time(0),
|
355 |
-
"diff": 0
|
356 |
-
}],
|
357 |
-
"buffer_transcription": buffer_transcription,
|
358 |
-
"buffer_diarization": buffer_diarization,
|
359 |
-
"remaining_time_transcription": remaining_time_transcription,
|
360 |
-
"remaining_time_diarization": remaining_time_diarization
|
361 |
|
362 |
-
}
|
363 |
-
|
364 |
-
response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
|
365 |
-
|
366 |
-
if response_content != shared_state.last_response_content:
|
367 |
-
if lines or buffer_transcription or buffer_diarization:
|
368 |
-
await websocket.send_json(response)
|
369 |
-
shared_state.last_response_content = response_content
|
370 |
-
|
371 |
-
# Add a small delay to avoid overwhelming the client
|
372 |
-
await asyncio.sleep(0.1)
|
373 |
-
|
374 |
-
except Exception as e:
|
375 |
-
logger.warning(f"Exception in results_formatter: {e}")
|
376 |
-
logger.warning(f"Traceback: {traceback.format_exc()}")
|
377 |
-
await asyncio.sleep(0.5) # Back off on error
|
378 |
|
379 |
-
##### ENDPOINTS #####
|
380 |
|
381 |
-
@app.get("/")
|
382 |
-
async def get():
|
383 |
-
return HTMLResponse(html)
|
384 |
|
385 |
@app.websocket("/asr")
|
386 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
|
387 |
await websocket.accept()
|
388 |
logger.info("WebSocket connection opened.")
|
389 |
|
390 |
ffmpeg_process = None
|
391 |
pcm_buffer = bytearray()
|
392 |
-
shared_state = SharedState()
|
393 |
|
394 |
transcription_queue = asyncio.Queue() if args.transcription else None
|
395 |
diarization_queue = asyncio.Queue() if args.diarization else None
|
396 |
|
397 |
online = None
|
398 |
|
399 |
-
|
400 |
-
nonlocal ffmpeg_process, online, pcm_buffer
|
401 |
-
if ffmpeg_process:
|
402 |
-
try:
|
403 |
-
ffmpeg_process.kill()
|
404 |
-
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
|
405 |
-
except Exception as e:
|
406 |
-
logger.warning(f"Error killing FFmpeg process: {e}")
|
407 |
-
ffmpeg_process = await start_ffmpeg_decoder()
|
408 |
-
pcm_buffer = bytearray()
|
409 |
-
|
410 |
-
if args.transcription:
|
411 |
-
online = online_factory(args, asr, tokenizer)
|
412 |
-
|
413 |
-
await shared_state.reset()
|
414 |
-
logger.info("FFmpeg process started.")
|
415 |
-
|
416 |
-
await restart_ffmpeg()
|
417 |
-
|
418 |
tasks = []
|
419 |
if args.transcription and online:
|
420 |
tasks.append(asyncio.create_task(
|
421 |
-
transcription_processor(
|
422 |
if args.diarization and diarization:
|
423 |
tasks.append(asyncio.create_task(
|
424 |
-
diarization_processor(
|
425 |
-
formatter_task = asyncio.create_task(results_formatter(
|
426 |
tasks.append(formatter_task)
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
loop = asyncio.get_event_loop()
|
431 |
-
beg = time()
|
432 |
-
|
433 |
-
while True:
|
434 |
-
try:
|
435 |
-
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
436 |
-
ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096)
|
437 |
-
beg = time()
|
438 |
-
|
439 |
-
# Read chunk with timeout
|
440 |
-
try:
|
441 |
-
chunk = await asyncio.wait_for(
|
442 |
-
loop.run_in_executor(
|
443 |
-
None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
444 |
-
),
|
445 |
-
timeout=15.0
|
446 |
-
)
|
447 |
-
except asyncio.TimeoutError:
|
448 |
-
logger.warning("FFmpeg read timeout. Restarting...")
|
449 |
-
await restart_ffmpeg()
|
450 |
-
beg = time()
|
451 |
-
continue # Skip processing and read from new process
|
452 |
-
|
453 |
-
if not chunk:
|
454 |
-
logger.info("FFmpeg stdout closed.")
|
455 |
-
break
|
456 |
-
pcm_buffer.extend(chunk)
|
457 |
-
|
458 |
-
if args.diarization and diarization_queue:
|
459 |
-
await diarization_queue.put(convert_pcm_to_float(pcm_buffer).copy())
|
460 |
-
|
461 |
-
if len(pcm_buffer) >= BYTES_PER_SEC:
|
462 |
-
if len(pcm_buffer) > MAX_BYTES_PER_SEC:
|
463 |
-
logger.warning(
|
464 |
-
f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds.
|
465 |
-
The model probably struggles to keep up. Consider using a smaller model.
|
466 |
-
""")
|
467 |
-
|
468 |
-
pcm_array = convert_pcm_to_float(pcm_buffer[:MAX_BYTES_PER_SEC])
|
469 |
-
pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
|
470 |
-
|
471 |
-
if args.transcription and transcription_queue:
|
472 |
-
await transcription_queue.put(pcm_array.copy())
|
473 |
-
|
474 |
-
|
475 |
-
if not args.transcription and not args.diarization:
|
476 |
-
await asyncio.sleep(0.1)
|
477 |
-
|
478 |
-
except Exception as e:
|
479 |
-
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
480 |
-
logger.warning(f"Traceback: {traceback.format_exc()}")
|
481 |
-
break
|
482 |
-
|
483 |
-
logger.info("Exiting ffmpeg_stdout_reader...")
|
484 |
-
|
485 |
-
stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
|
486 |
-
tasks.append(stdout_reader_task)
|
487 |
try:
|
488 |
while True:
|
489 |
# Receive incoming WebM audio chunks from the client
|
@@ -493,7 +119,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
493 |
ffmpeg_process.stdin.flush()
|
494 |
except (BrokenPipeError, AttributeError) as e:
|
495 |
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
496 |
-
await restart_ffmpeg()
|
497 |
ffmpeg_process.stdin.write(message)
|
498 |
ffmpeg_process.stdin.flush()
|
499 |
except WebSocketDisconnect:
|
@@ -501,17 +127,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
501 |
finally:
|
502 |
for task in tasks:
|
503 |
task.cancel()
|
504 |
-
|
505 |
try:
|
506 |
await asyncio.gather(*tasks, return_exceptions=True)
|
507 |
ffmpeg_process.stdin.close()
|
508 |
ffmpeg_process.wait()
|
509 |
except Exception as e:
|
510 |
logger.warning(f"Error during cleanup: {e}")
|
511 |
-
|
512 |
if args.diarization and diarization:
|
513 |
diarization.close()
|
514 |
-
|
515 |
logger.info("WebSocket endpoint cleaned up.")
|
516 |
|
517 |
if __name__ == "__main__":
|
|
|
17 |
import logging
|
18 |
from datetime import timedelta
|
19 |
import traceback
|
20 |
+
from state import SharedState
|
21 |
+
from formatters import format_time
|
22 |
+
from parse_args import parse_args
|
23 |
+
from audio import AudioProcessor
|
24 |
|
25 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
26 |
logging.getLogger().setLevel(logging.WARNING)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
logger.setLevel(logging.DEBUG)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
args = parse_args()
|
|
|
33 |
|
34 |
SAMPLE_RATE = 16000
|
35 |
+
# CHANNELS = 1
|
36 |
+
# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
|
37 |
+
# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
38 |
+
# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
|
39 |
+
# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
##### LOAD APP #####
|
43 |
|
44 |
@asynccontextmanager
|
|
|
71 |
with open("web/live_transcription.html", "r", encoding="utf-8") as f:
|
72 |
html = f.read()
|
73 |
|
74 |
+
@app.get("/")
|
75 |
+
async def get():
|
76 |
+
return HTMLResponse(html)
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
83 |
|
|
|
|
|
|
|
84 |
|
85 |
@app.websocket("/asr")
|
86 |
async def websocket_endpoint(websocket: WebSocket):
|
87 |
+
audio_processor = AudioProcessor(args, asr, tokenizer)
|
88 |
+
|
89 |
await websocket.accept()
|
90 |
logger.info("WebSocket connection opened.")
|
91 |
|
92 |
ffmpeg_process = None
|
93 |
pcm_buffer = bytearray()
|
|
|
94 |
|
95 |
transcription_queue = asyncio.Queue() if args.transcription else None
|
96 |
diarization_queue = asyncio.Queue() if args.diarization else None
|
97 |
|
98 |
online = None
|
99 |
|
100 |
+
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
tasks = []
|
102 |
if args.transcription and online:
|
103 |
tasks.append(asyncio.create_task(
|
104 |
+
audio_processor.transcription_processor(transcription_queue, online)))
|
105 |
if args.diarization and diarization:
|
106 |
tasks.append(asyncio.create_task(
|
107 |
+
audio_processor.diarization_processor(diarization_queue, diarization)))
|
108 |
+
formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket))
|
109 |
tasks.append(formatter_task)
|
110 |
+
stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue))
|
111 |
+
tasks.append(stdout_reader_task)
|
112 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
try:
|
114 |
while True:
|
115 |
# Receive incoming WebM audio chunks from the client
|
|
|
119 |
ffmpeg_process.stdin.flush()
|
120 |
except (BrokenPipeError, AttributeError) as e:
|
121 |
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
122 |
+
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
123 |
ffmpeg_process.stdin.write(message)
|
124 |
ffmpeg_process.stdin.flush()
|
125 |
except WebSocketDisconnect:
|
|
|
127 |
finally:
|
128 |
for task in tasks:
|
129 |
task.cancel()
|
|
|
130 |
try:
|
131 |
await asyncio.gather(*tasks, return_exceptions=True)
|
132 |
ffmpeg_process.stdin.close()
|
133 |
ffmpeg_process.wait()
|
134 |
except Exception as e:
|
135 |
logger.warning(f"Error during cleanup: {e}")
|
|
|
136 |
if args.diarization and diarization:
|
137 |
diarization.close()
|
|
|
138 |
logger.info("WebSocket endpoint cleaned up.")
|
139 |
|
140 |
if __name__ == "__main__":
|
whisper_streaming_custom/whisper_online.py
CHANGED
@@ -71,7 +71,7 @@ def add_shared_args(parser):
|
|
71 |
parser.add_argument(
|
72 |
"--min-chunk-size",
|
73 |
type=float,
|
74 |
-
default=
|
75 |
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
76 |
)
|
77 |
parser.add_argument(
|
|
|
71 |
parser.add_argument(
|
72 |
"--min-chunk-size",
|
73 |
type=float,
|
74 |
+
default=0.5,
|
75 |
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
76 |
)
|
77 |
parser.add_argument(
|