SilasKieser commited on
Commit
fd90ec3
·
1 Parent(s): a163c7b
whisper_fastapi_online_server.py CHANGED
@@ -10,7 +10,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
- from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args
14
  from timed_objects import ASRToken
15
 
16
  import math
@@ -160,6 +160,7 @@ async def lifespan(app: FastAPI):
160
  global asr, tokenizer, diarization
161
  if args.transcription:
162
  asr, tokenizer = backend_factory(args)
 
163
  else:
164
  asr, tokenizer = None, None
165
 
 
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
+ from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args,warmup_asr
14
  from timed_objects import ASRToken
15
 
16
  import math
 
160
  global asr, tokenizer, diarization
161
  if args.transcription:
162
  asr, tokenizer = backend_factory(args)
163
+ warmup_asr(asr, args.warmup_file)
164
  else:
165
  asr, tokenizer = None, None
166
 
whisper_streaming_custom/whisper_online.py CHANGED
@@ -227,11 +227,34 @@ def asr_factory(args, logfile=sys.stderr):
227
  online = online_factory(args, asr, tokenizer, logfile=logfile)
228
  return asr, online
229
 
230
- def set_logging(args, logger, others=[]):
231
- logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
232
- logger.setLevel(args.log_level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- for other in others:
235
- logging.getLogger(other).setLevel(args.log_level)
 
236
 
 
237
 
 
227
  online = online_factory(args, asr, tokenizer, logfile=logfile)
228
  return asr, online
229
 
230
+ def warmup_asr(asr, warmup_file=None):
231
+ """
232
+ Warmup the ASR model by transcribing a short audio file.
233
+ """
234
+ if warmup_file:
235
+ warmup_file = warmup_file
236
+ else:
237
+ # Download JFK sample if not already present
238
+ import tempfile
239
+ import os
240
+
241
+
242
+ jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
243
+ temp_dir = tempfile.gettempdir()
244
+ warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
245
+
246
+ if not os.path.exists(warmup_file):
247
+ logger.debug(f"Downloading warmup file from {jfk_url}")
248
+ import urllib.request
249
+ urllib.request.urlretrieve(jfk_url, warmup_file)
250
+
251
+
252
+ # Load the warmup file
253
+ audio, sr = librosa.load(warmup_file, sr=16000)
254
 
255
+ # Process the audio
256
+ asr.transcribe(audio)
257
+
258
 
259
+ logger.info("Whisper is warmed up")
260