bluegiraffe-sc commited on
Commit
c0dd2e2
·
1 Parent(s): 2249846

import backend from __init__

Browse files
Files changed (1) hide show
  1. whisper_online.py +19 -8
whisper_online.py CHANGED
@@ -23,15 +23,19 @@ def load_audio_chunk(fname, beg, end):
23
 
24
  class ASRBase:
25
 
26
- # join transcribe words with this character (" " for whisper_timestamped, "" for faster-whisper because it emits the spaces when neeeded)
27
- sep = " "
28
 
29
  def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
30
  self.transcribe_kargs = {}
31
  self.original_language = lan
32
 
 
33
  self.model = self.load_model(modelsize, cache_dir, model_dir)
34
 
 
 
 
35
  def load_model(self, modelsize, cache_dir):
36
  raise NotImplemented("must be implemented in the child class")
37
 
@@ -49,11 +53,14 @@ class ASRBase:
49
  class WhisperTimestampedASR(ASRBase):
50
  """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
51
  On the other hand, the installation for GPU could be easier.
 
52
 
53
- If used, requires imports:
 
 
 
54
  import whisper
55
  import whisper_timestamped
56
- """
57
 
58
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
59
  if model_dir is not None:
@@ -89,8 +96,12 @@ class FasterWhisperASR(ASRBase):
89
 
90
  sep = ""
91
 
 
 
 
 
92
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
93
- from faster_whisper import WhisperModel
94
 
95
 
96
  if model_dir is not None:
@@ -465,11 +476,11 @@ if __name__ == "__main__":
465
  #asr = WhisperASR(lan=language, modelsize=size)
466
 
467
  if args.backend == "faster-whisper":
468
- from faster_whisper import WhisperModel
469
  asr_cls = FasterWhisperASR
470
  else:
471
- import whisper
472
- import whisper_timestamped
473
  # from whisper_timestamped_model import WhisperTimestampedASR
474
  asr_cls = WhisperTimestampedASR
475
 
 
23
 
24
  class ASRBase:
25
 
26
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
27
+ # "" for faster-whisper because it emits the spaces when neeeded)
28
 
29
  def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None):
30
  self.transcribe_kargs = {}
31
  self.original_language = lan
32
 
33
+ self.import_backend()
34
  self.model = self.load_model(modelsize, cache_dir, model_dir)
35
 
36
+ def import_backend(self):
37
+ raise NotImplemented("must be implemented in the child class")
38
+
39
  def load_model(self, modelsize, cache_dir):
40
  raise NotImplemented("must be implemented in the child class")
41
 
 
53
  class WhisperTimestampedASR(ASRBase):
54
  """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
55
  On the other hand, the installation for GPU could be easier.
56
+ """
57
 
58
+ sep = " "
59
+
60
+ def import_backend(self):
61
+ global whisper, whisper_timestamped
62
  import whisper
63
  import whisper_timestamped
 
64
 
65
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
66
  if model_dir is not None:
 
96
 
97
  sep = ""
98
 
99
+ def import_backend(self):
100
+ global faster_whisper
101
+ import faster_whisper
102
+
103
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
104
+ #from faster_whisper import WhisperModel
105
 
106
 
107
  if model_dir is not None:
 
476
  #asr = WhisperASR(lan=language, modelsize=size)
477
 
478
  if args.backend == "faster-whisper":
479
+ #from faster_whisper import WhisperModel
480
  asr_cls = FasterWhisperASR
481
  else:
482
+ #import whisper
483
+ #import whisper_timestamped
484
  # from whisper_timestamped_model import WhisperTimestampedASR
485
  asr_cls = WhisperTimestampedASR
486