qfuxa commited on
Commit
ff49b3c
·
1 Parent(s): b9f09f7

Refactor DiartDiarization initialization and streamline WebSocket audio processing

Browse files
audio.py CHANGED
@@ -1,25 +1,15 @@
1
- import io
2
- import argparse
3
  import asyncio
4
  import numpy as np
5
  import ffmpeg
6
  from time import time, sleep
7
- from contextlib import asynccontextmanager
8
 
9
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
- from fastapi.responses import HTMLResponse
11
- from fastapi.middleware.cors import CORSMiddleware
12
-
13
- from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
14
- from timed_objects import ASRToken
15
 
 
16
  import math
17
  import logging
18
- from datetime import timedelta
19
  import traceback
20
  from state import SharedState
21
  from formatters import format_time
22
- from parse_args import parse_args
23
 
24
 
25
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -27,7 +17,6 @@ logging.getLogger().setLevel(logging.WARNING)
27
  logger = logging.getLogger(__name__)
28
  logger.setLevel(logging.DEBUG)
29
 
30
-
31
  class AudioProcessor:
32
 
33
  def __init__(self, args, asr, tokenizer):
@@ -38,9 +27,22 @@ class AudioProcessor:
38
  self.bytes_per_sample = 2
39
  self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
40
  self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
 
 
41
  self.shared_state = SharedState()
42
  self.asr = asr
43
  self.tokenizer = tokenizer
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def convert_pcm_to_float(self, pcm_buffer):
46
  """
@@ -70,26 +72,17 @@ class AudioProcessor:
70
  )
71
  return process
72
 
73
- async def restart_ffmpeg(self, ffmpeg_process, online, pcm_buffer):
74
- if ffmpeg_process:
75
  try:
76
- ffmpeg_process.kill()
77
- await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
78
  except Exception as e:
79
  logger.warning(f"Error killing FFmpeg process: {e}")
80
- ffmpeg_process = await self.start_ffmpeg_decoder()
81
- pcm_buffer = bytearray()
82
-
83
- if self.args.transcription:
84
- online = online_factory(self.args, self.asr, self.tokenizer)
85
-
86
- await self.shared_state.reset()
87
- logger.info("FFmpeg process started.")
88
- return ffmpeg_process, online, pcm_buffer
89
-
90
-
91
 
92
- async def ffmpeg_stdout_reader(self, ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue):
93
  loop = asyncio.get_event_loop()
94
  beg = time()
95
 
@@ -103,36 +96,36 @@ class AudioProcessor:
103
  try:
104
  chunk = await asyncio.wait_for(
105
  loop.run_in_executor(
106
- None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
107
  ),
108
  timeout=15.0
109
  )
110
  except asyncio.TimeoutError:
111
  logger.warning("FFmpeg read timeout. Restarting...")
112
- ffmpeg_process, online, pcm_buffer = await self.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
113
  beg = time()
114
  continue # Skip processing and read from new process
115
 
116
  if not chunk:
117
  logger.info("FFmpeg stdout closed.")
118
  break
119
- pcm_buffer.extend(chunk)
120
 
121
- if self.args.diarization and diarization_queue:
122
- await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy())
123
 
124
- if len(pcm_buffer) >= self.bytes_per_sec:
125
- if len(pcm_buffer) > self.max_bytes_per_sec:
126
  logger.warning(
127
- f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds.
128
  The model probably struggles to keep up. Consider using a smaller model.
129
  """)
130
 
131
- pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec])
132
- pcm_buffer = pcm_buffer[self.max_bytes_per_sec:]
133
 
134
- if self.args.transcription and transcription_queue:
135
- await transcription_queue.put(pcm_array.copy())
136
 
137
 
138
  if not self.args.transcription and not self.args.diarization:
@@ -144,27 +137,24 @@ class AudioProcessor:
144
  break
145
  logger.info("Exiting ffmpeg_stdout_reader...")
146
 
147
-
148
-
149
-
150
- async def transcription_processor(self, pcm_queue, online):
151
  full_transcription = ""
152
- sep = online.asr.sep
153
 
154
  while True:
155
  try:
156
- pcm_array = await pcm_queue.get()
157
 
158
- logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
159
 
160
  # Process transcription
161
- online.insert_audio_chunk(pcm_array)
162
- new_tokens = online.process_iter()
163
 
164
  if new_tokens:
165
  full_transcription += sep.join([t.text for t in new_tokens])
166
 
167
- _buffer = online.get_buffer()
168
  buffer = _buffer.text
169
  end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
170
 
@@ -178,14 +168,15 @@ class AudioProcessor:
178
  logger.warning(f"Exception in transcription_processor: {e}")
179
  logger.warning(f"Traceback: {traceback.format_exc()}")
180
  finally:
181
- pcm_queue.task_done()
182
 
183
- async def diarization_processor(self, pcm_queue, diarization_obj):
 
184
  buffer_diarization = ""
185
 
186
  while True:
187
  try:
188
- pcm_array = await pcm_queue.get()
189
 
190
  # Process diarization
191
  await diarization_obj.diarize(pcm_array)
@@ -205,7 +196,7 @@ class AudioProcessor:
205
  logger.warning(f"Exception in diarization_processor: {e}")
206
  logger.warning(f"Traceback: {traceback.format_exc()}")
207
  finally:
208
- pcm_queue.task_done()
209
 
210
  async def results_formatter(self, websocket):
211
  while True:
@@ -304,3 +295,40 @@ class AudioProcessor:
304
  logger.warning(f"Exception in results_formatter: {e}")
305
  logger.warning(f"Traceback: {traceback.format_exc()}")
306
  await asyncio.sleep(0.5) # Back off on error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import numpy as np
3
  import ffmpeg
4
  from time import time, sleep
 
5
 
 
 
 
 
 
 
6
 
7
+ from whisper_streaming_custom.whisper_online import online_factory
8
  import math
9
  import logging
 
10
  import traceback
11
  from state import SharedState
12
  from formatters import format_time
 
13
 
14
 
15
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.DEBUG)
19
 
 
20
  class AudioProcessor:
21
 
22
  def __init__(self, args, asr, tokenizer):
 
27
  self.bytes_per_sample = 2
28
  self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
29
  self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
30
+
31
+
32
  self.shared_state = SharedState()
33
  self.asr = asr
34
  self.tokenizer = tokenizer
35
+
36
+ self.ffmpeg_process = self.start_ffmpeg_decoder()
37
+
38
+ self.transcription_queue = asyncio.Queue() if self.args.transcription else None
39
+ self.diarization_queue = asyncio.Queue() if self.args.diarization else None
40
+
41
+ self.pcm_buffer = bytearray()
42
+ if self.args.transcription:
43
+ self.online = online_factory(self.args, self.asr, self.tokenizer)
44
+
45
+
46
 
47
  def convert_pcm_to_float(self, pcm_buffer):
48
  """
 
72
  )
73
  return process
74
 
75
+ async def restart_ffmpeg(self):
76
+ if self.ffmpeg_process:
77
  try:
78
+ self.ffmpeg_process.kill()
79
+ await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
80
  except Exception as e:
81
  logger.warning(f"Error killing FFmpeg process: {e}")
82
+ self.ffmpeg_process = await self.start_ffmpeg_decoder()
83
+ self.pcm_buffer = bytearray()
 
 
 
 
 
 
 
 
 
84
 
85
+ async def ffmpeg_stdout_reader(self):
86
  loop = asyncio.get_event_loop()
87
  beg = time()
88
 
 
96
  try:
97
  chunk = await asyncio.wait_for(
98
  loop.run_in_executor(
99
+ None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
100
  ),
101
  timeout=15.0
102
  )
103
  except asyncio.TimeoutError:
104
  logger.warning("FFmpeg read timeout. Restarting...")
105
+ await self.restart_ffmpeg()
106
  beg = time()
107
  continue # Skip processing and read from new process
108
 
109
  if not chunk:
110
  logger.info("FFmpeg stdout closed.")
111
  break
112
+ self.pcm_buffer.extend(chunk)
113
 
114
+ if self.args.diarization and self.diarization_queue:
115
+ await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy())
116
 
117
+ if len(self.pcm_buffer) >= self.bytes_per_sec:
118
+ if len(self.pcm_buffer) > self.max_bytes_per_sec:
119
  logger.warning(
120
+ f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds.
121
  The model probably struggles to keep up. Consider using a smaller model.
122
  """)
123
 
124
+ pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
125
+ self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
126
 
127
+ if self.args.transcription and self.transcription_queue:
128
+ await self.transcription_queue.put(pcm_array.copy())
129
 
130
 
131
  if not self.args.transcription and not self.args.diarization:
 
137
  break
138
  logger.info("Exiting ffmpeg_stdout_reader...")
139
 
140
+ async def transcription_processor(self):
 
 
 
141
  full_transcription = ""
142
+ sep = self.online.asr.sep
143
 
144
  while True:
145
  try:
146
+ pcm_array = await self.transcription_queue.get()
147
 
148
+ logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.")
149
 
150
  # Process transcription
151
+ self.online.insert_audio_chunk(pcm_array)
152
+ new_tokens = self.online.process_iter()
153
 
154
  if new_tokens:
155
  full_transcription += sep.join([t.text for t in new_tokens])
156
 
157
+ _buffer = self.online.get_buffer()
158
  buffer = _buffer.text
159
  end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
160
 
 
168
  logger.warning(f"Exception in transcription_processor: {e}")
169
  logger.warning(f"Traceback: {traceback.format_exc()}")
170
  finally:
171
+ self.transcription_queue.task_done()
172
 
173
+
174
+ async def diarization_processor(self, diarization_obj):
175
  buffer_diarization = ""
176
 
177
  while True:
178
  try:
179
+ pcm_array = await self.diarization_queue.get()
180
 
181
  # Process diarization
182
  await diarization_obj.diarize(pcm_array)
 
196
  logger.warning(f"Exception in diarization_processor: {e}")
197
  logger.warning(f"Traceback: {traceback.format_exc()}")
198
  finally:
199
+ self.diarization_queue.task_done()
200
 
201
  async def results_formatter(self, websocket):
202
  while True:
 
295
  logger.warning(f"Exception in results_formatter: {e}")
296
  logger.warning(f"Traceback: {traceback.format_exc()}")
297
  await asyncio.sleep(0.5) # Back off on error
298
+
299
+ async def create_tasks(self, websocket, diarization):
300
+ tasks = []
301
+ if self.args.transcription and self.online:
302
+ tasks.append(asyncio.create_task(self.transcription_processor()))
303
+ if self.args.diarization and diarization:
304
+ tasks.append(asyncio.create_task(self.diarization_processor(diarization)))
305
+ formatter_task = asyncio.create_task(self.results_formatter(websocket))
306
+ tasks.append(formatter_task)
307
+ stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
308
+ tasks.append(stdout_reader_task)
309
+ self.tasks = tasks
310
+ self.diarization = diarization
311
+
312
+ async def cleanup(self):
313
+ for task in self.tasks:
314
+ task.cancel()
315
+ try:
316
+ await asyncio.gather(*self.tasks, return_exceptions=True)
317
+ self.ffmpeg_process.stdin.close()
318
+ self.ffmpeg_process.wait()
319
+ except Exception as e:
320
+ logger.warning(f"Error during cleanup: {e}")
321
+ if self.args.diarization and self.diarization:
322
+ self.diarization.close()
323
+
324
+ async def process_audio(self, message):
325
+ try:
326
+ self.ffmpeg_process.stdin.write(message)
327
+ self.ffmpeg_process.stdin.flush()
328
+ except (BrokenPipeError, AttributeError) as e:
329
+ logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
330
+ await self.restart_ffmpeg()
331
+ self.ffmpeg_process.stdin.write(message)
332
+ self.ffmpeg_process.stdin.flush()
333
+
334
+
diarization/diarization_online.py CHANGED
@@ -103,7 +103,7 @@ class WebSocketAudioSource(AudioSource):
103
 
104
 
105
  class DiartDiarization:
106
- def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
107
  self.pipeline = SpeakerDiarization(config=config)
108
  self.observer = DiarizationObserver()
109
 
 
103
 
104
 
105
  class DiartDiarization:
106
+ def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
107
  self.pipeline = SpeakerDiarization(config=config)
108
  self.observer = DiarizationObserver()
109
 
whisper_fastapi_online_server.py CHANGED
@@ -1,24 +1,11 @@
1
- import io
2
- import argparse
3
- import asyncio
4
- import numpy as np
5
- import ffmpeg
6
- from time import time, sleep
7
  from contextlib import asynccontextmanager
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
- from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
14
- from timed_objects import ASRToken
15
-
16
- import math
17
  import logging
18
- from datetime import timedelta
19
- import traceback
20
- from state import SharedState
21
- from formatters import format_time
22
  from parse_args import parse_args
23
  from audio import AudioProcessor
24
 
@@ -27,19 +14,8 @@ logging.getLogger().setLevel(logging.WARNING)
27
  logger = logging.getLogger(__name__)
28
  logger.setLevel(logging.DEBUG)
29
 
30
-
31
-
32
  args = parse_args()
33
 
34
- SAMPLE_RATE = 16000
35
- # CHANNELS = 1
36
- # SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
37
- # BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
38
- # BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
39
- # MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
40
-
41
-
42
- ##### LOAD APP #####
43
 
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
@@ -52,7 +28,7 @@ async def lifespan(app: FastAPI):
52
 
53
  if args.diarization:
54
  from diarization.diarization_online import DiartDiarization
55
- diarization = DiartDiarization(SAMPLE_RATE)
56
  else :
57
  diarization = None
58
  yield
@@ -75,66 +51,22 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f:
75
  async def get():
76
  return HTMLResponse(html)
77
 
78
-
79
-
80
-
81
-
82
-
83
-
84
-
85
  @app.websocket("/asr")
86
  async def websocket_endpoint(websocket: WebSocket):
87
  audio_processor = AudioProcessor(args, asr, tokenizer)
88
 
89
  await websocket.accept()
90
  logger.info("WebSocket connection opened.")
91
-
92
- ffmpeg_process = None
93
- pcm_buffer = bytearray()
94
-
95
- transcription_queue = asyncio.Queue() if args.transcription else None
96
- diarization_queue = asyncio.Queue() if args.diarization else None
97
-
98
- online = None
99
-
100
- ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
101
- tasks = []
102
- if args.transcription and online:
103
- tasks.append(asyncio.create_task(
104
- audio_processor.transcription_processor(transcription_queue, online)))
105
- if args.diarization and diarization:
106
- tasks.append(asyncio.create_task(
107
- audio_processor.diarization_processor(diarization_queue, diarization)))
108
- formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket))
109
- tasks.append(formatter_task)
110
- stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue))
111
- tasks.append(stdout_reader_task)
112
-
113
  try:
114
  while True:
115
- # Receive incoming WebM audio chunks from the client
116
  message = await websocket.receive_bytes()
117
- try:
118
- ffmpeg_process.stdin.write(message)
119
- ffmpeg_process.stdin.flush()
120
- except (BrokenPipeError, AttributeError) as e:
121
- logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
122
- ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
123
- ffmpeg_process.stdin.write(message)
124
- ffmpeg_process.stdin.flush()
125
  except WebSocketDisconnect:
126
  logger.warning("WebSocket disconnected.")
127
  finally:
128
- for task in tasks:
129
- task.cancel()
130
- try:
131
- await asyncio.gather(*tasks, return_exceptions=True)
132
- ffmpeg_process.stdin.close()
133
- ffmpeg_process.wait()
134
- except Exception as e:
135
- logger.warning(f"Error during cleanup: {e}")
136
- if args.diarization and diarization:
137
- diarization.close()
138
  logger.info("WebSocket endpoint cleaned up.")
139
 
140
  if __name__ == "__main__":
 
 
 
 
 
 
 
1
  from contextlib import asynccontextmanager
2
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from fastapi.responses import HTMLResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
7
+ from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
 
 
 
8
  import logging
 
 
 
 
9
  from parse_args import parse_args
10
  from audio import AudioProcessor
11
 
 
14
  logger = logging.getLogger(__name__)
15
  logger.setLevel(logging.DEBUG)
16
 
 
 
17
  args = parse_args()
18
 
 
 
 
 
 
 
 
 
 
19
 
20
  @asynccontextmanager
21
  async def lifespan(app: FastAPI):
 
28
 
29
  if args.diarization:
30
  from diarization.diarization_online import DiartDiarization
31
+ diarization = DiartDiarization()
32
  else :
33
  diarization = None
34
  yield
 
51
  async def get():
52
  return HTMLResponse(html)
53
 
 
 
 
 
 
 
 
54
  @app.websocket("/asr")
55
  async def websocket_endpoint(websocket: WebSocket):
56
  audio_processor = AudioProcessor(args, asr, tokenizer)
57
 
58
  await websocket.accept()
59
  logger.info("WebSocket connection opened.")
60
+
61
+ await audio_processor.create_tasks(websocket, diarization)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
  while True:
 
64
  message = await websocket.receive_bytes()
65
+ audio_processor.process_audio(message)
 
 
 
 
 
 
 
66
  except WebSocketDisconnect:
67
  logger.warning("WebSocket disconnected.")
68
  finally:
69
+ audio_processor.cleanup()
 
 
 
 
 
 
 
 
 
70
  logger.info("WebSocket endpoint cleaned up.")
71
 
72
  if __name__ == "__main__":