Refactor DiartDiarization initialization and streamline WebSocket audio processing
Browse files- audio.py +82 -54
- diarization/diarization_online.py +1 -1
- whisper_fastapi_online_server.py +6 -74
audio.py
CHANGED
@@ -1,25 +1,15 @@
|
|
1 |
-
import io
|
2 |
-
import argparse
|
3 |
import asyncio
|
4 |
import numpy as np
|
5 |
import ffmpeg
|
6 |
from time import time, sleep
|
7 |
-
from contextlib import asynccontextmanager
|
8 |
|
9 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
10 |
-
from fastapi.responses import HTMLResponse
|
11 |
-
from fastapi.middleware.cors import CORSMiddleware
|
12 |
-
|
13 |
-
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
|
14 |
-
from timed_objects import ASRToken
|
15 |
|
|
|
16 |
import math
|
17 |
import logging
|
18 |
-
from datetime import timedelta
|
19 |
import traceback
|
20 |
from state import SharedState
|
21 |
from formatters import format_time
|
22 |
-
from parse_args import parse_args
|
23 |
|
24 |
|
25 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
@@ -27,7 +17,6 @@ logging.getLogger().setLevel(logging.WARNING)
|
|
27 |
logger = logging.getLogger(__name__)
|
28 |
logger.setLevel(logging.DEBUG)
|
29 |
|
30 |
-
|
31 |
class AudioProcessor:
|
32 |
|
33 |
def __init__(self, args, asr, tokenizer):
|
@@ -38,9 +27,22 @@ class AudioProcessor:
|
|
38 |
self.bytes_per_sample = 2
|
39 |
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
40 |
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
|
|
|
|
41 |
self.shared_state = SharedState()
|
42 |
self.asr = asr
|
43 |
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def convert_pcm_to_float(self, pcm_buffer):
|
46 |
"""
|
@@ -70,26 +72,17 @@ class AudioProcessor:
|
|
70 |
)
|
71 |
return process
|
72 |
|
73 |
-
async def restart_ffmpeg(self
|
74 |
-
if ffmpeg_process:
|
75 |
try:
|
76 |
-
ffmpeg_process.kill()
|
77 |
-
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
|
78 |
except Exception as e:
|
79 |
logger.warning(f"Error killing FFmpeg process: {e}")
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
if self.args.transcription:
|
84 |
-
online = online_factory(self.args, self.asr, self.tokenizer)
|
85 |
-
|
86 |
-
await self.shared_state.reset()
|
87 |
-
logger.info("FFmpeg process started.")
|
88 |
-
return ffmpeg_process, online, pcm_buffer
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
async def ffmpeg_stdout_reader(self
|
93 |
loop = asyncio.get_event_loop()
|
94 |
beg = time()
|
95 |
|
@@ -103,36 +96,36 @@ class AudioProcessor:
|
|
103 |
try:
|
104 |
chunk = await asyncio.wait_for(
|
105 |
loop.run_in_executor(
|
106 |
-
None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
107 |
),
|
108 |
timeout=15.0
|
109 |
)
|
110 |
except asyncio.TimeoutError:
|
111 |
logger.warning("FFmpeg read timeout. Restarting...")
|
112 |
-
|
113 |
beg = time()
|
114 |
continue # Skip processing and read from new process
|
115 |
|
116 |
if not chunk:
|
117 |
logger.info("FFmpeg stdout closed.")
|
118 |
break
|
119 |
-
pcm_buffer.extend(chunk)
|
120 |
|
121 |
-
if self.args.diarization and diarization_queue:
|
122 |
-
await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy())
|
123 |
|
124 |
-
if len(pcm_buffer) >= self.bytes_per_sec:
|
125 |
-
if len(pcm_buffer) > self.max_bytes_per_sec:
|
126 |
logger.warning(
|
127 |
-
f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
128 |
The model probably struggles to keep up. Consider using a smaller model.
|
129 |
""")
|
130 |
|
131 |
-
pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec])
|
132 |
-
pcm_buffer = pcm_buffer[self.max_bytes_per_sec:]
|
133 |
|
134 |
-
if self.args.transcription and transcription_queue:
|
135 |
-
await transcription_queue.put(pcm_array.copy())
|
136 |
|
137 |
|
138 |
if not self.args.transcription and not self.args.diarization:
|
@@ -144,27 +137,24 @@ class AudioProcessor:
|
|
144 |
break
|
145 |
logger.info("Exiting ffmpeg_stdout_reader...")
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
async def transcription_processor(self, pcm_queue, online):
|
151 |
full_transcription = ""
|
152 |
-
sep = online.asr.sep
|
153 |
|
154 |
while True:
|
155 |
try:
|
156 |
-
pcm_array = await
|
157 |
|
158 |
-
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
159 |
|
160 |
# Process transcription
|
161 |
-
online.insert_audio_chunk(pcm_array)
|
162 |
-
new_tokens = online.process_iter()
|
163 |
|
164 |
if new_tokens:
|
165 |
full_transcription += sep.join([t.text for t in new_tokens])
|
166 |
|
167 |
-
_buffer = online.get_buffer()
|
168 |
buffer = _buffer.text
|
169 |
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
170 |
|
@@ -178,14 +168,15 @@ class AudioProcessor:
|
|
178 |
logger.warning(f"Exception in transcription_processor: {e}")
|
179 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
180 |
finally:
|
181 |
-
|
182 |
|
183 |
-
|
|
|
184 |
buffer_diarization = ""
|
185 |
|
186 |
while True:
|
187 |
try:
|
188 |
-
pcm_array = await
|
189 |
|
190 |
# Process diarization
|
191 |
await diarization_obj.diarize(pcm_array)
|
@@ -205,7 +196,7 @@ class AudioProcessor:
|
|
205 |
logger.warning(f"Exception in diarization_processor: {e}")
|
206 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
207 |
finally:
|
208 |
-
|
209 |
|
210 |
async def results_formatter(self, websocket):
|
211 |
while True:
|
@@ -304,3 +295,40 @@ class AudioProcessor:
|
|
304 |
logger.warning(f"Exception in results_formatter: {e}")
|
305 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
306 |
await asyncio.sleep(0.5) # Back off on error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import numpy as np
|
3 |
import ffmpeg
|
4 |
from time import time, sleep
|
|
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
from whisper_streaming_custom.whisper_online import online_factory
|
8 |
import math
|
9 |
import logging
|
|
|
10 |
import traceback
|
11 |
from state import SharedState
|
12 |
from formatters import format_time
|
|
|
13 |
|
14 |
|
15 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
17 |
logger = logging.getLogger(__name__)
|
18 |
logger.setLevel(logging.DEBUG)
|
19 |
|
|
|
20 |
class AudioProcessor:
|
21 |
|
22 |
def __init__(self, args, asr, tokenizer):
|
|
|
27 |
self.bytes_per_sample = 2
|
28 |
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
29 |
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
30 |
+
|
31 |
+
|
32 |
self.shared_state = SharedState()
|
33 |
self.asr = asr
|
34 |
self.tokenizer = tokenizer
|
35 |
+
|
36 |
+
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
37 |
+
|
38 |
+
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
39 |
+
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
40 |
+
|
41 |
+
self.pcm_buffer = bytearray()
|
42 |
+
if self.args.transcription:
|
43 |
+
self.online = online_factory(self.args, self.asr, self.tokenizer)
|
44 |
+
|
45 |
+
|
46 |
|
47 |
def convert_pcm_to_float(self, pcm_buffer):
|
48 |
"""
|
|
|
72 |
)
|
73 |
return process
|
74 |
|
75 |
+
async def restart_ffmpeg(self):
|
76 |
+
if self.ffmpeg_process:
|
77 |
try:
|
78 |
+
self.ffmpeg_process.kill()
|
79 |
+
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
80 |
except Exception as e:
|
81 |
logger.warning(f"Error killing FFmpeg process: {e}")
|
82 |
+
self.ffmpeg_process = await self.start_ffmpeg_decoder()
|
83 |
+
self.pcm_buffer = bytearray()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
async def ffmpeg_stdout_reader(self):
|
86 |
loop = asyncio.get_event_loop()
|
87 |
beg = time()
|
88 |
|
|
|
96 |
try:
|
97 |
chunk = await asyncio.wait_for(
|
98 |
loop.run_in_executor(
|
99 |
+
None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
100 |
),
|
101 |
timeout=15.0
|
102 |
)
|
103 |
except asyncio.TimeoutError:
|
104 |
logger.warning("FFmpeg read timeout. Restarting...")
|
105 |
+
await self.restart_ffmpeg()
|
106 |
beg = time()
|
107 |
continue # Skip processing and read from new process
|
108 |
|
109 |
if not chunk:
|
110 |
logger.info("FFmpeg stdout closed.")
|
111 |
break
|
112 |
+
self.pcm_buffer.extend(chunk)
|
113 |
|
114 |
+
if self.args.diarization and self.diarization_queue:
|
115 |
+
await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy())
|
116 |
|
117 |
+
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
118 |
+
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
119 |
logger.warning(
|
120 |
+
f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
121 |
The model probably struggles to keep up. Consider using a smaller model.
|
122 |
""")
|
123 |
|
124 |
+
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
125 |
+
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
126 |
|
127 |
+
if self.args.transcription and self.transcription_queue:
|
128 |
+
await self.transcription_queue.put(pcm_array.copy())
|
129 |
|
130 |
|
131 |
if not self.args.transcription and not self.args.diarization:
|
|
|
137 |
break
|
138 |
logger.info("Exiting ffmpeg_stdout_reader...")
|
139 |
|
140 |
+
async def transcription_processor(self):
|
|
|
|
|
|
|
141 |
full_transcription = ""
|
142 |
+
sep = self.online.asr.sep
|
143 |
|
144 |
while True:
|
145 |
try:
|
146 |
+
pcm_array = await self.transcription_queue.get()
|
147 |
|
148 |
+
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
149 |
|
150 |
# Process transcription
|
151 |
+
self.online.insert_audio_chunk(pcm_array)
|
152 |
+
new_tokens = self.online.process_iter()
|
153 |
|
154 |
if new_tokens:
|
155 |
full_transcription += sep.join([t.text for t in new_tokens])
|
156 |
|
157 |
+
_buffer = self.online.get_buffer()
|
158 |
buffer = _buffer.text
|
159 |
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
160 |
|
|
|
168 |
logger.warning(f"Exception in transcription_processor: {e}")
|
169 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
170 |
finally:
|
171 |
+
self.transcription_queue.task_done()
|
172 |
|
173 |
+
|
174 |
+
async def diarization_processor(self, diarization_obj):
|
175 |
buffer_diarization = ""
|
176 |
|
177 |
while True:
|
178 |
try:
|
179 |
+
pcm_array = await self.diarization_queue.get()
|
180 |
|
181 |
# Process diarization
|
182 |
await diarization_obj.diarize(pcm_array)
|
|
|
196 |
logger.warning(f"Exception in diarization_processor: {e}")
|
197 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
198 |
finally:
|
199 |
+
self.diarization_queue.task_done()
|
200 |
|
201 |
async def results_formatter(self, websocket):
|
202 |
while True:
|
|
|
295 |
logger.warning(f"Exception in results_formatter: {e}")
|
296 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
297 |
await asyncio.sleep(0.5) # Back off on error
|
298 |
+
|
299 |
+
async def create_tasks(self, websocket, diarization):
|
300 |
+
tasks = []
|
301 |
+
if self.args.transcription and self.online:
|
302 |
+
tasks.append(asyncio.create_task(self.transcription_processor()))
|
303 |
+
if self.args.diarization and diarization:
|
304 |
+
tasks.append(asyncio.create_task(self.diarization_processor(diarization)))
|
305 |
+
formatter_task = asyncio.create_task(self.results_formatter(websocket))
|
306 |
+
tasks.append(formatter_task)
|
307 |
+
stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
308 |
+
tasks.append(stdout_reader_task)
|
309 |
+
self.tasks = tasks
|
310 |
+
self.diarization = diarization
|
311 |
+
|
312 |
+
async def cleanup(self):
|
313 |
+
for task in self.tasks:
|
314 |
+
task.cancel()
|
315 |
+
try:
|
316 |
+
await asyncio.gather(*self.tasks, return_exceptions=True)
|
317 |
+
self.ffmpeg_process.stdin.close()
|
318 |
+
self.ffmpeg_process.wait()
|
319 |
+
except Exception as e:
|
320 |
+
logger.warning(f"Error during cleanup: {e}")
|
321 |
+
if self.args.diarization and self.diarization:
|
322 |
+
self.diarization.close()
|
323 |
+
|
324 |
+
async def process_audio(self, message):
|
325 |
+
try:
|
326 |
+
self.ffmpeg_process.stdin.write(message)
|
327 |
+
self.ffmpeg_process.stdin.flush()
|
328 |
+
except (BrokenPipeError, AttributeError) as e:
|
329 |
+
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
330 |
+
await self.restart_ffmpeg()
|
331 |
+
self.ffmpeg_process.stdin.write(message)
|
332 |
+
self.ffmpeg_process.stdin.flush()
|
333 |
+
|
334 |
+
|
diarization/diarization_online.py
CHANGED
@@ -103,7 +103,7 @@ class WebSocketAudioSource(AudioSource):
|
|
103 |
|
104 |
|
105 |
class DiartDiarization:
|
106 |
-
def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
107 |
self.pipeline = SpeakerDiarization(config=config)
|
108 |
self.observer = DiarizationObserver()
|
109 |
|
|
|
103 |
|
104 |
|
105 |
class DiartDiarization:
|
106 |
+
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
107 |
self.pipeline = SpeakerDiarization(config=config)
|
108 |
self.observer = DiarizationObserver()
|
109 |
|
whisper_fastapi_online_server.py
CHANGED
@@ -1,24 +1,11 @@
|
|
1 |
-
import io
|
2 |
-
import argparse
|
3 |
-
import asyncio
|
4 |
-
import numpy as np
|
5 |
-
import ffmpeg
|
6 |
-
from time import time, sleep
|
7 |
from contextlib import asynccontextmanager
|
8 |
|
9 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
10 |
from fastapi.responses import HTMLResponse
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
|
13 |
-
from whisper_streaming_custom.whisper_online import backend_factory,
|
14 |
-
from timed_objects import ASRToken
|
15 |
-
|
16 |
-
import math
|
17 |
import logging
|
18 |
-
from datetime import timedelta
|
19 |
-
import traceback
|
20 |
-
from state import SharedState
|
21 |
-
from formatters import format_time
|
22 |
from parse_args import parse_args
|
23 |
from audio import AudioProcessor
|
24 |
|
@@ -27,19 +14,8 @@ logging.getLogger().setLevel(logging.WARNING)
|
|
27 |
logger = logging.getLogger(__name__)
|
28 |
logger.setLevel(logging.DEBUG)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
args = parse_args()
|
33 |
|
34 |
-
SAMPLE_RATE = 16000
|
35 |
-
# CHANNELS = 1
|
36 |
-
# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
|
37 |
-
# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
38 |
-
# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
|
39 |
-
# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
|
40 |
-
|
41 |
-
|
42 |
-
##### LOAD APP #####
|
43 |
|
44 |
@asynccontextmanager
|
45 |
async def lifespan(app: FastAPI):
|
@@ -52,7 +28,7 @@ async def lifespan(app: FastAPI):
|
|
52 |
|
53 |
if args.diarization:
|
54 |
from diarization.diarization_online import DiartDiarization
|
55 |
-
diarization = DiartDiarization(
|
56 |
else :
|
57 |
diarization = None
|
58 |
yield
|
@@ -75,66 +51,22 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f:
|
|
75 |
async def get():
|
76 |
return HTMLResponse(html)
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
@app.websocket("/asr")
|
86 |
async def websocket_endpoint(websocket: WebSocket):
|
87 |
audio_processor = AudioProcessor(args, asr, tokenizer)
|
88 |
|
89 |
await websocket.accept()
|
90 |
logger.info("WebSocket connection opened.")
|
91 |
-
|
92 |
-
|
93 |
-
pcm_buffer = bytearray()
|
94 |
-
|
95 |
-
transcription_queue = asyncio.Queue() if args.transcription else None
|
96 |
-
diarization_queue = asyncio.Queue() if args.diarization else None
|
97 |
-
|
98 |
-
online = None
|
99 |
-
|
100 |
-
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
101 |
-
tasks = []
|
102 |
-
if args.transcription and online:
|
103 |
-
tasks.append(asyncio.create_task(
|
104 |
-
audio_processor.transcription_processor(transcription_queue, online)))
|
105 |
-
if args.diarization and diarization:
|
106 |
-
tasks.append(asyncio.create_task(
|
107 |
-
audio_processor.diarization_processor(diarization_queue, diarization)))
|
108 |
-
formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket))
|
109 |
-
tasks.append(formatter_task)
|
110 |
-
stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue))
|
111 |
-
tasks.append(stdout_reader_task)
|
112 |
-
|
113 |
try:
|
114 |
while True:
|
115 |
-
# Receive incoming WebM audio chunks from the client
|
116 |
message = await websocket.receive_bytes()
|
117 |
-
|
118 |
-
ffmpeg_process.stdin.write(message)
|
119 |
-
ffmpeg_process.stdin.flush()
|
120 |
-
except (BrokenPipeError, AttributeError) as e:
|
121 |
-
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
122 |
-
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
123 |
-
ffmpeg_process.stdin.write(message)
|
124 |
-
ffmpeg_process.stdin.flush()
|
125 |
except WebSocketDisconnect:
|
126 |
logger.warning("WebSocket disconnected.")
|
127 |
finally:
|
128 |
-
|
129 |
-
task.cancel()
|
130 |
-
try:
|
131 |
-
await asyncio.gather(*tasks, return_exceptions=True)
|
132 |
-
ffmpeg_process.stdin.close()
|
133 |
-
ffmpeg_process.wait()
|
134 |
-
except Exception as e:
|
135 |
-
logger.warning(f"Error during cleanup: {e}")
|
136 |
-
if args.diarization and diarization:
|
137 |
-
diarization.close()
|
138 |
logger.info("WebSocket endpoint cleaned up.")
|
139 |
|
140 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from contextlib import asynccontextmanager
|
2 |
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from fastapi.responses import HTMLResponse
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
|
7 |
+
from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
|
|
|
|
|
|
8 |
import logging
|
|
|
|
|
|
|
|
|
9 |
from parse_args import parse_args
|
10 |
from audio import AudioProcessor
|
11 |
|
|
|
14 |
logger = logging.getLogger(__name__)
|
15 |
logger.setLevel(logging.DEBUG)
|
16 |
|
|
|
|
|
17 |
args = parse_args()
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
@asynccontextmanager
|
21 |
async def lifespan(app: FastAPI):
|
|
|
28 |
|
29 |
if args.diarization:
|
30 |
from diarization.diarization_online import DiartDiarization
|
31 |
+
diarization = DiartDiarization()
|
32 |
else :
|
33 |
diarization = None
|
34 |
yield
|
|
|
51 |
async def get():
|
52 |
return HTMLResponse(html)
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
@app.websocket("/asr")
|
55 |
async def websocket_endpoint(websocket: WebSocket):
|
56 |
audio_processor = AudioProcessor(args, asr, tokenizer)
|
57 |
|
58 |
await websocket.accept()
|
59 |
logger.info("WebSocket connection opened.")
|
60 |
+
|
61 |
+
await audio_processor.create_tasks(websocket, diarization)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
try:
|
63 |
while True:
|
|
|
64 |
message = await websocket.receive_bytes()
|
65 |
+
audio_processor.process_audio(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
except WebSocketDisconnect:
|
67 |
logger.warning("WebSocket disconnected.")
|
68 |
finally:
|
69 |
+
audio_processor.cleanup()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
logger.info("WebSocket endpoint cleaned up.")
|
71 |
|
72 |
if __name__ == "__main__":
|