qfuxa commited on
Commit
d920423
·
1 Parent(s): f6b6b4a

add parameter to disable transcription (only diarization), add time in output

Browse files
Files changed (1) hide show
  1. whisper_fastapi_online_server.py +54 -30
whisper_fastapi_online_server.py CHANGED
@@ -3,7 +3,7 @@ import argparse
3
  import asyncio
4
  import numpy as np
5
  import ffmpeg
6
- from time import time
7
  from contextlib import asynccontextmanager
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
@@ -12,9 +12,12 @@ from fastapi.middleware.cors import CORSMiddleware
12
 
13
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
 
15
- import subprocess
16
  import math
17
  import logging
 
 
 
 
18
 
19
 
20
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -48,6 +51,12 @@ parser.add_argument(
48
  help="Whether to enable speaker diarization.",
49
  )
50
 
 
 
 
 
 
 
51
 
52
  add_shared_args(parser)
53
  args = parser.parse_args()
@@ -68,7 +77,10 @@ if args.diarization:
68
  @asynccontextmanager
69
  async def lifespan(app: FastAPI):
70
  global asr, tokenizer
71
- asr, tokenizer = backend_factory(args)
 
 
 
72
  yield
73
 
74
  app = FastAPI(lifespan=lifespan)
@@ -117,7 +129,7 @@ async def websocket_endpoint(websocket: WebSocket):
117
 
118
  ffmpeg_process = None
119
  pcm_buffer = bytearray()
120
- online = online_factory(args, asr, tokenizer)
121
  diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None
122
 
123
  async def restart_ffmpeg():
@@ -130,7 +142,7 @@ async def websocket_endpoint(websocket: WebSocket):
130
  logger.warning(f"Error killing FFmpeg process: {e}")
131
  ffmpeg_process = await start_ffmpeg_decoder()
132
  pcm_buffer = bytearray()
133
- online = online_factory(args, asr, tokenizer)
134
  if args.diarization:
135
  diarization = DiartDiarization(SAMPLE_RATE)
136
  logger.info("FFmpeg process started.")
@@ -142,7 +154,7 @@ async def websocket_endpoint(websocket: WebSocket):
142
  loop = asyncio.get_event_loop()
143
  full_transcription = ""
144
  beg = time()
145
-
146
  chunk_history = [] # Will store dicts: {beg, end, text, speaker}
147
 
148
  while True:
@@ -184,45 +196,57 @@ async def websocket_endpoint(websocket: WebSocket):
184
  / 32768.0
185
  )
186
  pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
187
- logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
188
- online.insert_audio_chunk(pcm_array)
189
- transcription = online.process_iter()
190
 
191
- if transcription:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  chunk_history.append({
193
- "beg": transcription.start,
194
- "end": transcription.end,
195
- "text": transcription.text,
196
- "speaker": "0"
197
  })
 
 
198
 
199
- full_transcription += transcription.text if transcription else ""
200
- buffer = online.get_buffer()
201
-
202
- if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
203
- buffer = ""
204
-
205
- lines = [
206
- {
207
- "speaker": "0",
208
- "text": "",
209
- }
210
- ]
211
-
212
  if args.diarization:
213
  await diarization.diarize(pcm_array)
214
  diarization.assign_speakers_to_chunks(chunk_history)
215
 
 
 
 
 
 
 
 
 
216
  for ch in chunk_history:
217
- if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
 
218
  lines.append(
219
  {
220
- "speaker": ch["speaker"][-1],
221
- "text": ch['text']
 
 
222
  }
223
  )
 
224
  else:
225
  lines[-1]["text"] += ch['text']
 
226
 
227
  response = {"lines": lines, "buffer": buffer}
228
  await websocket.send_json(response)
 
3
  import asyncio
4
  import numpy as np
5
  import ffmpeg
6
+ from time import time, sleep
7
  from contextlib import asynccontextmanager
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
12
 
13
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
 
 
15
  import math
16
  import logging
17
+ from datetime import timedelta
18
+
19
+ def format_time(seconds):
20
+ return str(timedelta(seconds=int(seconds)))
21
 
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
51
  help="Whether to enable speaker diarization.",
52
  )
53
 
54
+ parser.add_argument(
55
+ "--transcription",
56
+ type=bool,
57
+ default=True,
58
+ help="To disable to only see live diarization results.",
59
+ )
60
 
61
  add_shared_args(parser)
62
  args = parser.parse_args()
 
77
  @asynccontextmanager
78
  async def lifespan(app: FastAPI):
79
  global asr, tokenizer
80
+ if args.transcription:
81
+ asr, tokenizer = backend_factory(args)
82
+ else:
83
+ asr, tokenizer = None, None
84
  yield
85
 
86
  app = FastAPI(lifespan=lifespan)
 
129
 
130
  ffmpeg_process = None
131
  pcm_buffer = bytearray()
132
+ online = online_factory(args, asr, tokenizer) if args.transcription else None
133
  diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None
134
 
135
  async def restart_ffmpeg():
 
142
  logger.warning(f"Error killing FFmpeg process: {e}")
143
  ffmpeg_process = await start_ffmpeg_decoder()
144
  pcm_buffer = bytearray()
145
+ online = online_factory(args, asr, tokenizer) if args.transcription else None
146
  if args.diarization:
147
  diarization = DiartDiarization(SAMPLE_RATE)
148
  logger.info("FFmpeg process started.")
 
154
  loop = asyncio.get_event_loop()
155
  full_transcription = ""
156
  beg = time()
157
+ beg_loop = time()
158
  chunk_history = [] # Will store dicts: {beg, end, text, speaker}
159
 
160
  while True:
 
196
  / 32768.0
197
  )
198
  pcm_buffer = pcm_buffer[MAX_BYTES_PER_SEC:]
 
 
 
199
 
200
+ if args.transcription:
201
+ logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
202
+ online.insert_audio_chunk(pcm_array)
203
+ transcription = online.process_iter()
204
+ if transcription.start:
205
+ chunk_history.append({
206
+ "beg": transcription.start,
207
+ "end": transcription.end,
208
+ "text": transcription.text,
209
+ })
210
+ full_transcription += transcription.text if transcription else ""
211
+ buffer = online.get_buffer()
212
+ if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
213
+ buffer = ""
214
+ else:
215
  chunk_history.append({
216
+ "beg": time() - beg_loop,
217
+ "end": time() - beg_loop + 0.1,
218
+ "text": '',
 
219
  })
220
+ sleep(0.1)
221
+ buffer = ''
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  if args.diarization:
224
  await diarization.diarize(pcm_array)
225
  diarization.assign_speakers_to_chunks(chunk_history)
226
 
227
+
228
+ current_speaker = -1
229
+ lines = [{
230
+ "beg": 0,
231
+ "end": 0,
232
+ "speaker": current_speaker,
233
+ "text": ""
234
+ }]
235
  for ch in chunk_history:
236
+ if args.diarization and ch["speaker"] and ch["speaker"] != current_speaker:
237
+ new_speaker = ch["speaker"]
238
  lines.append(
239
  {
240
+ "speaker": new_speaker,
241
+ "text": ch['text'],
242
+ "beg": format_time(ch['beg']),
243
+ "end": format_time(ch['end']),
244
  }
245
  )
246
+ current_speaker = new_speaker
247
  else:
248
  lines[-1]["text"] += ch['text']
249
+ lines[-1]["end"] = format_time(ch['end'])
250
 
251
  response = {"lines": lines, "buffer": buffer}
252
  await websocket.send_json(response)