qfuxa commited on
Commit
c64fea5
·
2 Parent(s): 5cf0607 6cf18f3

Merge pull request #43 from QuentinFuxa/load-the-model-just-once

Browse files
Files changed (1) hide show
  1. whisper_fastapi_online_server.py +29 -80
whisper_fastapi_online_server.py CHANGED
@@ -4,6 +4,7 @@ import asyncio
4
  import numpy as np
5
  import ffmpeg
6
  from time import time
 
7
 
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.responses import HTMLResponse
@@ -12,73 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
12
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
13
 
14
 
15
- import logging
16
- import logging.config
17
-
18
- def setup_logging():
19
- logging_config = {
20
- 'version': 1,
21
- 'disable_existing_loggers': False,
22
- 'formatters': {
23
- 'standard': {
24
- 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
25
- },
26
- },
27
- 'handlers': {
28
- 'console': {
29
- 'level': 'INFO',
30
- 'class': 'logging.StreamHandler',
31
- 'formatter': 'standard',
32
- },
33
- },
34
- 'root': {
35
- 'handlers': ['console'],
36
- 'level': 'DEBUG',
37
- },
38
- 'loggers': {
39
- 'uvicorn': {
40
- 'handlers': ['console'],
41
- 'level': 'INFO',
42
- 'propagate': False,
43
- },
44
- 'uvicorn.error': {
45
- 'level': 'INFO',
46
- },
47
- 'uvicorn.access': {
48
- 'level': 'INFO',
49
- },
50
- 'src.whisper_streaming.online_asr': { # Add your specific module here
51
- 'handlers': ['console'],
52
- 'level': 'DEBUG',
53
- 'propagate': False,
54
- },
55
- 'src.whisper_streaming.whisper_streaming': { # Add your specific module here
56
- 'handlers': ['console'],
57
- 'level': 'DEBUG',
58
- 'propagate': False,
59
- },
60
- },
61
- }
62
-
63
- logging.config.dictConfig(logging_config)
64
-
65
- setup_logging()
66
- logger = logging.getLogger(__name__)
67
-
68
-
69
-
70
-
71
-
72
-
73
- app = FastAPI()
74
- app.add_middleware(
75
- CORSMiddleware,
76
- allow_origins=["*"],
77
- allow_credentials=True,
78
- allow_methods=["*"],
79
- allow_headers=["*"],
80
- )
81
-
82
 
83
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
84
  parser.add_argument(
@@ -108,28 +43,37 @@ parser.add_argument(
108
  add_shared_args(parser)
109
  args = parser.parse_args()
110
 
111
- asr, tokenizer = backend_factory(args)
 
 
 
 
112
 
113
  if args.diarization:
114
  from src.diarization.diarization_online import DiartDiarization
115
 
116
 
117
- # Load demo HTML for the root endpoint
118
- with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
119
- html = f.read()
120
-
121
 
122
- @app.get("/")
123
- async def get():
124
- return HTMLResponse(html)
 
 
125
 
 
 
 
 
 
 
 
 
126
 
127
- SAMPLE_RATE = 16000
128
- CHANNELS = 1
129
- SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
130
- BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
131
- BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
132
 
 
 
 
133
 
134
  async def start_ffmpeg_decoder():
135
  """
@@ -150,6 +94,11 @@ async def start_ffmpeg_decoder():
150
  return process
151
 
152
 
 
 
 
 
 
153
 
154
  @app.websocket("/asr")
155
  async def websocket_endpoint(websocket: WebSocket):
 
4
  import numpy as np
5
  import ffmpeg
6
  from time import time
7
+ from contextlib import asynccontextmanager
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
 
13
  from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
 
15
 
16
+ ##### LOAD ARGS #####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
19
  parser.add_argument(
 
43
  add_shared_args(parser)
44
  args = parser.parse_args()
45
 
46
+ SAMPLE_RATE = 16000
47
+ CHANNELS = 1
48
+ SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
49
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
50
+ BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
51
 
52
  if args.diarization:
53
  from src.diarization.diarization_online import DiartDiarization
54
 
55
 
56
+ ##### LOAD APP #####
 
 
 
57
 
58
+ @asynccontextmanager
59
+ async def lifespan(app: FastAPI):
60
+ global asr, tokenizer
61
+ asr, tokenizer = backend_factory(args)
62
+ yield
63
 
64
+ app = FastAPI(lifespan=lifespan)
65
+ app.add_middleware(
66
+ CORSMiddleware,
67
+ allow_origins=["*"],
68
+ allow_credentials=True,
69
+ allow_methods=["*"],
70
+ allow_headers=["*"],
71
+ )
72
 
 
 
 
 
 
73
 
74
+ # Load demo HTML for the root endpoint
75
+ with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
76
+ html = f.read()
77
 
78
  async def start_ffmpeg_decoder():
79
  """
 
94
  return process
95
 
96
 
97
+ ##### ENDPOINTS #####
98
+
99
+ @app.get("/")
100
+ async def get():
101
+ return HTMLResponse(html)
102
 
103
  @app.websocket("/asr")
104
  async def websocket_endpoint(websocket: WebSocket):