no online conflict when multiple users
Browse files- whisper_fastapi_online_server.py +5 -2
- whisper_online.py +11 -9
whisper_fastapi_online_server.py
CHANGED
@@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
9 |
from fastapi.responses import HTMLResponse
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
11 |
|
12 |
-
from whisper_online import
|
13 |
|
14 |
app = FastAPI()
|
15 |
app.add_middleware(
|
@@ -40,7 +40,7 @@ parser.add_argument(
|
|
40 |
add_shared_args(parser)
|
41 |
args = parser.parse_args()
|
42 |
|
43 |
-
asr,
|
44 |
|
45 |
# Load demo HTML for the root endpoint
|
46 |
with open("src/live_transcription.html", "r") as f:
|
@@ -85,6 +85,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
85 |
|
86 |
ffmpeg_process = await start_ffmpeg_decoder()
|
87 |
pcm_buffer = bytearray()
|
|
|
|
|
|
|
88 |
|
89 |
# Continuously read decoded PCM from ffmpeg stdout in a background task
|
90 |
async def ffmpeg_stdout_reader():
|
|
|
9 |
from fastapi.responses import HTMLResponse
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
11 |
|
12 |
+
from whisper_online import backend_factory, online_factory, add_shared_args
|
13 |
|
14 |
app = FastAPI()
|
15 |
app.add_middleware(
|
|
|
40 |
add_shared_args(parser)
|
41 |
args = parser.parse_args()
|
42 |
|
43 |
+
asr, tokenizer = backend_factory(args)
|
44 |
|
45 |
# Load demo HTML for the root endpoint
|
46 |
with open("src/live_transcription.html", "r") as f:
|
|
|
85 |
|
86 |
ffmpeg_process = await start_ffmpeg_decoder()
|
87 |
pcm_buffer = bytearray()
|
88 |
+
print("Loading online.")
|
89 |
+
online = online_factory(args, asr, tokenizer)
|
90 |
+
print("Online loaded.")
|
91 |
|
92 |
# Continuously read decoded PCM from ffmpeg stdout in a background task
|
93 |
async def ffmpeg_stdout_reader():
|
whisper_online.py
CHANGED
@@ -920,11 +920,7 @@ def add_shared_args(parser):
|
|
920 |
default="DEBUG",
|
921 |
)
|
922 |
|
923 |
-
|
924 |
-
def asr_factory(args, logfile=sys.stderr):
|
925 |
-
"""
|
926 |
-
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
927 |
-
"""
|
928 |
backend = args.backend
|
929 |
if backend == "openai-api":
|
930 |
logger.debug("Using OpenAI API.")
|
@@ -967,10 +963,10 @@ def asr_factory(args, logfile=sys.stderr):
|
|
967 |
tokenizer = create_tokenizer(tgt_language)
|
968 |
else:
|
969 |
tokenizer = None
|
|
|
970 |
|
971 |
-
|
972 |
if args.vac:
|
973 |
-
|
974 |
online = VACOnlineASRProcessor(
|
975 |
args.min_chunk_size,
|
976 |
asr,
|
@@ -985,10 +981,16 @@ def asr_factory(args, logfile=sys.stderr):
|
|
985 |
logfile=logfile,
|
986 |
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
987 |
)
|
988 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
989 |
return asr, online
|
990 |
|
991 |
-
|
992 |
def set_logging(args, logger, other="_server"):
|
993 |
logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
|
994 |
logger.setLevel(args.log_level)
|
|
|
920 |
default="DEBUG",
|
921 |
)
|
922 |
|
923 |
+
def backend_factory(args):
|
|
|
|
|
|
|
|
|
924 |
backend = args.backend
|
925 |
if backend == "openai-api":
|
926 |
logger.debug("Using OpenAI API.")
|
|
|
963 |
tokenizer = create_tokenizer(tgt_language)
|
964 |
else:
|
965 |
tokenizer = None
|
966 |
+
return asr, tokenizer
|
967 |
|
968 |
+
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
969 |
if args.vac:
|
|
|
970 |
online = VACOnlineASRProcessor(
|
971 |
args.min_chunk_size,
|
972 |
asr,
|
|
|
981 |
logfile=logfile,
|
982 |
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
983 |
)
|
984 |
+
return online
|
985 |
+
|
986 |
+
def asr_factory(args, logfile=sys.stderr):
|
987 |
+
"""
|
988 |
+
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
989 |
+
"""
|
990 |
+
asr, tokenizer = backend_factory(args)
|
991 |
+
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
992 |
return asr, online
|
993 |
|
|
|
994 |
def set_logging(args, logger, other="_server"):
|
995 |
logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
|
996 |
logger.setLevel(args.log_level)
|