qfuxa commited on
Commit
6cf18f3
·
1 Parent(s): 59ba1f3

Use lifespan to load the model just one

Browse files
Files changed (1) hide show
  1. whisper_fastapi_online_server.py +29 -21
whisper_fastapi_online_server.py CHANGED
@@ -4,6 +4,7 @@ import asyncio
4
  import numpy as np
5
  import ffmpeg
6
  from time import time
 
7
 
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.responses import HTMLResponse
@@ -11,15 +12,8 @@ from fastapi.middleware.cors import CORSMiddleware
11
 
12
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
13
 
14
- app = FastAPI()
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
 
 
23
 
24
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
25
  parser.add_argument(
@@ -49,28 +43,37 @@ parser.add_argument(
49
  add_shared_args(parser)
50
  args = parser.parse_args()
51
 
52
- asr, tokenizer = backend_factory(args)
 
 
 
 
53
 
54
  if args.diarization:
55
  from src.diarization.diarization_online import DiartDiarization
56
 
57
 
58
- # Load demo HTML for the root endpoint
59
- with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
60
- html = f.read()
61
 
 
 
 
 
 
62
 
63
- @app.get("/")
64
- async def get():
65
- return HTMLResponse(html)
66
-
 
 
 
 
67
 
68
- SAMPLE_RATE = 16000
69
- CHANNELS = 1
70
- SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
71
- BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
72
- BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
73
 
 
 
 
74
 
75
  async def start_ffmpeg_decoder():
76
  """
@@ -91,6 +94,11 @@ async def start_ffmpeg_decoder():
91
  return process
92
 
93
 
 
 
 
 
 
94
 
95
  @app.websocket("/asr")
96
  async def websocket_endpoint(websocket: WebSocket):
 
4
  import numpy as np
5
  import ffmpeg
6
  from time import time
7
+ from contextlib import asynccontextmanager
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
 
12
 
13
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
 
 
 
 
 
 
 
 
 
15
 
16
+ ##### LOAD ARGS #####
17
 
18
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
19
  parser.add_argument(
 
43
  add_shared_args(parser)
44
  args = parser.parse_args()
45
 
46
+ SAMPLE_RATE = 16000
47
+ CHANNELS = 1
48
+ SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
49
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
50
+ BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
51
 
52
  if args.diarization:
53
  from src.diarization.diarization_online import DiartDiarization
54
 
55
 
56
+ ##### LOAD APP #####
 
 
57
 
58
+ @asynccontextmanager
59
+ async def lifespan(app: FastAPI):
60
+ global asr, tokenizer
61
+ asr, tokenizer = backend_factory(args)
62
+ yield
63
 
64
+ app = FastAPI(lifespan=lifespan)
65
+ app.add_middleware(
66
+ CORSMiddleware,
67
+ allow_origins=["*"],
68
+ allow_credentials=True,
69
+ allow_methods=["*"],
70
+ allow_headers=["*"],
71
+ )
72
 
 
 
 
 
 
73
 
74
+ # Load demo HTML for the root endpoint
75
+ with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
76
+ html = f.read()
77
 
78
  async def start_ffmpeg_decoder():
79
  """
 
94
  return process
95
 
96
 
97
+ ##### ENDPOINTS #####
98
+
99
+ @app.get("/")
100
+ async def get():
101
+ return HTMLResponse(html)
102
 
103
  @app.websocket("/asr")
104
  async def websocket_endpoint(websocket: WebSocket):