Tijs Zwinkels commited on
Commit
8896389
·
1 Parent(s): 5929a82

Fix crash when using openai-api with whisper_online_server

Browse files
Files changed (2) hide show
  1. whisper_online.py +32 -21
  2. whisper_online_server.py +1 -24
whisper_online.py CHANGED
@@ -548,6 +548,37 @@ def add_shared_args(parser):
548
  parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
549
  parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  ## main:
552
 
553
  if __name__ == "__main__":
@@ -575,28 +606,8 @@ if __name__ == "__main__":
575
  duration = len(load_audio(audio_path))/SAMPLING_RATE
576
  print("Audio duration is: %2.2f seconds" % duration, file=logfile)
577
 
 
578
  language = args.lan
579
-
580
- if args.backend == "openai-api":
581
- print("Using OpenAI API.",file=logfile)
582
- asr = OpenaiApiASR(lan=language)
583
- else:
584
- if args.backend == "faster-whisper":
585
- asr_cls = FasterWhisperASR
586
- else:
587
- asr_cls = WhisperTimestampedASR
588
-
589
- size = args.model
590
- t = time.time()
591
- print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
592
- asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
593
- e = time.time()
594
- print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
595
-
596
- if args.vad:
597
- print("setting VAD filter",file=logfile)
598
- asr.use_vad()
599
-
600
  if args.task == "translate":
601
  asr.set_translate_task()
602
  tgt_language = "en" # Whisper translates into English
 
548
  parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
549
  parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
550
 
551
+ def asr_factory(args, logfile=sys.stderr):
552
+ """
553
+ Creates and configures an ASR instance based on the specified backend and arguments.
554
+ """
555
+ backend = args.backend
556
+ if backend == "openai-api":
557
+ print("Using OpenAI API.", file=logfile)
558
+ asr = OpenaiApiASR(lan=args.lan)
559
+ else:
560
+ if backend == "faster-whisper":
561
+ from faster_whisper import FasterWhisperASR
562
+ asr_cls = FasterWhisperASR
563
+ else:
564
+ from whisper_timestamped import WhisperTimestampedASR
565
+ asr_cls = WhisperTimestampedASR
566
+
567
+ # Only for FasterWhisperASR and WhisperTimestampedASR
568
+ size = args.model
569
+ t = time.time()
570
+ print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True)
571
+ asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
572
+ e = time.time()
573
+ print(f"done. It took {round(e-t,2)} seconds.", file=logfile)
574
+
575
+ # Apply common configurations
576
+ if getattr(args, 'vad', False): # Checks if VAD argument is present and True
577
+ print("Setting VAD filter", file=logfile)
578
+ asr.use_vad()
579
+
580
+ return asr
581
+
582
  ## main:
583
 
584
  if __name__ == "__main__":
 
606
  duration = len(load_audio(audio_path))/SAMPLING_RATE
607
  print("Audio duration is: %2.2f seconds" % duration, file=logfile)
608
 
609
+ asr = asr_factory(args, logfile=logfile)
610
  language = args.lan
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  if args.task == "translate":
612
  asr.set_translate_task()
613
  tgt_language = "en" # Whisper translates into English
whisper_online_server.py CHANGED
@@ -24,36 +24,13 @@ SAMPLING_RATE = 16000
24
  size = args.model
25
  language = args.lan
26
 
27
- t = time.time()
28
- print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
29
-
30
- if args.backend == "faster-whisper":
31
- from faster_whisper import WhisperModel
32
- asr_cls = FasterWhisperASR
33
- elif args.backend == "openai-api":
34
- asr_cls = OpenaiApiASR
35
- else:
36
- import whisper
37
- import whisper_timestamped
38
- # from whisper_timestamped_model import WhisperTimestampedASR
39
- asr_cls = WhisperTimestampedASR
40
-
41
- asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
42
-
43
  if args.task == "translate":
44
  asr.set_translate_task()
45
  tgt_language = "en"
46
  else:
47
  tgt_language = language
48
 
49
- e = time.time()
50
- print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
51
-
52
- if args.vad:
53
- print("setting VAD filter",file=sys.stderr)
54
- asr.use_vad()
55
-
56
-
57
  min_chunk = args.min_chunk_size
58
 
59
  if args.buffer_trimming == "sentence":
 
24
  size = args.model
25
  language = args.lan
26
 
27
+ asr = asr_factory(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if args.task == "translate":
29
  asr.set_translate_task()
30
  tgt_language = "en"
31
  else:
32
  tgt_language = language
33
 
 
 
 
 
 
 
 
 
34
  min_chunk = args.min_chunk_size
35
 
36
  if args.buffer_trimming == "sentence":