no online conflict when multiple users
Browse files- whisper_fastapi_online_server.py +5 -2
- whisper_online.py +11 -9
whisper_fastapi_online_server.py
CHANGED
|
@@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
| 9 |
from fastapi.responses import HTMLResponse
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
|
| 12 |
-
from whisper_online import
|
| 13 |
|
| 14 |
app = FastAPI()
|
| 15 |
app.add_middleware(
|
|
@@ -40,7 +40,7 @@ parser.add_argument(
|
|
| 40 |
add_shared_args(parser)
|
| 41 |
args = parser.parse_args()
|
| 42 |
|
| 43 |
-
asr,
|
| 44 |
|
| 45 |
# Load demo HTML for the root endpoint
|
| 46 |
with open("src/live_transcription.html", "r") as f:
|
|
@@ -85,6 +85,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 85 |
|
| 86 |
ffmpeg_process = await start_ffmpeg_decoder()
|
| 87 |
pcm_buffer = bytearray()
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Continuously read decoded PCM from ffmpeg stdout in a background task
|
| 90 |
async def ffmpeg_stdout_reader():
|
|
|
|
| 9 |
from fastapi.responses import HTMLResponse
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
|
| 12 |
+
from whisper_online import backend_factory, online_factory, add_shared_args
|
| 13 |
|
| 14 |
app = FastAPI()
|
| 15 |
app.add_middleware(
|
|
|
|
| 40 |
add_shared_args(parser)
|
| 41 |
args = parser.parse_args()
|
| 42 |
|
| 43 |
+
asr, tokenizer = backend_factory(args)
|
| 44 |
|
| 45 |
# Load demo HTML for the root endpoint
|
| 46 |
with open("src/live_transcription.html", "r") as f:
|
|
|
|
| 85 |
|
| 86 |
ffmpeg_process = await start_ffmpeg_decoder()
|
| 87 |
pcm_buffer = bytearray()
|
| 88 |
+
print("Loading online.")
|
| 89 |
+
online = online_factory(args, asr, tokenizer)
|
| 90 |
+
print("Online loaded.")
|
| 91 |
|
| 92 |
# Continuously read decoded PCM from ffmpeg stdout in a background task
|
| 93 |
async def ffmpeg_stdout_reader():
|
whisper_online.py
CHANGED
|
@@ -920,11 +920,7 @@ def add_shared_args(parser):
|
|
| 920 |
default="DEBUG",
|
| 921 |
)
|
| 922 |
|
| 923 |
-
|
| 924 |
-
def asr_factory(args, logfile=sys.stderr):
|
| 925 |
-
"""
|
| 926 |
-
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
| 927 |
-
"""
|
| 928 |
backend = args.backend
|
| 929 |
if backend == "openai-api":
|
| 930 |
logger.debug("Using OpenAI API.")
|
|
@@ -967,10 +963,10 @@ def asr_factory(args, logfile=sys.stderr):
|
|
| 967 |
tokenizer = create_tokenizer(tgt_language)
|
| 968 |
else:
|
| 969 |
tokenizer = None
|
|
|
|
| 970 |
|
| 971 |
-
|
| 972 |
if args.vac:
|
| 973 |
-
|
| 974 |
online = VACOnlineASRProcessor(
|
| 975 |
args.min_chunk_size,
|
| 976 |
asr,
|
|
@@ -985,10 +981,16 @@ def asr_factory(args, logfile=sys.stderr):
|
|
| 985 |
logfile=logfile,
|
| 986 |
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
| 987 |
)
|
| 988 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 989 |
return asr, online
|
| 990 |
|
| 991 |
-
|
| 992 |
def set_logging(args, logger, other="_server"):
|
| 993 |
logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
|
| 994 |
logger.setLevel(args.log_level)
|
|
|
|
| 920 |
default="DEBUG",
|
| 921 |
)
|
| 922 |
|
| 923 |
+
def backend_factory(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 924 |
backend = args.backend
|
| 925 |
if backend == "openai-api":
|
| 926 |
logger.debug("Using OpenAI API.")
|
|
|
|
| 963 |
tokenizer = create_tokenizer(tgt_language)
|
| 964 |
else:
|
| 965 |
tokenizer = None
|
| 966 |
+
return asr, tokenizer
|
| 967 |
|
| 968 |
+
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
| 969 |
if args.vac:
|
|
|
|
| 970 |
online = VACOnlineASRProcessor(
|
| 971 |
args.min_chunk_size,
|
| 972 |
asr,
|
|
|
|
| 981 |
logfile=logfile,
|
| 982 |
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
| 983 |
)
|
| 984 |
+
return online
|
| 985 |
+
|
| 986 |
+
def asr_factory(args, logfile=sys.stderr):
|
| 987 |
+
"""
|
| 988 |
+
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
| 989 |
+
"""
|
| 990 |
+
asr, tokenizer = backend_factory(args)
|
| 991 |
+
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
| 992 |
return asr, online
|
| 993 |
|
|
|
|
| 994 |
def set_logging(args, logger, other="_server"):
|
| 995 |
logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
|
| 996 |
logger.setLevel(args.log_level)
|