qfuxa commited on
Commit
aa0ba59
·
1 Parent(s): b7a2d23

no online conflict when multiple users

Browse files
whisper_fastapi_online_server.py CHANGED
@@ -9,7 +9,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.responses import HTMLResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
- from whisper_online import asr_factory, add_shared_args
13
 
14
  app = FastAPI()
15
  app.add_middleware(
@@ -40,7 +40,7 @@ parser.add_argument(
40
  add_shared_args(parser)
41
  args = parser.parse_args()
42
 
43
- asr, online = asr_factory(args)
44
 
45
  # Load demo HTML for the root endpoint
46
  with open("src/live_transcription.html", "r") as f:
@@ -85,6 +85,9 @@ async def websocket_endpoint(websocket: WebSocket):
85
 
86
  ffmpeg_process = await start_ffmpeg_decoder()
87
  pcm_buffer = bytearray()
 
 
 
88
 
89
  # Continuously read decoded PCM from ffmpeg stdout in a background task
90
  async def ffmpeg_stdout_reader():
 
9
  from fastapi.responses import HTMLResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
+ from whisper_online import backend_factory, online_factory, add_shared_args
13
 
14
  app = FastAPI()
15
  app.add_middleware(
 
40
  add_shared_args(parser)
41
  args = parser.parse_args()
42
 
43
+ asr, tokenizer = backend_factory(args)
44
 
45
  # Load demo HTML for the root endpoint
46
  with open("src/live_transcription.html", "r") as f:
 
85
 
86
  ffmpeg_process = await start_ffmpeg_decoder()
87
  pcm_buffer = bytearray()
88
+ print("Loading online.")
89
+ online = online_factory(args, asr, tokenizer)
90
+ print("Online loaded.")
91
 
92
  # Continuously read decoded PCM from ffmpeg stdout in a background task
93
  async def ffmpeg_stdout_reader():
whisper_online.py CHANGED
@@ -920,11 +920,7 @@ def add_shared_args(parser):
920
  default="DEBUG",
921
  )
922
 
923
-
924
- def asr_factory(args, logfile=sys.stderr):
925
- """
926
- Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
927
- """
928
  backend = args.backend
929
  if backend == "openai-api":
930
  logger.debug("Using OpenAI API.")
@@ -967,10 +963,10 @@ def asr_factory(args, logfile=sys.stderr):
967
  tokenizer = create_tokenizer(tgt_language)
968
  else:
969
  tokenizer = None
 
970
 
971
- # Create the OnlineASRProcessor
972
  if args.vac:
973
-
974
  online = VACOnlineASRProcessor(
975
  args.min_chunk_size,
976
  asr,
@@ -985,10 +981,16 @@ def asr_factory(args, logfile=sys.stderr):
985
  logfile=logfile,
986
  buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
987
  )
988
-
 
 
 
 
 
 
 
989
  return asr, online
990
 
991
-
992
  def set_logging(args, logger, other="_server"):
993
  logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
994
  logger.setLevel(args.log_level)
 
920
  default="DEBUG",
921
  )
922
 
923
+ def backend_factory(args):
 
 
 
 
924
  backend = args.backend
925
  if backend == "openai-api":
926
  logger.debug("Using OpenAI API.")
 
963
  tokenizer = create_tokenizer(tgt_language)
964
  else:
965
  tokenizer = None
966
+ return asr, tokenizer
967
 
968
+ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
969
  if args.vac:
 
970
  online = VACOnlineASRProcessor(
971
  args.min_chunk_size,
972
  asr,
 
981
  logfile=logfile,
982
  buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
983
  )
984
+ return online
985
+
986
+ def asr_factory(args, logfile=sys.stderr):
987
+ """
988
+ Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
989
+ """
990
+ asr, tokenizer = backend_factory(args)
991
+ online = online_factory(args, asr, tokenizer, logfile=logfile)
992
  return asr, online
993
 
 
994
  def set_logging(args, logger, other="_server"):
995
  logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
996
  logger.setLevel(args.log_level)