qfuxa commited on
Commit
d5886b3
·
1 Parent(s): ca4162a

Implement logging for WebSocket events and FFmpeg process management; remove obsolete test file

Browse files
whisper_fastapi_online_server.py CHANGED
@@ -14,8 +14,14 @@ from src.whisper_streaming.whisper_online import backend_factory, online_factory
14
 
15
  import subprocess
16
  import math
 
17
 
18
 
 
 
 
 
 
19
  ##### LOAD ARGS #####
20
 
21
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
@@ -51,6 +57,7 @@ CHANNELS = 1
51
  SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
52
  BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
53
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
 
54
 
55
  if args.diarization:
56
  from src.diarization.diarization_online import DiartDiarization
@@ -106,7 +113,7 @@ async def get():
106
  @app.websocket("/asr")
107
  async def websocket_endpoint(websocket: WebSocket):
108
  await websocket.accept()
109
- print("WebSocket connection opened.")
110
 
111
  ffmpeg_process = None
112
  pcm_buffer = bytearray()
@@ -120,13 +127,13 @@ async def websocket_endpoint(websocket: WebSocket):
120
  ffmpeg_process.kill()
121
  await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
122
  except Exception as e:
123
- print(f"Error killing FFmpeg process: {e}")
124
  ffmpeg_process = await start_ffmpeg_decoder()
125
  pcm_buffer = bytearray()
126
  online = online_factory(args, asr, tokenizer)
127
  if args.diarization:
128
  diarization = DiartDiarization(SAMPLE_RATE)
129
- print("FFmpeg process started.")
130
 
131
  await restart_ffmpeg()
132
 
@@ -153,7 +160,7 @@ async def websocket_endpoint(websocket: WebSocket):
153
  timeout=5.0
154
  )
155
  except asyncio.TimeoutError:
156
- print("FFmpeg read timeout. Restarting...")
157
  await restart_ffmpeg()
158
  full_transcription = ""
159
  chunk_history = []
@@ -161,17 +168,23 @@ async def websocket_endpoint(websocket: WebSocket):
161
  continue # Skip processing and read from new process
162
 
163
  if not chunk:
164
- print("FFmpeg stdout closed.")
165
  break
166
 
167
  pcm_buffer.extend(chunk)
168
  if len(pcm_buffer) >= BYTES_PER_SEC:
 
 
 
 
 
169
  # Convert int16 -> float32
170
  pcm_array = (
171
- np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
172
  / 32768.0
173
  )
174
- pcm_buffer = bytearray()
 
175
  online.insert_audio_chunk(pcm_array)
176
  transcription = online.process_iter()
177
 
@@ -215,10 +228,10 @@ async def websocket_endpoint(websocket: WebSocket):
215
  await websocket.send_json(response)
216
 
217
  except Exception as e:
218
- print(f"Exception in ffmpeg_stdout_reader: {e}")
219
  break
220
 
221
- print("Exiting ffmpeg_stdout_reader...")
222
 
223
  stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
224
 
@@ -230,12 +243,12 @@ async def websocket_endpoint(websocket: WebSocket):
230
  ffmpeg_process.stdin.write(message)
231
  ffmpeg_process.stdin.flush()
232
  except (BrokenPipeError, AttributeError) as e:
233
- print(f"Error writing to FFmpeg: {e}. Restarting...")
234
  await restart_ffmpeg()
235
  ffmpeg_process.stdin.write(message)
236
  ffmpeg_process.stdin.flush()
237
  except WebSocketDisconnect:
238
- print("WebSocket disconnected.")
239
  finally:
240
  stdout_reader_task.cancel()
241
  try:
@@ -254,4 +267,4 @@ if __name__ == "__main__":
254
  uvicorn.run(
255
  "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
256
  log_level="info"
257
- )
 
14
 
15
  import subprocess
16
  import math
17
+ import logging
18
 
19
 
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
21
+ logging.getLogger().setLevel(logging.WARNING)
22
+ logger = logging.getLogger(__name__)
23
+ logger.setLevel(logging.DEBUG)
24
+
25
  ##### LOAD ARGS #####
26
 
27
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
 
57
  SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
58
  BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
59
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
60
+ MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
61
 
62
  if args.diarization:
63
  from src.diarization.diarization_online import DiartDiarization
 
113
  @app.websocket("/asr")
114
  async def websocket_endpoint(websocket: WebSocket):
115
  await websocket.accept()
116
+ logger.info("WebSocket connection opened.")
117
 
118
  ffmpeg_process = None
119
  pcm_buffer = bytearray()
 
127
  ffmpeg_process.kill()
128
  await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
129
  except Exception as e:
130
+ logger.warning(f"Error killing FFmpeg process: {e}")
131
  ffmpeg_process = await start_ffmpeg_decoder()
132
  pcm_buffer = bytearray()
133
  online = online_factory(args, asr, tokenizer)
134
  if args.diarization:
135
  diarization = DiartDiarization(SAMPLE_RATE)
136
+ logger.info("FFmpeg process started.")
137
 
138
  await restart_ffmpeg()
139
 
 
160
  timeout=5.0
161
  )
162
  except asyncio.TimeoutError:
163
+ logger.warning("FFmpeg read timeout. Restarting...")
164
  await restart_ffmpeg()
165
  full_transcription = ""
166
  chunk_history = []
 
168
  continue # Skip processing and read from new process
169
 
170
  if not chunk:
171
+ logger.info("FFmpeg stdout closed.")
172
  break
173
 
174
  pcm_buffer.extend(chunk)
175
  if len(pcm_buffer) >= BYTES_PER_SEC:
176
+ if len(pcm_buffer) > MAX_BYTES_PER_SEC:
177
+ logger.warning(
178
+ f"""Audio buffer is too large: {len(pcm_buffer) / BYTES_PER_SEC:.2f} seconds.
179
+ The model probably struggles to keep up. Consider using a smaller model.
180
+ """)
181
  # Convert int16 -> float32
182
  pcm_array = (
183
+ np.frombuffer(pcm_buffer[:MAX_BYTES_PER_SEC], dtype=np.int16).astype(np.float32)
184
  / 32768.0
185
  )
186
+ pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
187
+ logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
188
  online.insert_audio_chunk(pcm_array)
189
  transcription = online.process_iter()
190
 
 
228
  await websocket.send_json(response)
229
 
230
  except Exception as e:
231
+ logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
232
  break
233
 
234
+ logger.info("Exiting ffmpeg_stdout_reader...")
235
 
236
  stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
237
 
 
243
  ffmpeg_process.stdin.write(message)
244
  ffmpeg_process.stdin.flush()
245
  except (BrokenPipeError, AttributeError) as e:
246
+ logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
247
  await restart_ffmpeg()
248
  ffmpeg_process.stdin.write(message)
249
  ffmpeg_process.stdin.flush()
250
  except WebSocketDisconnect:
251
+ logger.warning("WebSocket disconnected.")
252
  finally:
253
  stdout_reader_task.cancel()
254
  try:
 
267
  uvicorn.run(
268
  "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
269
  log_level="info"
270
+ )
whisper_noserver_test.py DELETED
@@ -1,181 +0,0 @@
1
- #!/usr/bin/env python3
2
- import sys
3
- import numpy as np
4
- import librosa
5
- from functools import lru_cache
6
- import time
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- from src.whisper_streaming.whisper_online import *
12
-
13
- @lru_cache(10**6)
14
- def load_audio(fname):
15
- a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
16
- return a
17
-
18
-
19
- def load_audio_chunk(fname, beg, end):
20
- audio = load_audio(fname)
21
- beg_s = int(beg * 16000)
22
- end_s = int(end * 16000)
23
- return audio[beg_s:end_s]
24
-
25
- if __name__ == "__main__":
26
-
27
- import argparse
28
-
29
- parser = argparse.ArgumentParser()
30
- parser.add_argument(
31
- "--audio_path",
32
- type=str,
33
- default='samples_jfk.wav',
34
- help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
35
- )
36
- add_shared_args(parser)
37
- parser.add_argument(
38
- "--start_at",
39
- type=float,
40
- default=0.0,
41
- help="Start processing audio at this time.",
42
- )
43
- parser.add_argument(
44
- "--offline", action="store_true", default=False, help="Offline mode."
45
- )
46
- parser.add_argument(
47
- "--comp_unaware",
48
- action="store_true",
49
- default=False,
50
- help="Computationally unaware simulation.",
51
- )
52
-
53
- args = parser.parse_args()
54
-
55
- # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
56
- logfile = None # sys.stderr
57
-
58
- if args.offline and args.comp_unaware:
59
- logger.error(
60
- "No or one option from --offline and --comp_unaware are available, not both. Exiting."
61
- )
62
- sys.exit(1)
63
-
64
- # if args.log_level:
65
- # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
66
- # level=getattr(logging, args.log_level))
67
-
68
- set_logging(args, logger,others=["src.whisper_streaming.online_asr"])
69
-
70
- audio_path = args.audio_path
71
-
72
- SAMPLING_RATE = 16000
73
- duration = len(load_audio(audio_path)) / SAMPLING_RATE
74
- logger.info("Audio duration is: %2.2f seconds" % duration)
75
-
76
- asr, online = asr_factory(args, logfile=logfile)
77
- if args.vac:
78
- min_chunk = args.vac_chunk_size
79
- else:
80
- min_chunk = args.min_chunk_size
81
-
82
- # load the audio into the LRU cache before we start the timer
83
- a = load_audio_chunk(audio_path, 0, 1)
84
-
85
- # warm up the ASR because the very first transcribe takes much more time than the other
86
- asr.transcribe(a)
87
-
88
- beg = args.start_at
89
- start = time.time() - beg
90
-
91
- def output_transcript(o, now=None):
92
- # output format in stdout is like:
93
- # 4186.3606 0 1720 Takhle to je
94
- # - the first three words are:
95
- # - emission time from beginning of processing, in milliseconds
96
- # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
97
- # - the next words: segment transcript
98
- if now is None:
99
- now = time.time() - start
100
- if o[0] is not None:
101
- log_string = f"{now*1000:1.0f}, {o[0]*1000:1.0f}-{o[1]*1000:1.0f} ({(now-o[1]):+1.0f}s): {o[2]}"
102
-
103
- logger.debug(
104
- log_string
105
- )
106
-
107
- if logfile is not None:
108
- print(
109
- log_string,
110
- file=logfile,
111
- flush=True,
112
- )
113
- else:
114
- # No text, so no output
115
- pass
116
-
117
- if args.offline: ## offline mode processing (for testing/debugging)
118
- a = load_audio(audio_path)
119
- online.insert_audio_chunk(a)
120
- try:
121
- o = online.process_iter()
122
- except AssertionError as e:
123
- logger.error(f"assertion error: {repr(e)}")
124
- else:
125
- output_transcript(o)
126
- now = None
127
- elif args.comp_unaware: # computational unaware mode
128
- end = beg + min_chunk
129
- while True:
130
- a = load_audio_chunk(audio_path, beg, end)
131
- online.insert_audio_chunk(a)
132
- try:
133
- o = online.process_iter()
134
- except AssertionError as e:
135
- logger.error(f"assertion error: {repr(e)}")
136
- pass
137
- else:
138
- output_transcript(o, now=end)
139
-
140
- logger.debug(f"## last processed {end:.2f}s")
141
-
142
- if end >= duration:
143
- break
144
-
145
- beg = end
146
-
147
- if end + min_chunk > duration:
148
- end = duration
149
- else:
150
- end += min_chunk
151
- now = duration
152
-
153
- else: # online = simultaneous mode
154
- end = 0
155
- while True:
156
- now = time.time() - start
157
- if now < end + min_chunk:
158
- time.sleep(min_chunk + end - now)
159
- end = time.time() - start
160
- a = load_audio_chunk(audio_path, beg, end)
161
- beg = end
162
- online.insert_audio_chunk(a)
163
-
164
- try:
165
- o = online.process_iter()
166
- except AssertionError as e:
167
- logger.error(f"assertion error: {e}")
168
- pass
169
- else:
170
- output_transcript(o)
171
- now = time.time() - start
172
- logger.debug(
173
- f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}"
174
- )
175
-
176
- if end >= duration:
177
- break
178
- now = None
179
-
180
- o = online.finish()
181
- output_transcript(o, now=now)