qfuxa commited on
Commit
69c754e
·
1 Parent(s): 566619b

Refactor import statement for AudioProcessor and update cleanup method to be awaited; remove unused formatters and state management files

Browse files
audio.py → audio_processor.py RENAMED
@@ -2,24 +2,32 @@ import asyncio
2
  import numpy as np
3
  import ffmpeg
4
  from time import time, sleep
5
-
6
-
7
- from whisper_streaming_custom.whisper_online import online_factory
8
  import math
9
  import logging
10
  import traceback
11
- from state import SharedState
12
- from formatters import format_time
13
-
 
14
 
 
15
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
- logging.getLogger().setLevel(logging.WARNING)
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.DEBUG)
19
 
 
 
 
 
20
  class AudioProcessor:
 
 
 
 
21
 
22
  def __init__(self, args, asr, tokenizer):
 
 
23
  self.args = args
24
  self.sample_rate = 16000
25
  self.channels = 1
@@ -28,106 +36,165 @@ class AudioProcessor:
28
  self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
29
  self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
30
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- self.shared_state = SharedState()
33
  self.asr = asr
34
  self.tokenizer = tokenizer
35
-
36
  self.ffmpeg_process = self.start_ffmpeg_decoder()
37
-
38
- self.transcription_queue = asyncio.Queue() if self.args.transcription else None
39
- self.diarization_queue = asyncio.Queue() if self.args.diarization else None
40
-
41
  self.pcm_buffer = bytearray()
42
- if self.args.transcription:
43
- self.online = online_factory(self.args, self.asr, self.tokenizer)
44
-
45
 
 
 
 
46
 
47
  def convert_pcm_to_float(self, pcm_buffer):
48
- """
49
- Converts a PCM buffer in s16le format to a normalized NumPy array.
50
- Arg: pcm_buffer. PCM buffer containing raw audio data in s16le format
51
- Returns: np.ndarray. NumPy array of float32 type normalized between -1.0 and 1.0
52
- """
53
- pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
54
- / 32768.0)
55
- return pcm_array
56
 
57
  def start_ffmpeg_decoder(self):
58
- """
59
- Start an FFmpeg process in async streaming mode that reads WebM from stdin
60
- and outputs raw s16le PCM on stdout. Returns the process object.
61
- """
62
- process = (
63
- ffmpeg.input("pipe:0", format="webm")
64
- .output(
65
- "pipe:1",
66
- format="s16le",
67
- acodec="pcm_s16le",
68
- ac=self.channels,
69
- ar=str(self.sample_rate),
70
- )
71
- .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
72
- )
73
- return process
74
 
75
  async def restart_ffmpeg(self):
 
76
  if self.ffmpeg_process:
77
  try:
78
  self.ffmpeg_process.kill()
79
  await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
80
  except Exception as e:
81
  logger.warning(f"Error killing FFmpeg process: {e}")
82
- self.ffmpeg_process = self.start_ffmpeg_decoder()
83
  self.pcm_buffer = bytearray()
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  async def ffmpeg_stdout_reader(self):
 
86
  loop = asyncio.get_event_loop()
87
  beg = time()
88
 
89
  while True:
90
  try:
91
- elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
92
- ffmpeg_buffer_from_duration = max(int(32000 * elapsed_time), 4096)
 
93
  beg = time()
94
 
95
  # Read chunk with timeout
96
  try:
97
  chunk = await asyncio.wait_for(
98
- loop.run_in_executor(
99
- None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
100
- ),
101
  timeout=15.0
102
  )
103
  except asyncio.TimeoutError:
104
  logger.warning("FFmpeg read timeout. Restarting...")
105
  await self.restart_ffmpeg()
106
  beg = time()
107
- continue # Skip processing and read from new process
108
 
109
  if not chunk:
110
  logger.info("FFmpeg stdout closed.")
111
  break
 
112
  self.pcm_buffer.extend(chunk)
113
 
 
114
  if self.args.diarization and self.diarization_queue:
115
- await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy())
 
 
116
 
 
117
  if len(self.pcm_buffer) >= self.bytes_per_sec:
118
  if len(self.pcm_buffer) > self.max_bytes_per_sec:
119
  logger.warning(
120
- f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds.
121
- The model probably struggles to keep up. Consider using a smaller model.
122
- """)
123
 
 
124
  pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
125
  self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
126
 
 
127
  if self.args.transcription and self.transcription_queue:
128
  await self.transcription_queue.put(pcm_array.copy())
129
 
130
-
131
  if not self.args.transcription and not self.args.diarization:
132
  await asyncio.sleep(0.1)
133
 
@@ -135,34 +202,39 @@ class AudioProcessor:
135
  logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
136
  logger.warning(f"Traceback: {traceback.format_exc()}")
137
  break
138
- logger.info("Exiting ffmpeg_stdout_reader...")
139
 
140
  async def transcription_processor(self):
141
- full_transcription = ""
142
- sep = self.online.asr.sep
 
143
 
144
  while True:
145
  try:
146
  pcm_array = await self.transcription_queue.get()
147
 
148
- logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.")
149
 
150
  # Process transcription
151
  self.online.insert_audio_chunk(pcm_array)
152
  new_tokens = self.online.process_iter()
153
 
154
  if new_tokens:
155
- full_transcription += sep.join([t.text for t in new_tokens])
156
 
 
157
  _buffer = self.online.get_buffer()
158
  buffer = _buffer.text
159
- end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
 
 
160
 
161
- if buffer in full_transcription:
 
162
  buffer = ""
163
 
164
- await self.shared_state.update_transcription(
165
- new_tokens, buffer, end_buffer, full_transcription, sep)
 
166
 
167
  except Exception as e:
168
  logger.warning(f"Exception in transcription_processor: {e}")
@@ -170,8 +242,8 @@ class AudioProcessor:
170
  finally:
171
  self.transcription_queue.task_done()
172
 
173
-
174
  async def diarization_processor(self, diarization_obj):
 
175
  buffer_diarization = ""
176
 
177
  while True:
@@ -181,16 +253,13 @@ class AudioProcessor:
181
  # Process diarization
182
  await diarization_obj.diarize(pcm_array)
183
 
184
- # Get current state
185
- state = await self.shared_state.get_current_state()
186
- tokens = state["tokens"]
187
- end_attributed_speaker = state["end_attributed_speaker"]
188
-
189
- # Update speaker information
190
- new_end_attributed_speaker = diarization_obj.assign_speakers_to_tokens(
191
- end_attributed_speaker, tokens)
192
 
193
- await self.shared_state.update_diarization(new_end_attributed_speaker, buffer_diarization)
194
 
195
  except Exception as e:
196
  logger.warning(f"Exception in diarization_processor: {e}")
@@ -199,94 +268,94 @@ class AudioProcessor:
199
  self.diarization_queue.task_done()
200
 
201
  async def results_formatter(self):
 
202
  while True:
203
  try:
204
- state = await self.shared_state.get_current_state()
 
205
  tokens = state["tokens"]
206
  buffer_transcription = state["buffer_transcription"]
207
  buffer_diarization = state["buffer_diarization"]
208
  end_attributed_speaker = state["end_attributed_speaker"]
209
- remaining_time_transcription = state["remaining_time_transcription"]
210
- remaining_time_diarization = state["remaining_time_diarization"]
211
  sep = state["sep"]
212
 
213
- # If diarization is enabled but no transcription, add dummy tokens periodically
214
  if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
215
- await self.shared_state.add_dummy_token()
216
  sleep(0.5)
217
- state = await self.shared_state.get_current_state()
218
  tokens = state["tokens"]
 
 
219
  previous_speaker = -1
220
  lines = []
221
  last_end_diarized = 0
222
  undiarized_text = []
223
 
 
224
  for token in tokens:
225
  speaker = token.speaker
 
 
226
  if self.args.diarization:
227
- if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
228
  undiarized_text.append(token.text)
229
  continue
230
- elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
231
  speaker = previous_speaker
232
  if speaker not in [-1, 0]:
233
  last_end_diarized = max(token.end, last_end_diarized)
234
 
 
235
  if speaker != previous_speaker or not lines:
236
- lines.append(
237
- {
238
- "speaker": speaker,
239
- "text": token.text,
240
- "beg": format_time(token.start),
241
- "end": format_time(token.end),
242
- "diff": round(token.end - last_end_diarized, 2)
243
- }
244
- )
245
  previous_speaker = speaker
246
  elif token.text: # Only append if text isn't empty
247
  lines[-1]["text"] += sep + token.text
248
  lines[-1]["end"] = format_time(token.end)
249
  lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
250
 
 
251
  if undiarized_text:
252
- combined_buffer_diarization = sep.join(undiarized_text)
253
  if buffer_transcription:
254
- combined_buffer_diarization += sep
255
- await self.shared_state.update_diarization(end_attributed_speaker, combined_buffer_diarization)
256
- buffer_diarization = combined_buffer_diarization
257
-
258
- if lines:
259
- response = {
260
- "lines": lines,
261
- "buffer_transcription": buffer_transcription,
262
- "buffer_diarization": buffer_diarization,
263
- "remaining_time_transcription": remaining_time_transcription,
264
- "remaining_time_diarization": remaining_time_diarization
265
- }
266
- else:
267
- response = {
268
- "lines": [{
269
- "speaker": 1,
270
- "text": "",
271
- "beg": format_time(0),
272
- "end": format_time(tokens[-1].end) if tokens else format_time(0),
273
- "diff": 0
274
- }],
275
- "buffer_transcription": buffer_transcription,
276
- "buffer_diarization": buffer_diarization,
277
- "remaining_time_transcription": remaining_time_transcription,
278
- "remaining_time_diarization": remaining_time_diarization
279
- }
280
 
281
- response_content = ' '.join([str(line['speaker']) + ' ' + line["text"] for line in lines]) + ' | ' + buffer_transcription + ' | ' + buffer_diarization
 
 
282
 
283
- if response_content != self.shared_state.last_response_content:
284
- if lines or buffer_transcription or buffer_diarization:
285
- yield response
286
- self.shared_state.last_response_content = response_content
287
 
288
- #small delay to avoid overwhelming the client
289
- await asyncio.sleep(0.1)
290
 
291
  except Exception as e:
292
  logger.warning(f"Exception in results_formatter: {e}")
@@ -294,35 +363,39 @@ class AudioProcessor:
294
  await asyncio.sleep(0.5) # Back off on error
295
 
296
  async def create_tasks(self, diarization=None):
 
297
  if diarization:
298
  self.diarization = diarization
299
 
300
  tasks = []
301
  if self.args.transcription and self.online:
302
- tasks.append(asyncio.create_task(self.transcription_processor()))
 
303
  if self.args.diarization and self.diarization:
304
  tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
305
 
306
- stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
307
- tasks.append(stdout_reader_task)
308
-
309
  self.tasks = tasks
310
 
311
  return self.results_formatter()
312
 
313
  async def cleanup(self):
 
314
  for task in self.tasks:
315
  task.cancel()
 
316
  try:
317
  await asyncio.gather(*self.tasks, return_exceptions=True)
318
  self.ffmpeg_process.stdin.close()
319
  self.ffmpeg_process.wait()
320
  except Exception as e:
321
  logger.warning(f"Error during cleanup: {e}")
322
- if self.args.diarization and self.diarization:
 
323
  self.diarization.close()
324
 
325
  async def process_audio(self, message):
 
326
  try:
327
  self.ffmpeg_process.stdin.write(message)
328
  self.ffmpeg_process.stdin.flush()
@@ -330,6 +403,4 @@ class AudioProcessor:
330
  logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
331
  await self.restart_ffmpeg()
332
  self.ffmpeg_process.stdin.write(message)
333
- self.ffmpeg_process.stdin.flush()
334
-
335
-
 
2
  import numpy as np
3
  import ffmpeg
4
  from time import time, sleep
 
 
 
5
  import math
6
  import logging
7
  import traceback
8
+ from datetime import timedelta
9
+ from typing import List, Dict, Any
10
+ from timed_objects import ASRToken
11
+ from whisper_streaming_custom.whisper_online import online_factory
12
 
13
+ # Set up logging once
14
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.DEBUG)
17
 
18
+ def format_time(seconds: float) -> str:
19
+ """Format seconds as HH:MM:SS."""
20
+ return str(timedelta(seconds=int(seconds)))
21
+
22
  class AudioProcessor:
23
+ """
24
+ Processes audio streams for transcription and diarization.
25
+ Handles audio processing, state management, and result formatting in a single class.
26
+ """
27
 
28
  def __init__(self, args, asr, tokenizer):
29
+ """Initialize the audio processor with configuration, models, and state."""
30
+ # Audio processing settings
31
  self.args = args
32
  self.sample_rate = 16000
33
  self.channels = 1
 
36
  self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
37
  self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
38
 
39
+ # State management
40
+ self.tokens = []
41
+ self.buffer_transcription = ""
42
+ self.buffer_diarization = ""
43
+ self.full_transcription = ""
44
+ self.end_buffer = 0
45
+ self.end_attributed_speaker = 0
46
+ self.lock = asyncio.Lock()
47
+ self.beg_loop = time()
48
+ self.sep = " " # Default separator
49
+ self.last_response_content = ""
50
 
51
+ # Models and processing
52
  self.asr = asr
53
  self.tokenizer = tokenizer
 
54
  self.ffmpeg_process = self.start_ffmpeg_decoder()
55
+ self.transcription_queue = asyncio.Queue() if args.transcription else None
56
+ self.diarization_queue = asyncio.Queue() if args.diarization else None
 
 
57
  self.pcm_buffer = bytearray()
 
 
 
58
 
59
+ # Initialize transcription engine if enabled
60
+ if args.transcription:
61
+ self.online = online_factory(args, asr, tokenizer)
62
 
63
  def convert_pcm_to_float(self, pcm_buffer):
64
+ """Convert PCM buffer in s16le format to normalized NumPy array."""
65
+ return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
 
 
 
 
 
 
66
 
67
  def start_ffmpeg_decoder(self):
68
+ """Start FFmpeg process for WebM to PCM conversion."""
69
+ return (ffmpeg.input("pipe:0", format="webm")
70
+ .output("pipe:1", format="s16le", acodec="pcm_s16le",
71
+ ac=self.channels, ar=str(self.sample_rate))
72
+ .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  async def restart_ffmpeg(self):
75
+ """Restart the FFmpeg process after failure."""
76
  if self.ffmpeg_process:
77
  try:
78
  self.ffmpeg_process.kill()
79
  await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
80
  except Exception as e:
81
  logger.warning(f"Error killing FFmpeg process: {e}")
82
+ self.ffmpeg_process = self.start_ffmpeg_decoder()
83
  self.pcm_buffer = bytearray()
84
 
85
+ async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
86
+ """Thread-safe update of transcription with new data."""
87
+ async with self.lock:
88
+ self.tokens.extend(new_tokens)
89
+ self.buffer_transcription = buffer
90
+ self.end_buffer = end_buffer
91
+ self.full_transcription = full_transcription
92
+ self.sep = sep
93
+
94
+ async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
95
+ """Thread-safe update of diarization with new data."""
96
+ async with self.lock:
97
+ self.end_attributed_speaker = end_attributed_speaker
98
+ if buffer_diarization:
99
+ self.buffer_diarization = buffer_diarization
100
+
101
+ async def add_dummy_token(self):
102
+ """Placeholder token when no transcription is available."""
103
+ async with self.lock:
104
+ current_time = time() - self.beg_loop
105
+ self.tokens.append(ASRToken(
106
+ start=current_time, end=current_time + 1,
107
+ text=".", speaker=-1, is_dummy=True
108
+ ))
109
+
110
+ async def get_current_state(self):
111
+ """Get current state."""
112
+ async with self.lock:
113
+ current_time = time()
114
+
115
+ # Calculate remaining times
116
+ remaining_transcription = 0
117
+ if self.end_buffer > 0:
118
+ remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
119
+
120
+ remaining_diarization = 0
121
+ if self.tokens:
122
+ latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
123
+ remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
124
+
125
+ return {
126
+ "tokens": self.tokens.copy(),
127
+ "buffer_transcription": self.buffer_transcription,
128
+ "buffer_diarization": self.buffer_diarization,
129
+ "end_buffer": self.end_buffer,
130
+ "end_attributed_speaker": self.end_attributed_speaker,
131
+ "sep": self.sep,
132
+ "remaining_time_transcription": remaining_transcription,
133
+ "remaining_time_diarization": remaining_diarization
134
+ }
135
+
136
+ async def reset(self):
137
+ """Reset all state variables to initial values."""
138
+ async with self.lock:
139
+ self.tokens = []
140
+ self.buffer_transcription = self.buffer_diarization = ""
141
+ self.end_buffer = self.end_attributed_speaker = 0
142
+ self.full_transcription = self.last_response_content = ""
143
+ self.beg_loop = time()
144
+
145
  async def ffmpeg_stdout_reader(self):
146
+ """Read audio data from FFmpeg stdout and process it."""
147
  loop = asyncio.get_event_loop()
148
  beg = time()
149
 
150
  while True:
151
  try:
152
+ # Calculate buffer size based on elapsed time
153
+ elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
154
+ buffer_size = max(int(32000 * elapsed_time), 4096)
155
  beg = time()
156
 
157
  # Read chunk with timeout
158
  try:
159
  chunk = await asyncio.wait_for(
160
+ loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size),
 
 
161
  timeout=15.0
162
  )
163
  except asyncio.TimeoutError:
164
  logger.warning("FFmpeg read timeout. Restarting...")
165
  await self.restart_ffmpeg()
166
  beg = time()
167
+ continue
168
 
169
  if not chunk:
170
  logger.info("FFmpeg stdout closed.")
171
  break
172
+
173
  self.pcm_buffer.extend(chunk)
174
 
175
+ # Send to diarization if enabled
176
  if self.args.diarization and self.diarization_queue:
177
+ await self.diarization_queue.put(
178
+ self.convert_pcm_to_float(self.pcm_buffer).copy()
179
+ )
180
 
181
+ # Process when we have enough data
182
  if len(self.pcm_buffer) >= self.bytes_per_sec:
183
  if len(self.pcm_buffer) > self.max_bytes_per_sec:
184
  logger.warning(
185
+ f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
186
+ f"Consider using a smaller model."
187
+ )
188
 
189
+ # Process audio chunk
190
  pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
191
  self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
192
 
193
+ # Send to transcription if enabled
194
  if self.args.transcription and self.transcription_queue:
195
  await self.transcription_queue.put(pcm_array.copy())
196
 
197
+ # Sleep if no processing is happening
198
  if not self.args.transcription and not self.args.diarization:
199
  await asyncio.sleep(0.1)
200
 
 
202
  logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
203
  logger.warning(f"Traceback: {traceback.format_exc()}")
204
  break
 
205
 
206
  async def transcription_processor(self):
207
+ """Process audio chunks for transcription."""
208
+ self.full_transcription = ""
209
+ self.sep = self.online.asr.sep
210
 
211
  while True:
212
  try:
213
  pcm_array = await self.transcription_queue.get()
214
 
215
+ logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
216
 
217
  # Process transcription
218
  self.online.insert_audio_chunk(pcm_array)
219
  new_tokens = self.online.process_iter()
220
 
221
  if new_tokens:
222
+ self.full_transcription += self.sep.join([t.text for t in new_tokens])
223
 
224
+ # Get buffer information
225
  _buffer = self.online.get_buffer()
226
  buffer = _buffer.text
227
+ end_buffer = _buffer.end if _buffer.end else (
228
+ new_tokens[-1].end if new_tokens else 0
229
+ )
230
 
231
+ # Avoid duplicating content
232
+ if buffer in self.full_transcription:
233
  buffer = ""
234
 
235
+ await self.update_transcription(
236
+ new_tokens, buffer, end_buffer, self.full_transcription, self.sep
237
+ )
238
 
239
  except Exception as e:
240
  logger.warning(f"Exception in transcription_processor: {e}")
 
242
  finally:
243
  self.transcription_queue.task_done()
244
 
 
245
  async def diarization_processor(self, diarization_obj):
246
+ """Process audio chunks for speaker diarization."""
247
  buffer_diarization = ""
248
 
249
  while True:
 
253
  # Process diarization
254
  await diarization_obj.diarize(pcm_array)
255
 
256
+ # Get current state and update speakers
257
+ state = await self.get_current_state()
258
+ new_end = diarization_obj.assign_speakers_to_tokens(
259
+ state["end_attributed_speaker"], state["tokens"]
260
+ )
 
 
 
261
 
262
+ await self.update_diarization(new_end, buffer_diarization)
263
 
264
  except Exception as e:
265
  logger.warning(f"Exception in diarization_processor: {e}")
 
268
  self.diarization_queue.task_done()
269
 
270
  async def results_formatter(self):
271
+ """Format processing results for output."""
272
  while True:
273
  try:
274
+ # Get current state
275
+ state = await self.get_current_state()
276
  tokens = state["tokens"]
277
  buffer_transcription = state["buffer_transcription"]
278
  buffer_diarization = state["buffer_diarization"]
279
  end_attributed_speaker = state["end_attributed_speaker"]
 
 
280
  sep = state["sep"]
281
 
282
+ # Add dummy tokens if needed
283
  if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
284
+ await self.add_dummy_token()
285
  sleep(0.5)
286
+ state = await self.get_current_state()
287
  tokens = state["tokens"]
288
+
289
+ # Format output
290
  previous_speaker = -1
291
  lines = []
292
  last_end_diarized = 0
293
  undiarized_text = []
294
 
295
+ # Process each token
296
  for token in tokens:
297
  speaker = token.speaker
298
+
299
+ # Handle diarization
300
  if self.args.diarization:
301
+ if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
302
  undiarized_text.append(token.text)
303
  continue
304
+ elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
305
  speaker = previous_speaker
306
  if speaker not in [-1, 0]:
307
  last_end_diarized = max(token.end, last_end_diarized)
308
 
309
+ # Group by speaker
310
  if speaker != previous_speaker or not lines:
311
+ lines.append({
312
+ "speaker": speaker,
313
+ "text": token.text,
314
+ "beg": format_time(token.start),
315
+ "end": format_time(token.end),
316
+ "diff": round(token.end - last_end_diarized, 2)
317
+ })
 
 
318
  previous_speaker = speaker
319
  elif token.text: # Only append if text isn't empty
320
  lines[-1]["text"] += sep + token.text
321
  lines[-1]["end"] = format_time(token.end)
322
  lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
323
 
324
+ # Handle undiarized text
325
  if undiarized_text:
326
+ combined = sep.join(undiarized_text)
327
  if buffer_transcription:
328
+ combined += sep
329
+ await self.update_diarization(end_attributed_speaker, combined)
330
+ buffer_diarization = combined
331
+
332
+ # Create response object
333
+ if not lines:
334
+ lines = [{
335
+ "speaker": 1,
336
+ "text": "",
337
+ "beg": format_time(0),
338
+ "end": format_time(tokens[-1].end if tokens else 0),
339
+ "diff": 0
340
+ }]
341
+
342
+ response = {
343
+ "lines": lines,
344
+ "buffer_transcription": buffer_transcription,
345
+ "buffer_diarization": buffer_diarization,
346
+ "remaining_time_transcription": state["remaining_time_transcription"],
347
+ "remaining_time_diarization": state["remaining_time_diarization"]
348
+ }
 
 
 
 
 
349
 
350
+ # Only yield if content has changed
351
+ response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
352
+ f" | {buffer_transcription} | {buffer_diarization}"
353
 
354
+ if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
355
+ yield response
356
+ self.last_response_content = response_content
 
357
 
358
+ await asyncio.sleep(0.1) # Avoid overwhelming the client
 
359
 
360
  except Exception as e:
361
  logger.warning(f"Exception in results_formatter: {e}")
 
363
  await asyncio.sleep(0.5) # Back off on error
364
 
365
  async def create_tasks(self, diarization=None):
366
+ """Create and start processing tasks."""
367
  if diarization:
368
  self.diarization = diarization
369
 
370
  tasks = []
371
  if self.args.transcription and self.online:
372
+ tasks.append(asyncio.create_task(self.transcription_processor()))
373
+
374
  if self.args.diarization and self.diarization:
375
  tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
376
 
377
+ tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
 
 
378
  self.tasks = tasks
379
 
380
  return self.results_formatter()
381
 
382
  async def cleanup(self):
383
+ """Clean up resources when processing is complete."""
384
  for task in self.tasks:
385
  task.cancel()
386
+
387
  try:
388
  await asyncio.gather(*self.tasks, return_exceptions=True)
389
  self.ffmpeg_process.stdin.close()
390
  self.ffmpeg_process.wait()
391
  except Exception as e:
392
  logger.warning(f"Error during cleanup: {e}")
393
+
394
+ if self.args.diarization and hasattr(self, 'diarization'):
395
  self.diarization.close()
396
 
397
  async def process_audio(self, message):
398
+ """Process incoming audio data."""
399
  try:
400
  self.ffmpeg_process.stdin.write(message)
401
  self.ffmpeg_process.stdin.flush()
 
403
  logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
404
  await self.restart_ffmpeg()
405
  self.ffmpeg_process.stdin.write(message)
406
+ self.ffmpeg_process.stdin.flush()
 
 
formatters.py DELETED
@@ -1,91 +0,0 @@
1
- from typing import Dict, Any, List
2
- from datetime import timedelta
3
-
4
- def format_time(seconds: float) -> str:
5
- """Format seconds as HH:MM:SS."""
6
- return str(timedelta(seconds=int(seconds)))
7
-
8
- def format_response(state: Dict[str, Any], with_diarization: bool = False) -> Dict[str, Any]:
9
- """
10
- Format the shared state into a client-friendly response.
11
-
12
- Args:
13
- state: Current shared state dictionary
14
- with_diarization: Whether to include diarization formatting
15
-
16
- Returns:
17
- Formatted response dictionary ready to send to client
18
- """
19
- tokens = state["tokens"]
20
- buffer_transcription = state["buffer_transcription"]
21
- buffer_diarization = state["buffer_diarization"]
22
- end_attributed_speaker = state["end_attributed_speaker"]
23
- remaining_time_transcription = state["remaining_time_transcription"]
24
- remaining_time_diarization = state["remaining_time_diarization"]
25
- sep = state["sep"]
26
-
27
- # Default response for empty state
28
- if not tokens:
29
- return {
30
- "lines": [{
31
- "speaker": 1,
32
- "text": "",
33
- "beg": format_time(0),
34
- "end": format_time(0),
35
- "diff": 0
36
- }],
37
- "buffer_transcription": buffer_transcription,
38
- "buffer_diarization": buffer_diarization,
39
- "remaining_time_transcription": remaining_time_transcription,
40
- "remaining_time_diarization": remaining_time_diarization
41
- }
42
-
43
- # Process tokens to create response
44
- previous_speaker = -1
45
- lines = []
46
- last_end_diarized = 0
47
- undiarized_text = []
48
-
49
- for token in tokens:
50
- speaker = token.speaker
51
-
52
- # Handle diarization logic
53
- if with_diarization:
54
- if (speaker == -1 or speaker == 0) and token.end >= end_attributed_speaker:
55
- undiarized_text.append(token.text)
56
- continue
57
- elif (speaker == -1 or speaker == 0) and token.end < end_attributed_speaker:
58
- speaker = previous_speaker
59
-
60
- if speaker not in [-1, 0]:
61
- last_end_diarized = max(token.end, last_end_diarized)
62
-
63
- # Add new line or append to existing line
64
- if speaker != previous_speaker or not lines:
65
- lines.append({
66
- "speaker": speaker,
67
- "text": token.text,
68
- "beg": format_time(token.start),
69
- "end": format_time(token.end),
70
- "diff": round(token.end - last_end_diarized, 2)
71
- })
72
- previous_speaker = speaker
73
- elif token.text: # Only append if text isn't empty
74
- lines[-1]["text"] += sep + token.text
75
- lines[-1]["end"] = format_time(token.end)
76
- lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
77
-
78
- # If we have undiarized text, include it in the buffer
79
- if undiarized_text:
80
- combined_buffer = sep.join(undiarized_text)
81
- if buffer_transcription:
82
- combined_buffer += sep + buffer_transcription
83
- buffer_diarization = combined_buffer
84
-
85
- return {
86
- "lines": lines,
87
- "buffer_transcription": buffer_transcription,
88
- "buffer_diarization": buffer_diarization,
89
- "remaining_time_transcription": remaining_time_transcription,
90
- "remaining_time_diarization": remaining_time_diarization
91
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
state.py DELETED
@@ -1,96 +0,0 @@
1
- import asyncio
2
- import logging
3
- from time import time
4
- from typing import List, Dict, Any, Optional
5
- from dataclasses import dataclass, field
6
- from timed_objects import ASRToken
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
-
11
- class SharedState:
12
- """
13
- Thread-safe state manager for streaming transcription and diarization.
14
- Handles coordination between audio processing, transcription, and diarization.
15
- """
16
-
17
- def __init__(self):
18
- self.tokens: List[ASRToken] = []
19
- self.buffer_transcription: str = ""
20
- self.buffer_diarization: str = ""
21
- self.full_transcription: str = ""
22
- self.end_buffer: float = 0
23
- self.end_attributed_speaker: float = 0
24
- self.lock = asyncio.Lock()
25
- self.beg_loop: float = time()
26
- self.sep: str = " " # Default separator
27
- self.last_response_content: str = "" # To track changes in response
28
-
29
- async def update_transcription(self, new_tokens: List[ASRToken], buffer: str,
30
- end_buffer: float, full_transcription: str, sep: str) -> None:
31
- """Update the state with new transcription data."""
32
- async with self.lock:
33
- self.tokens.extend(new_tokens)
34
- self.buffer_transcription = buffer
35
- self.end_buffer = end_buffer
36
- self.full_transcription = full_transcription
37
- self.sep = sep
38
-
39
- async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None:
40
- """Update the state with new diarization data."""
41
- async with self.lock:
42
- self.end_attributed_speaker = end_attributed_speaker
43
- if buffer_diarization:
44
- self.buffer_diarization = buffer_diarization
45
-
46
- async def add_dummy_token(self) -> None:
47
- """Add a dummy token to keep the state updated even without transcription."""
48
- async with self.lock:
49
- current_time = time() - self.beg_loop
50
- dummy_token = ASRToken(
51
- start=current_time,
52
- end=current_time + 1,
53
- text=".",
54
- speaker=-1,
55
- is_dummy=True
56
- )
57
- self.tokens.append(dummy_token)
58
-
59
- async def get_current_state(self) -> Dict[str, Any]:
60
- """Get the current state with calculated timing information."""
61
- async with self.lock:
62
- current_time = time()
63
- remaining_time_transcription = 0
64
- remaining_time_diarization = 0
65
-
66
- # Calculate remaining time for transcription buffer
67
- if self.end_buffer > 0:
68
- remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
69
-
70
- # Calculate remaining time for diarization
71
- if self.tokens:
72
- latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
73
- remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
74
-
75
- return {
76
- "tokens": self.tokens.copy(),
77
- "buffer_transcription": self.buffer_transcription,
78
- "buffer_diarization": self.buffer_diarization,
79
- "end_buffer": self.end_buffer,
80
- "end_attributed_speaker": self.end_attributed_speaker,
81
- "sep": self.sep,
82
- "remaining_time_transcription": remaining_time_transcription,
83
- "remaining_time_diarization": remaining_time_diarization
84
- }
85
-
86
- async def reset(self) -> None:
87
- """Reset the state to initial values."""
88
- async with self.lock:
89
- self.tokens = []
90
- self.buffer_transcription = ""
91
- self.buffer_diarization = ""
92
- self.end_buffer = 0
93
- self.end_attributed_speaker = 0
94
- self.full_transcription = ""
95
- self.beg_loop = time()
96
- self.last_response_content = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisper_fastapi_online_server.py CHANGED
@@ -8,7 +8,7 @@ from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
8
  import asyncio
9
  import logging
10
  from parse_args import parse_args
11
- from audio import AudioProcessor
12
 
13
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
14
  logging.getLogger().setLevel(logging.WARNING)
@@ -80,7 +80,7 @@ async def websocket_endpoint(websocket: WebSocket):
80
  logger.warning("WebSocket disconnected.")
81
  finally:
82
  websocket_task.cancel()
83
- audio_processor.cleanup()
84
  logger.info("WebSocket endpoint cleaned up.")
85
 
86
  if __name__ == "__main__":
 
8
  import asyncio
9
  import logging
10
  from parse_args import parse_args
11
+ from audio_processor import AudioProcessor
12
 
13
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
14
  logging.getLogger().setLevel(logging.WARNING)
 
80
  logger.warning("WebSocket disconnected.")
81
  finally:
82
  websocket_task.cancel()
83
+ await audio_processor.cleanup()
84
  logger.info("WebSocket endpoint cleaned up.")
85
 
86
  if __name__ == "__main__":