qfuxa commited on
Commit
4e2c533
·
2 Parent(s): 0912e2c 5f66658

Merge pull request #53 from QuentinFuxa/diart_integration_improvements

Browse files
Files changed (1) hide show
  1. whisper_fastapi_online_server.py +10 -8
whisper_fastapi_online_server.py CHANGED
@@ -68,19 +68,23 @@ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
68
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
69
  MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
70
 
71
- if args.diarization:
72
- from src.diarization.diarization_online import DiartDiarization
73
 
74
 
75
  ##### LOAD APP #####
76
 
77
  @asynccontextmanager
78
  async def lifespan(app: FastAPI):
79
- global asr, tokenizer
80
  if args.transcription:
81
  asr, tokenizer = backend_factory(args)
82
  else:
83
  asr, tokenizer = None, None
 
 
 
 
 
 
84
  yield
85
 
86
  app = FastAPI(lifespan=lifespan)
@@ -130,10 +134,10 @@ async def websocket_endpoint(websocket: WebSocket):
130
  ffmpeg_process = None
131
  pcm_buffer = bytearray()
132
  online = online_factory(args, asr, tokenizer) if args.transcription else None
133
- diarization = DiartDiarization(SAMPLE_RATE) if args.diarization else None
134
 
135
  async def restart_ffmpeg():
136
- nonlocal ffmpeg_process, online, diarization, pcm_buffer
137
  if ffmpeg_process:
138
  try:
139
  ffmpeg_process.kill()
@@ -143,14 +147,12 @@ async def websocket_endpoint(websocket: WebSocket):
143
  ffmpeg_process = await start_ffmpeg_decoder()
144
  pcm_buffer = bytearray()
145
  online = online_factory(args, asr, tokenizer) if args.transcription else None
146
- if args.diarization:
147
- diarization = DiartDiarization(SAMPLE_RATE)
148
  logger.info("FFmpeg process started.")
149
 
150
  await restart_ffmpeg()
151
 
152
  async def ffmpeg_stdout_reader():
153
- nonlocal ffmpeg_process, online, diarization, pcm_buffer
154
  loop = asyncio.get_event_loop()
155
  full_transcription = ""
156
  beg = time()
 
68
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
69
  MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
70
 
 
 
71
 
72
 
73
  ##### LOAD APP #####
74
 
75
  @asynccontextmanager
76
  async def lifespan(app: FastAPI):
77
+ global asr, tokenizer, diarization
78
  if args.transcription:
79
  asr, tokenizer = backend_factory(args)
80
  else:
81
  asr, tokenizer = None, None
82
+
83
+ if args.diarization:
84
+ from src.diarization.diarization_online import DiartDiarization
85
+ diarization = DiartDiarization(SAMPLE_RATE)
86
+ else :
87
+ diarization = None
88
  yield
89
 
90
  app = FastAPI(lifespan=lifespan)
 
134
  ffmpeg_process = None
135
  pcm_buffer = bytearray()
136
  online = online_factory(args, asr, tokenizer) if args.transcription else None
137
+
138
 
139
  async def restart_ffmpeg():
140
+ nonlocal ffmpeg_process, online, pcm_buffer
141
  if ffmpeg_process:
142
  try:
143
  ffmpeg_process.kill()
 
147
  ffmpeg_process = await start_ffmpeg_decoder()
148
  pcm_buffer = bytearray()
149
  online = online_factory(args, asr, tokenizer) if args.transcription else None
 
 
150
  logger.info("FFmpeg process started.")
151
 
152
  await restart_ffmpeg()
153
 
154
  async def ffmpeg_stdout_reader():
155
+ nonlocal ffmpeg_process, online, pcm_buffer
156
  loop = asyncio.get_event_loop()
157
  full_transcription = ""
158
  beg = time()