qfuxa commited on
Commit
58e48bb
·
2 Parent(s): 4cb3660 6a04ddb

Merge pull request #10 from SilasK/main

Browse files

More flexibility by using custom tokenize_method + black

silero_vad_iterator.py CHANGED
@@ -6,15 +6,16 @@ import torch
6
 
7
  # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8
 
9
- class VADIterator:
10
- def __init__(self,
11
- model,
12
- threshold: float = 0.5,
13
- sampling_rate: int = 16000,
14
- min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
15
- speech_pad_ms: int = 100 # same
16
- ):
17
 
 
 
 
 
 
 
 
 
 
18
  """
19
  Class for stream imitation
20
 
@@ -41,7 +42,9 @@ class VADIterator:
41
  self.sampling_rate = sampling_rate
42
 
43
  if sampling_rate not in [8000, 16000]:
44
- raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
 
 
45
 
46
  self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
47
  self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
@@ -80,7 +83,13 @@ class VADIterator:
80
  if (speech_prob >= self.threshold) and not self.triggered:
81
  self.triggered = True
82
  speech_start = self.current_sample - self.speech_pad_samples
83
- return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
 
 
 
 
 
 
84
 
85
  if (speech_prob < self.threshold - 0.15) and self.triggered:
86
  if not self.temp_end:
@@ -91,26 +100,35 @@ class VADIterator:
91
  speech_end = self.temp_end + self.speech_pad_samples
92
  self.temp_end = 0
93
  self.triggered = False
94
- return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
 
 
 
 
 
 
95
 
96
  return None
97
 
 
98
  #######################
99
- # because Silero now requires exactly 512-sized audio chunks
100
 
101
  import numpy as np
 
 
102
  class FixedVADIterator(VADIterator):
103
- '''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
104
- If audio to be processed at once is long and multiple voiced segments detected,
105
- then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
106
- '''
107
 
108
  def reset_states(self):
109
  super().reset_states()
110
- self.buffer = np.array([],dtype=np.float32)
111
 
112
  def __call__(self, x, return_seconds=False):
113
- self.buffer = np.append(self.buffer, x)
114
  ret = None
115
  while len(self.buffer) >= 512:
116
  r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
@@ -118,29 +136,28 @@ class FixedVADIterator(VADIterator):
118
  if ret is None:
119
  ret = r
120
  elif r is not None:
121
- if 'end' in r:
122
- ret['end'] = r['end'] # the latter end
123
- if 'start' in r and 'end' in ret: # there is an earlier start.
124
  # Remove end, merging this segment with the previous one.
125
- del ret['end']
126
  return ret if ret != {} else None
127
 
 
128
  if __name__ == "__main__":
129
  # test/demonstrate the need for FixedVADIterator:
130
 
131
  import torch
132
- model, _ = torch.hub.load(
133
- repo_or_dir='snakers4/silero-vad',
134
- model='silero_vad'
135
- )
136
  vac = FixedVADIterator(model)
137
- # vac = VADIterator(model) # the second case crashes with this
138
 
139
  # this works: for both
140
- audio_buffer = np.array([0]*(512),dtype=np.float32)
141
  vac(audio_buffer)
142
 
143
- # this crashes on the non FixedVADIterator with
144
  # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
145
- audio_buffer = np.array([0]*(512-1),dtype=np.float32)
146
  vac(audio_buffer)
 
6
 
7
  # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8
 
 
 
 
 
 
 
 
 
9
 
10
+ class VADIterator:
11
+ def __init__(
12
+ self,
13
+ model,
14
+ threshold: float = 0.5,
15
+ sampling_rate: int = 16000,
16
+ min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
17
+ speech_pad_ms: int = 100, # same
18
+ ):
19
  """
20
  Class for stream imitation
21
 
 
42
  self.sampling_rate = sampling_rate
43
 
44
  if sampling_rate not in [8000, 16000]:
45
+ raise ValueError(
46
+ "VADIterator does not support sampling rates other than [8000, 16000]"
47
+ )
48
 
49
  self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
50
  self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
 
83
  if (speech_prob >= self.threshold) and not self.triggered:
84
  self.triggered = True
85
  speech_start = self.current_sample - self.speech_pad_samples
86
+ return {
87
+ "start": (
88
+ int(speech_start)
89
+ if not return_seconds
90
+ else round(speech_start / self.sampling_rate, 1)
91
+ )
92
+ }
93
 
94
  if (speech_prob < self.threshold - 0.15) and self.triggered:
95
  if not self.temp_end:
 
100
  speech_end = self.temp_end + self.speech_pad_samples
101
  self.temp_end = 0
102
  self.triggered = False
103
+ return {
104
+ "end": (
105
+ int(speech_end)
106
+ if not return_seconds
107
+ else round(speech_end / self.sampling_rate, 1)
108
+ )
109
+ }
110
 
111
  return None
112
 
113
+
114
  #######################
115
+ # because Silero now requires exactly 512-sized audio chunks
116
 
117
  import numpy as np
118
+
119
+
120
  class FixedVADIterator(VADIterator):
121
+ """It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
122
+ If audio to be processed at once is long and multiple voiced segments detected,
123
+ then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
124
+ """
125
 
126
  def reset_states(self):
127
  super().reset_states()
128
+ self.buffer = np.array([], dtype=np.float32)
129
 
130
  def __call__(self, x, return_seconds=False):
131
+ self.buffer = np.append(self.buffer, x)
132
  ret = None
133
  while len(self.buffer) >= 512:
134
  r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
 
136
  if ret is None:
137
  ret = r
138
  elif r is not None:
139
+ if "end" in r:
140
+ ret["end"] = r["end"] # the latter end
141
+ if "start" in r and "end" in ret: # there is an earlier start.
142
  # Remove end, merging this segment with the previous one.
143
+ del ret["end"]
144
  return ret if ret != {} else None
145
 
146
+
147
  if __name__ == "__main__":
148
  # test/demonstrate the need for FixedVADIterator:
149
 
150
  import torch
151
+
152
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
 
 
153
  vac = FixedVADIterator(model)
154
+ # vac = VADIterator(model) # the second case crashes with this
155
 
156
  # this works: for both
157
+ audio_buffer = np.array([0] * (512), dtype=np.float32)
158
  vac(audio_buffer)
159
 
160
+ # this crashes on the non FixedVADIterator with
161
  # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
162
+ audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
163
  vac(audio_buffer)
whisper_fastapi_online_server.py CHANGED
@@ -22,10 +22,21 @@ app.add_middleware(
22
 
23
 
24
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
25
- parser.add_argument("--host", type=str, default='localhost', help="The host address to bind the server to.")
26
- parser.add_argument("--port", type=int, default=8000, help="The port number to bind the server to.")
27
- parser.add_argument("--warmup-file", type=str, dest="warmup_file",
28
- help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
 
 
 
 
 
 
 
 
 
 
 
29
  add_shared_args(parser)
30
  args = parser.parse_args()
31
 
@@ -35,29 +46,38 @@ asr, online = asr_factory(args)
35
  with open("src/live_transcription.html", "r") as f:
36
  html = f.read()
37
 
 
38
  @app.get("/")
39
  async def get():
40
  return HTMLResponse(html)
41
 
 
42
  SAMPLE_RATE = 16000
43
  CHANNELS = 1
44
  SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
45
- BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
46
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
47
 
 
48
  async def start_ffmpeg_decoder():
49
  """
50
  Start an FFmpeg process in async streaming mode that reads WebM from stdin
51
  and outputs raw s16le PCM on stdout. Returns the process object.
52
  """
53
  process = (
54
- ffmpeg
55
- .input('pipe:0', format='webm')
56
- .output('pipe:1', format='s16le', acodec='pcm_s16le', ac=CHANNELS, ar=str(SAMPLE_RATE))
 
 
 
 
 
57
  .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
58
  )
59
  return process
60
 
 
61
  @app.websocket("/asr")
62
  async def websocket_endpoint(websocket: WebSocket):
63
  await websocket.accept()
@@ -65,6 +85,7 @@ async def websocket_endpoint(websocket: WebSocket):
65
 
66
  ffmpeg_process = await start_ffmpeg_decoder()
67
  pcm_buffer = bytearray()
 
68
  # Continuously read decoded PCM from ffmpeg stdout in a background task
69
  async def ffmpeg_stdout_reader():
70
  nonlocal pcm_buffer
@@ -75,10 +96,16 @@ async def websocket_endpoint(websocket: WebSocket):
75
  try:
76
  elapsed_time = int(time() - beg)
77
  beg = time()
78
- chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 32000*elapsed_time)
79
- if not chunk: # The first chunk will be almost empty, FFmpeg is still starting up
80
- chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 4096)
81
- if not chunk: # FFmpeg might have closed
 
 
 
 
 
 
82
  print("FFmpeg stdout closed.")
83
  break
84
 
@@ -86,21 +113,29 @@ async def websocket_endpoint(websocket: WebSocket):
86
 
87
  if len(pcm_buffer) >= BYTES_PER_SEC:
88
  # Convert int16 -> float32
89
- pcm_array = np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
 
 
 
90
  pcm_buffer = bytearray()
91
  online.insert_audio_chunk(pcm_array)
92
  transcription = online.process_iter()[2]
93
  full_transcription += transcription
94
  if args.vac:
95
- buffer = online.online.to_flush(online.online.transcript_buffer.buffer)[2] # We need to access the underlying online object to get the buffer
 
 
 
 
96
  else:
97
  buffer = online.to_flush(online.transcript_buffer.buffer)[2]
98
- if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
 
 
99
  buffer = ""
100
- await websocket.send_json({
101
- "transcription": transcription,
102
- "buffer": buffer
103
- })
104
  except Exception as e:
105
  print(f"Exception in ffmpeg_stdout_reader: {e}")
106
  break
@@ -135,8 +170,11 @@ async def websocket_endpoint(websocket: WebSocket):
135
  pass
136
 
137
  ffmpeg_process.wait()
138
-
139
-
140
  if __name__ == "__main__":
141
  import uvicorn
142
- uvicorn.run("whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True)
 
 
 
 
22
 
23
 
24
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
25
+ parser.add_argument(
26
+ "--host",
27
+ type=str,
28
+ default="localhost",
29
+ help="The host address to bind the server to.",
30
+ )
31
+ parser.add_argument(
32
+ "--port", type=int, default=8000, help="The port number to bind the server to."
33
+ )
34
+ parser.add_argument(
35
+ "--warmup-file",
36
+ type=str,
37
+ dest="warmup_file",
38
+ help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
39
+ )
40
  add_shared_args(parser)
41
  args = parser.parse_args()
42
 
 
46
  with open("src/live_transcription.html", "r") as f:
47
  html = f.read()
48
 
49
+
50
  @app.get("/")
51
  async def get():
52
  return HTMLResponse(html)
53
 
54
+
55
  SAMPLE_RATE = 16000
56
  CHANNELS = 1
57
  SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
58
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
59
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
60
 
61
+
62
  async def start_ffmpeg_decoder():
63
  """
64
  Start an FFmpeg process in async streaming mode that reads WebM from stdin
65
  and outputs raw s16le PCM on stdout. Returns the process object.
66
  """
67
  process = (
68
+ ffmpeg.input("pipe:0", format="webm")
69
+ .output(
70
+ "pipe:1",
71
+ format="s16le",
72
+ acodec="pcm_s16le",
73
+ ac=CHANNELS,
74
+ ar=str(SAMPLE_RATE),
75
+ )
76
  .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
77
  )
78
  return process
79
 
80
+
81
  @app.websocket("/asr")
82
  async def websocket_endpoint(websocket: WebSocket):
83
  await websocket.accept()
 
85
 
86
  ffmpeg_process = await start_ffmpeg_decoder()
87
  pcm_buffer = bytearray()
88
+
89
  # Continuously read decoded PCM from ffmpeg stdout in a background task
90
  async def ffmpeg_stdout_reader():
91
  nonlocal pcm_buffer
 
96
  try:
97
  elapsed_time = int(time() - beg)
98
  beg = time()
99
+ chunk = await loop.run_in_executor(
100
+ None, ffmpeg_process.stdout.read, 32000 * elapsed_time
101
+ )
102
+ if (
103
+ not chunk
104
+ ): # The first chunk will be almost empty, FFmpeg is still starting up
105
+ chunk = await loop.run_in_executor(
106
+ None, ffmpeg_process.stdout.read, 4096
107
+ )
108
+ if not chunk: # FFmpeg might have closed
109
  print("FFmpeg stdout closed.")
110
  break
111
 
 
113
 
114
  if len(pcm_buffer) >= BYTES_PER_SEC:
115
  # Convert int16 -> float32
116
+ pcm_array = (
117
+ np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
118
+ / 32768.0
119
+ )
120
  pcm_buffer = bytearray()
121
  online.insert_audio_chunk(pcm_array)
122
  transcription = online.process_iter()[2]
123
  full_transcription += transcription
124
  if args.vac:
125
+ buffer = online.online.to_flush(
126
+ online.online.transcript_buffer.buffer
127
+ )[
128
+ 2
129
+ ] # We need to access the underlying online object to get the buffer
130
  else:
131
  buffer = online.to_flush(online.transcript_buffer.buffer)[2]
132
+ if (
133
+ buffer in full_transcription
134
+ ): # With VAC, the buffer is not updated until the next chunk is processed
135
  buffer = ""
136
+ await websocket.send_json(
137
+ {"transcription": transcription, "buffer": buffer}
138
+ )
 
139
  except Exception as e:
140
  print(f"Exception in ffmpeg_stdout_reader: {e}")
141
  break
 
170
  pass
171
 
172
  ffmpeg_process.wait()
173
+
174
+
175
  if __name__ == "__main__":
176
  import uvicorn
177
+
178
+ uvicorn.run(
179
+ "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True
180
+ )
whisper_online.py CHANGED
@@ -12,26 +12,31 @@ import math
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
15
  @lru_cache(10**6)
16
  def load_audio(fname):
17
  a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
18
  return a
19
 
 
20
  def load_audio_chunk(fname, beg, end):
21
  audio = load_audio(fname)
22
- beg_s = int(beg*16000)
23
- end_s = int(end*16000)
24
  return audio[beg_s:end_s]
25
 
26
 
27
  # Whisper backend
28
 
 
29
  class ASRBase:
30
 
31
- sep = " " # join transcribe words with this character (" " for whisper_timestamped,
32
- # "" for faster-whisper because it emits the spaces when neeeded)
33
 
34
- def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
 
 
35
  self.logfile = logfile
36
 
37
  self.transcribe_kargs = {}
@@ -42,7 +47,6 @@ class ASRBase:
42
 
43
  self.model = self.load_model(modelsize, cache_dir, model_dir)
44
 
45
-
46
  def load_model(self, modelsize, cache_dir):
47
  raise NotImplemented("must be implemented in the child class")
48
 
@@ -64,24 +68,30 @@ class WhisperTimestampedASR(ASRBase):
64
  import whisper
65
  import whisper_timestamped
66
  from whisper_timestamped import transcribe_timestamped
 
67
  self.transcribe_timestamped = transcribe_timestamped
68
  if model_dir is not None:
69
  logger.debug("ignoring model_dir, not implemented")
70
  return whisper.load_model(modelsize, download_root=cache_dir)
71
 
72
  def transcribe(self, audio, init_prompt=""):
73
- result = self.transcribe_timestamped(self.model,
74
- audio, language=self.original_language,
75
- initial_prompt=init_prompt, verbose=None,
76
- condition_on_previous_text=True, **self.transcribe_kargs)
 
 
 
 
 
77
  return result
78
-
79
- def ts_words(self,r):
80
  # return: transcribe result object to [(beg,end,"word1"), ...]
81
  o = []
82
  for s in r["segments"]:
83
  for w in s["words"]:
84
- t = (w["start"],w["end"],w["text"])
85
  o.append(t)
86
  return o
87
 
@@ -95,43 +105,55 @@ class WhisperTimestampedASR(ASRBase):
95
  self.transcribe_kargs["task"] = "translate"
96
 
97
 
98
-
99
-
100
  class FasterWhisperASR(ASRBase):
101
- """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
102
- """
103
 
104
  sep = ""
105
 
106
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
107
  from faster_whisper import WhisperModel
108
- # logging.getLogger("faster_whisper").setLevel(logger.level)
 
109
  if model_dir is not None:
110
- logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
 
 
111
  model_size_or_path = model_dir
112
  elif modelsize is not None:
113
  model_size_or_path = modelsize
114
  else:
115
  raise ValueError("modelsize or model_dir parameter must be set")
116
 
117
-
118
  # this worked fast and reliably on NVIDIA L40
119
- model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
 
 
 
 
 
120
 
121
  # or run on GPU with INT8
122
  # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
123
- #model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
124
 
125
  # or run on CPU with INT8
126
  # tested: works, but slow, appx 10-times than cuda FP16
127
- # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
128
  return model
129
 
130
  def transcribe(self, audio, init_prompt=""):
131
 
132
  # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
133
- segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
134
- #print(info) # info contains language detection result
 
 
 
 
 
 
 
 
135
 
136
  return list(segments)
137
 
@@ -156,40 +178,45 @@ class FasterWhisperASR(ASRBase):
156
  def set_translate_task(self):
157
  self.transcribe_kargs["task"] = "translate"
158
 
 
159
  class MLXWhisper(ASRBase):
160
  """
161
  Uses MPX Whisper library as the backend, optimized for Apple Silicon.
162
  Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
163
- Significantly faster than faster-whisper (without CUDA) on Apple M1.
164
  """
165
 
166
  sep = " "
167
 
168
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
169
  """
170
- Loads the MLX-compatible Whisper model.
171
-
172
- Args:
173
- modelsize (str, optional): The size or name of the Whisper model to load.
174
- If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
175
- Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
176
- cache_dir (str, optional): Path to the directory for caching models.
177
- **Note**: This is not supported by MLX Whisper and will be ignored.
178
- model_dir (str, optional): Direct path to a custom model directory.
179
- If specified, it overrides the `modelsize` parameter.
180
  """
181
  from mlx_whisper import transcribe
182
 
183
  if model_dir is not None:
184
- logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
 
 
185
  model_size_or_path = model_dir
186
  elif modelsize is not None:
187
  model_size_or_path = self.translate_model_name(modelsize)
188
- logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
189
-
 
 
190
  self.model_size_or_path = model_size_or_path
191
  return transcribe
192
-
193
  def translate_model_name(self, model_name):
194
  """
195
  Translates a given model name to its corresponding MLX-compatible model path.
@@ -214,7 +241,7 @@ class MLXWhisper(ASRBase):
214
  "large-v2": "mlx-community/whisper-large-v2-mlx",
215
  "large-v3": "mlx-community/whisper-large-v3-mlx",
216
  "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
217
- "large": "mlx-community/whisper-large-mlx"
218
  }
219
 
220
  # Retrieve the corresponding MLX model path
@@ -223,8 +250,10 @@ class MLXWhisper(ASRBase):
223
  if mlx_model_path:
224
  return mlx_model_path
225
  else:
226
- raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
227
-
 
 
228
  def transcribe(self, audio, init_prompt=""):
229
  segments = self.model(
230
  audio,
@@ -233,11 +262,10 @@ class MLXWhisper(ASRBase):
233
  word_timestamps=True,
234
  condition_on_previous_text=True,
235
  path_or_hf_repo=self.model_size_or_path,
236
- **self.transcribe_kargs
237
  )
238
  return segments.get("segments", [])
239
 
240
-
241
  def ts_words(self, segments):
242
  """
243
  Extract timestamped words from transcription segments and skips words with high no-speech probability.
@@ -248,9 +276,9 @@ class MLXWhisper(ASRBase):
248
  for word in segment.get("words", [])
249
  if segment.get("no_speech_prob", 0) <= 0.9
250
  ]
251
-
252
  def segments_end_ts(self, res):
253
- return [s['end'] for s in res]
254
 
255
  def use_vad(self):
256
  self.transcribe_kargs["vad_filter"] = True
@@ -258,15 +286,18 @@ class MLXWhisper(ASRBase):
258
  def set_translate_task(self):
259
  self.transcribe_kargs["task"] = "translate"
260
 
 
261
  class OpenaiApiASR(ASRBase):
262
  """Uses OpenAI's Whisper API for audio transcription."""
263
 
264
  def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
265
  self.logfile = logfile
266
 
267
- self.modelname = "whisper-1"
268
- self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
269
- self.response_format = "verbose_json"
 
 
270
  self.temperature = temperature
271
 
272
  self.load_model()
@@ -278,10 +309,12 @@ class OpenaiApiASR(ASRBase):
278
 
279
  def load_model(self, *args, **kwargs):
280
  from openai import OpenAI
 
281
  self.client = OpenAI()
282
 
283
- self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
284
-
 
285
 
286
  def ts_words(self, segments):
287
  no_speech_segments = []
@@ -289,7 +322,9 @@ class OpenaiApiASR(ASRBase):
289
  for segment in segments.segments:
290
  # TODO: threshold can be set from outside
291
  if segment["no_speech_prob"] > 0.8:
292
- no_speech_segments.append((segment.get("start"), segment.get("end")))
 
 
293
 
294
  o = []
295
  for word in segments.words:
@@ -301,7 +336,6 @@ class OpenaiApiASR(ASRBase):
301
  o.append((start, end, word.word))
302
  return o
303
 
304
-
305
  def segments_end_ts(self, res):
306
  return [s.end for s in res.words]
307
 
@@ -309,17 +343,19 @@ class OpenaiApiASR(ASRBase):
309
  # Write the audio data to a buffer
310
  buffer = io.BytesIO()
311
  buffer.name = "temp.wav"
312
- sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
313
  buffer.seek(0) # Reset buffer's position to the beginning
314
 
315
- self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
 
 
316
 
317
  params = {
318
  "model": self.modelname,
319
  "file": buffer,
320
  "response_format": self.response_format,
321
  "temperature": self.temperature,
322
- "timestamp_granularities": ["word", "segment"]
323
  }
324
  if self.task != "translate" and self.original_language:
325
  params["language"] = self.original_language
@@ -333,7 +369,9 @@ class OpenaiApiASR(ASRBase):
333
 
334
  # Process transcription/translation
335
  transcript = proc.create(**params)
336
- logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
 
 
337
 
338
  return transcript
339
 
@@ -344,8 +382,6 @@ class OpenaiApiASR(ASRBase):
344
  self.task = "translate"
345
 
346
 
347
-
348
-
349
  class HypothesisBuffer:
350
 
351
  def __init__(self, logfile=sys.stderr):
@@ -361,20 +397,24 @@ class HypothesisBuffer:
361
  def insert(self, new, offset):
362
  # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
363
  # the new tail is added to self.new
364
-
365
- new = [(a+offset,b+offset,t) for a,b,t in new]
366
- self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1]
367
 
368
  if len(self.new) >= 1:
369
- a,b,t = self.new[0]
370
  if abs(a - self.last_commited_time) < 1:
371
  if self.commited_in_buffer:
372
  # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
373
  cn = len(self.commited_in_buffer)
374
  nn = len(self.new)
375
- for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum
376
- c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
377
- tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
 
 
 
 
378
  if c == tail:
379
  words = []
380
  for j in range(i):
@@ -384,7 +424,7 @@ class HypothesisBuffer:
384
  break
385
 
386
  def flush(self):
387
- # returns commited chunk = the longest common prefix of 2 last inserts.
388
 
389
  commit = []
390
  while self.new:
@@ -394,7 +434,7 @@ class HypothesisBuffer:
394
  break
395
 
396
  if nt == self.buffer[0][2]:
397
- commit.append((na,nb,nt))
398
  self.last_commited_word = nt
399
  self.last_commited_time = nb
400
  self.buffer.pop(0)
@@ -413,19 +453,26 @@ class HypothesisBuffer:
413
  def complete(self):
414
  return self.buffer
415
 
 
416
  class OnlineASRProcessor:
417
 
418
  SAMPLING_RATE = 16000
419
 
420
- def __init__(self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr):
 
 
 
 
 
 
421
  """asr: WhisperASR object
422
- tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
423
  ("segment", 15)
424
  buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
425
- logfile: where to store the log.
426
  """
427
  self.asr = asr
428
- self.tokenizer = tokenizer
429
  self.logfile = logfile
430
 
431
  self.init()
@@ -434,7 +481,7 @@ class OnlineASRProcessor:
434
 
435
  def init(self, offset=None):
436
  """run this when starting or restarting processing"""
437
- self.audio_buffer = np.array([],dtype=np.float32)
438
  self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
439
  self.buffer_time_offset = 0
440
  if offset is not None:
@@ -446,34 +493,38 @@ class OnlineASRProcessor:
446
  self.audio_buffer = np.append(self.audio_buffer, audio)
447
 
448
  def prompt(self):
449
- """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
450
  "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
451
  """
452
- k = max(0,len(self.commited)-1)
453
- while k > 0 and self.commited[k-1][1] > self.buffer_time_offset:
454
  k -= 1
455
 
456
  p = self.commited[:k]
457
- p = [t for _,_,t in p]
458
  prompt = []
459
  l = 0
460
  while p and l < 200: # 200 characters prompt size
461
  x = p.pop(-1)
462
- l += len(x)+1
463
  prompt.append(x)
464
  non_prompt = self.commited[k:]
465
- return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt)
 
 
466
 
467
  def process_iter(self):
468
  """Runs on the current audio buffer.
469
- Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
470
  The non-emty text is confirmed (committed) partial transcript.
471
  """
472
 
473
  prompt, non_prompt = self.prompt()
474
  logger.debug(f"PROMPT: {prompt}")
475
  logger.debug(f"CONTEXT: {non_prompt}")
476
- logger.debug(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}")
 
 
477
  res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
478
 
479
  # transform to [(beg,end,"word1"), ...]
@@ -483,41 +534,45 @@ class OnlineASRProcessor:
483
  o = self.transcript_buffer.flush()
484
  self.commited.extend(o)
485
  completed = self.to_flush(o)
486
- logger.debug(f">>>>COMPLETE NOW: {completed}")
487
  the_rest = self.to_flush(self.transcript_buffer.complete())
488
- logger.debug(f"INCOMPLETE: {the_rest}")
489
 
490
  # there is a newly confirmed text
491
 
492
  if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
493
- if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: # longer than this
 
 
494
  self.chunk_completed_sentence()
495
 
496
-
497
  if self.buffer_trimming_way == "segment":
498
  s = self.buffer_trimming_sec # trim the completed segments longer than s,
499
  else:
500
- s = 30 # if the audio buffer is longer than 30s, trim it
501
-
502
- if len(self.audio_buffer)/self.SAMPLING_RATE > s:
503
  self.chunk_completed_segment(res)
504
 
505
  # alternative: on any word
506
- #l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
507
  # let's find commited word that is less
508
- #k = len(self.commited)-1
509
- #while k>0 and self.commited[k][1] > l:
510
  # k -= 1
511
- #t = self.commited[k][1]
512
  logger.debug("chunking segment")
513
- #self.chunk_at(t)
514
 
515
- logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}")
 
 
516
  return self.to_flush(o)
517
 
518
  def chunk_completed_sentence(self):
519
- if self.commited == []: return
520
- logger.debug(self.commited)
 
521
  sents = self.words_to_sentences(self.commited)
522
  for s in sents:
523
  logger.debug(f"\t\tSENT: {s}")
@@ -532,7 +587,8 @@ class OnlineASRProcessor:
532
  self.chunk_at(chunk_at)
533
 
534
  def chunk_completed_segment(self, res):
535
- if self.commited == []: return
 
536
 
537
  ends = self.asr.segments_end_ts(res)
538
 
@@ -540,10 +596,10 @@ class OnlineASRProcessor:
540
 
541
  if len(ends) > 1:
542
 
543
- e = ends[-2]+self.buffer_time_offset
544
  while len(ends) > 2 and e > t:
545
  ends.pop(-1)
546
- e = ends[-2]+self.buffer_time_offset
547
  if e <= t:
548
  logger.debug(f"--- segment chunked at {e:2.2f}")
549
  self.chunk_at(e)
@@ -552,26 +608,21 @@ class OnlineASRProcessor:
552
  else:
553
  logger.debug(f"--- not enough segments to chunk")
554
 
555
-
556
-
557
-
558
-
559
  def chunk_at(self, time):
560
- """trims the hypothesis and audio buffer at "time"
561
- """
562
  self.transcript_buffer.pop_commited(time)
563
  cut_seconds = time - self.buffer_time_offset
564
- self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):]
565
  self.buffer_time_offset = time
566
 
567
  def words_to_sentences(self, words):
568
- """Uses self.tokenizer for sentence segmentation of words.
569
  Returns: [(beg,end,"sentence 1"),...]
570
  """
571
-
572
  cwords = [w for w in words]
573
  t = " ".join(o[2] for o in cwords)
574
- s = self.tokenizer.split(t)
575
  out = []
576
  while s:
577
  beg = None
@@ -579,15 +630,15 @@ class OnlineASRProcessor:
579
  sent = s.pop(0).strip()
580
  fsent = sent
581
  while cwords:
582
- b,e,w = cwords.pop(0)
583
  w = w.strip()
584
  if beg is None and sent.startswith(w):
585
  beg = b
586
  elif end is None and sent == w:
587
  end = e
588
- out.append((beg,end,fsent))
589
  break
590
- sent = sent[len(w):].strip()
591
  return out
592
 
593
  def finish(self):
@@ -597,11 +648,15 @@ class OnlineASRProcessor:
597
  o = self.transcript_buffer.complete()
598
  f = self.to_flush(o)
599
  logger.debug(f"last, noncommited: {f}")
600
- self.buffer_time_offset += len(self.audio_buffer)/16000
601
  return f
602
 
603
-
604
- def to_flush(self, sents, sep=None, offset=0, ):
 
 
 
 
605
  # concatenates the timestamped words or sentences into one sequence that is flushed in one line
606
  # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
607
  # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
@@ -614,15 +669,16 @@ class OnlineASRProcessor:
614
  else:
615
  b = offset + sents[0][0]
616
  e = offset + sents[-1][1]
617
- return (b,e,t)
 
618
 
619
  class VACOnlineASRProcessor(OnlineASRProcessor):
620
- '''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
621
 
622
- It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
623
- it runs VAD and continuously detects whether there is speech or not.
624
  When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
625
- '''
626
 
627
  def __init__(self, online_chunk_size, *a, **kw):
628
  self.online_chunk_size = online_chunk_size
@@ -631,12 +687,13 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
631
 
632
  # VAC:
633
  import torch
634
- model, _ = torch.hub.load(
635
- repo_or_dir='snakers4/silero-vad',
636
- model='silero_vad'
637
- )
638
  from silero_vad_iterator import FixedVADIterator
639
- self.vac = FixedVADIterator(model) # we use the default options there: 500ms silence, 100ms padding, etc.
 
 
 
640
 
641
  self.logfile = self.online.logfile
642
  self.init()
@@ -649,60 +706,65 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
649
  self.is_currently_final = False
650
 
651
  self.status = None # or "voice" or "nonvoice"
652
- self.audio_buffer = np.array([],dtype=np.float32)
653
  self.buffer_offset = 0 # in frames
654
 
655
  def clear_buffer(self):
656
  self.buffer_offset += len(self.audio_buffer)
657
- self.audio_buffer = np.array([],dtype=np.float32)
658
-
659
 
660
  def insert_audio_chunk(self, audio):
661
  res = self.vac(audio)
662
  self.audio_buffer = np.append(self.audio_buffer, audio)
663
 
664
  if res is not None:
665
- frame = list(res.values())[0]-self.buffer_offset
666
- if 'start' in res and 'end' not in res:
667
- self.status = 'voice'
668
  send_audio = self.audio_buffer[frame:]
669
- self.online.init(offset=(frame+self.buffer_offset)/self.SAMPLING_RATE)
 
 
670
  self.online.insert_audio_chunk(send_audio)
671
  self.current_online_chunk_buffer_size += len(send_audio)
672
  self.clear_buffer()
673
- elif 'end' in res and 'start' not in res:
674
- self.status = 'nonvoice'
675
  send_audio = self.audio_buffer[:frame]
676
  self.online.insert_audio_chunk(send_audio)
677
  self.current_online_chunk_buffer_size += len(send_audio)
678
  self.is_currently_final = True
679
  self.clear_buffer()
680
  else:
681
- beg = res["start"]-self.buffer_offset
682
- end = res["end"]-self.buffer_offset
683
- self.status = 'nonvoice'
684
  send_audio = self.audio_buffer[beg:end]
685
- self.online.init(offset=(beg+self.buffer_offset)/self.SAMPLING_RATE)
686
  self.online.insert_audio_chunk(send_audio)
687
  self.current_online_chunk_buffer_size += len(send_audio)
688
  self.is_currently_final = True
689
  self.clear_buffer()
690
  else:
691
- if self.status == 'voice':
692
  self.online.insert_audio_chunk(self.audio_buffer)
693
  self.current_online_chunk_buffer_size += len(self.audio_buffer)
694
  self.clear_buffer()
695
  else:
696
  # We keep 1 second because VAD may later find start of voice in it.
697
- # But we trim it to prevent OOM.
698
- self.buffer_offset += max(0,len(self.audio_buffer)-self.SAMPLING_RATE)
699
- self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
700
-
 
701
 
702
  def process_iter(self):
703
  if self.is_currently_final:
704
  return self.finish()
705
- elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE*self.online_chunk_size:
 
 
 
706
  self.current_online_chunk_buffer_size = 0
707
  ret = self.online.process_iter()
708
  return ret
@@ -717,37 +779,55 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
717
  return ret
718
 
719
 
 
 
 
720
 
721
- WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",")
722
 
723
  def create_tokenizer(lan):
724
  """returns an object that has split function that works like the one of MosesTokenizer"""
725
 
726
- assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
 
 
727
 
728
  if lan == "uk":
729
  import tokenize_uk
 
730
  class UkrainianTokenizer:
731
  def split(self, text):
732
  return tokenize_uk.tokenize_sents(text)
 
733
  return UkrainianTokenizer()
734
 
735
  # supported by fast-mosestokenizer
736
- if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
 
 
 
737
  from mosestokenizer import MosesTokenizer
 
738
  return MosesTokenizer(lan)
739
 
740
  # the following languages are in Whisper, but not in wtpsplit:
741
- if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
742
- logger.debug(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.")
 
 
 
 
 
743
  lan = None
744
 
745
  from wtpsplit import WtP
 
746
  # downloads the model from huggingface on the first use
747
  wtp = WtP("wtp-canine-s-12l-no-adapters")
 
748
  class WtPtok:
749
  def split(self, sent):
750
  return wtp.split(sent, lang_code=lan)
 
751
  return WtPtok()
752
 
753
 
@@ -755,19 +835,91 @@ def add_shared_args(parser):
755
  """shared args for simulation (this entry point) and server
756
  parser: argparse.ArgumentParser object
757
  """
758
- parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
759
- parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
760
- parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
761
- parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
762
- parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
763
- parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
764
- parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.')
765
- parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
766
- parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
767
- parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
768
- parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
769
- parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
770
- parser.add_argument("-l", "--log-level", dest="log_level", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Set the log level", default='DEBUG')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
 
772
  def asr_factory(args, logfile=sys.stderr):
773
  """
@@ -789,12 +941,17 @@ def asr_factory(args, logfile=sys.stderr):
789
  size = args.model
790
  t = time.time()
791
  logger.info(f"Loading Whisper {size} model for {args.lan}...")
792
- asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
 
 
 
 
 
793
  e = time.time()
794
  logger.info(f"done. It took {round(e-t,2)} seconds.")
795
 
796
  # Apply common configurations
797
- if getattr(args, 'vad', False): # Checks if VAD argument is present and True
798
  logger.info("Setting VAD filter")
799
  asr.use_vad()
800
 
@@ -813,51 +970,82 @@ def asr_factory(args, logfile=sys.stderr):
813
 
814
  # Create the OnlineASRProcessor
815
  if args.vac:
816
-
817
- online = VACOnlineASRProcessor(args.min_chunk_size, asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
 
 
 
 
 
 
818
  else:
819
- online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
 
 
 
 
 
820
 
821
  return asr, online
822
 
823
- def set_logging(args,logger,other="_server"):
824
- logging.basicConfig(#format='%(name)s
825
- format='%(levelname)s\t%(message)s')
826
  logger.setLevel(args.log_level)
827
- logging.getLogger("whisper_online"+other).setLevel(args.log_level)
828
- # logging.getLogger("whisper_online_server").setLevel(args.log_level)
829
 
830
 
 
 
831
 
832
  if __name__ == "__main__":
833
 
834
  import argparse
 
835
  parser = argparse.ArgumentParser()
836
- parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
 
 
 
 
837
  add_shared_args(parser)
838
- parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
839
- parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
840
- parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
841
-
 
 
 
 
 
 
 
 
 
 
 
 
842
  args = parser.parse_args()
843
 
844
  # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
845
  logfile = sys.stderr
846
 
847
  if args.offline and args.comp_unaware:
848
- logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
 
 
849
  sys.exit(1)
850
 
851
- # if args.log_level:
852
- # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
853
- # level=getattr(logging, args.log_level))
854
 
855
- set_logging(args,logger)
856
 
857
  audio_path = args.audio_path
858
 
859
  SAMPLING_RATE = 16000
860
- duration = len(load_audio(audio_path))/SAMPLING_RATE
861
  logger.info("Audio duration is: %2.2f seconds" % duration)
862
 
863
  asr, online = asr_factory(args, logfile=logfile)
@@ -867,13 +1055,13 @@ if __name__ == "__main__":
867
  min_chunk = args.min_chunk_size
868
 
869
  # load the audio into the LRU cache before we start the timer
870
- a = load_audio_chunk(audio_path,0,1)
871
 
872
  # warm up the ASR because the very first transcribe takes much more time than the other
873
  asr.transcribe(a)
874
 
875
  beg = args.start_at
876
- start = time.time()-beg
877
 
878
  def output_transcript(o, now=None):
879
  # output format in stdout is like:
@@ -883,15 +1071,22 @@ if __name__ == "__main__":
883
  # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
884
  # - the next words: segment transcript
885
  if now is None:
886
- now = time.time()-start
887
  if o[0] is not None:
888
- print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
889
- print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
 
 
 
 
 
 
 
890
  else:
891
  # No text, so no output
892
  pass
893
 
894
- if args.offline: ## offline mode processing (for testing/debugging)
895
  a = load_audio(audio_path)
896
  online.insert_audio_chunk(a)
897
  try:
@@ -901,10 +1096,10 @@ if __name__ == "__main__":
901
  else:
902
  output_transcript(o)
903
  now = None
904
- elif args.comp_unaware: # computational unaware mode
905
  end = beg + min_chunk
906
  while True:
907
- a = load_audio_chunk(audio_path,beg,end)
908
  online.insert_audio_chunk(a)
909
  try:
910
  o = online.process_iter()
@@ -918,23 +1113,23 @@ if __name__ == "__main__":
918
 
919
  if end >= duration:
920
  break
921
-
922
  beg = end
923
-
924
  if end + min_chunk > duration:
925
  end = duration
926
  else:
927
  end += min_chunk
928
  now = duration
929
 
930
- else: # online = simultaneous mode
931
  end = 0
932
  while True:
933
  now = time.time() - start
934
- if now < end+min_chunk:
935
- time.sleep(min_chunk+end-now)
936
  end = time.time() - start
937
- a = load_audio_chunk(audio_path,beg,end)
938
  beg = end
939
  online.insert_audio_chunk(a)
940
 
@@ -946,7 +1141,9 @@ if __name__ == "__main__":
946
  else:
947
  output_transcript(o)
948
  now = time.time() - start
949
- logger.debug(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}")
 
 
950
 
951
  if end >= duration:
952
  break
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+
16
  @lru_cache(10**6)
17
  def load_audio(fname):
18
  a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
19
  return a
20
 
21
+
22
  def load_audio_chunk(fname, beg, end):
23
  audio = load_audio(fname)
24
+ beg_s = int(beg * 16000)
25
+ end_s = int(end * 16000)
26
  return audio[beg_s:end_s]
27
 
28
 
29
  # Whisper backend
30
 
31
+
32
  class ASRBase:
33
 
34
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
35
+ # "" for faster-whisper because it emits the spaces when neeeded)
36
 
37
+ def __init__(
38
+ self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
39
+ ):
40
  self.logfile = logfile
41
 
42
  self.transcribe_kargs = {}
 
47
 
48
  self.model = self.load_model(modelsize, cache_dir, model_dir)
49
 
 
50
  def load_model(self, modelsize, cache_dir):
51
  raise NotImplemented("must be implemented in the child class")
52
 
 
68
  import whisper
69
  import whisper_timestamped
70
  from whisper_timestamped import transcribe_timestamped
71
+
72
  self.transcribe_timestamped = transcribe_timestamped
73
  if model_dir is not None:
74
  logger.debug("ignoring model_dir, not implemented")
75
  return whisper.load_model(modelsize, download_root=cache_dir)
76
 
77
  def transcribe(self, audio, init_prompt=""):
78
+ result = self.transcribe_timestamped(
79
+ self.model,
80
+ audio,
81
+ language=self.original_language,
82
+ initial_prompt=init_prompt,
83
+ verbose=None,
84
+ condition_on_previous_text=True,
85
+ **self.transcribe_kargs,
86
+ )
87
  return result
88
+
89
+ def ts_words(self, r):
90
  # return: transcribe result object to [(beg,end,"word1"), ...]
91
  o = []
92
  for s in r["segments"]:
93
  for w in s["words"]:
94
+ t = (w["start"], w["end"], w["text"])
95
  o.append(t)
96
  return o
97
 
 
105
  self.transcribe_kargs["task"] = "translate"
106
 
107
 
 
 
108
  class FasterWhisperASR(ASRBase):
109
+ """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
 
110
 
111
  sep = ""
112
 
113
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
114
  from faster_whisper import WhisperModel
115
+
116
+ # logging.getLogger("faster_whisper").setLevel(logger.level)
117
  if model_dir is not None:
118
+ logger.debug(
119
+ f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
120
+ )
121
  model_size_or_path = model_dir
122
  elif modelsize is not None:
123
  model_size_or_path = modelsize
124
  else:
125
  raise ValueError("modelsize or model_dir parameter must be set")
126
 
 
127
  # this worked fast and reliably on NVIDIA L40
128
+ model = WhisperModel(
129
+ model_size_or_path,
130
+ device="cuda",
131
+ compute_type="float16",
132
+ download_root=cache_dir,
133
+ )
134
 
135
  # or run on GPU with INT8
136
  # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
137
+ # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
138
 
139
  # or run on CPU with INT8
140
  # tested: works, but slow, appx 10-times than cuda FP16
141
+ # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
142
  return model
143
 
144
  def transcribe(self, audio, init_prompt=""):
145
 
146
  # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
147
+ segments, info = self.model.transcribe(
148
+ audio,
149
+ language=self.original_language,
150
+ initial_prompt=init_prompt,
151
+ beam_size=5,
152
+ word_timestamps=True,
153
+ condition_on_previous_text=True,
154
+ **self.transcribe_kargs,
155
+ )
156
+ # print(info) # info contains language detection result
157
 
158
  return list(segments)
159
 
 
178
  def set_translate_task(self):
179
  self.transcribe_kargs["task"] = "translate"
180
 
181
+
182
  class MLXWhisper(ASRBase):
183
  """
184
  Uses MPX Whisper library as the backend, optimized for Apple Silicon.
185
  Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
186
+ Significantly faster than faster-whisper (without CUDA) on Apple M1.
187
  """
188
 
189
  sep = " "
190
 
191
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
192
  """
193
+ Loads the MLX-compatible Whisper model.
194
+
195
+ Args:
196
+ modelsize (str, optional): The size or name of the Whisper model to load.
197
+ If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
198
+ Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
199
+ cache_dir (str, optional): Path to the directory for caching models.
200
+ **Note**: This is not supported by MLX Whisper and will be ignored.
201
+ model_dir (str, optional): Direct path to a custom model directory.
202
+ If specified, it overrides the `modelsize` parameter.
203
  """
204
  from mlx_whisper import transcribe
205
 
206
  if model_dir is not None:
207
+ logger.debug(
208
+ f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
209
+ )
210
  model_size_or_path = model_dir
211
  elif modelsize is not None:
212
  model_size_or_path = self.translate_model_name(modelsize)
213
+ logger.debug(
214
+ f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
215
+ )
216
+
217
  self.model_size_or_path = model_size_or_path
218
  return transcribe
219
+
220
  def translate_model_name(self, model_name):
221
  """
222
  Translates a given model name to its corresponding MLX-compatible model path.
 
241
  "large-v2": "mlx-community/whisper-large-v2-mlx",
242
  "large-v3": "mlx-community/whisper-large-v3-mlx",
243
  "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
244
+ "large": "mlx-community/whisper-large-mlx",
245
  }
246
 
247
  # Retrieve the corresponding MLX model path
 
250
  if mlx_model_path:
251
  return mlx_model_path
252
  else:
253
+ raise ValueError(
254
+ f"Model name '{model_name}' is not recognized or not supported."
255
+ )
256
+
257
  def transcribe(self, audio, init_prompt=""):
258
  segments = self.model(
259
  audio,
 
262
  word_timestamps=True,
263
  condition_on_previous_text=True,
264
  path_or_hf_repo=self.model_size_or_path,
265
+ **self.transcribe_kargs,
266
  )
267
  return segments.get("segments", [])
268
 
 
269
  def ts_words(self, segments):
270
  """
271
  Extract timestamped words from transcription segments and skips words with high no-speech probability.
 
276
  for word in segment.get("words", [])
277
  if segment.get("no_speech_prob", 0) <= 0.9
278
  ]
279
+
280
  def segments_end_ts(self, res):
281
+ return [s["end"] for s in res]
282
 
283
  def use_vad(self):
284
  self.transcribe_kargs["vad_filter"] = True
 
286
  def set_translate_task(self):
287
  self.transcribe_kargs["task"] = "translate"
288
 
289
+
290
  class OpenaiApiASR(ASRBase):
291
  """Uses OpenAI's Whisper API for audio transcription."""
292
 
293
  def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
294
  self.logfile = logfile
295
 
296
+ self.modelname = "whisper-1"
297
+ self.original_language = (
298
+ None if lan == "auto" else lan
299
+ ) # ISO-639-1 language code
300
+ self.response_format = "verbose_json"
301
  self.temperature = temperature
302
 
303
  self.load_model()
 
309
 
310
  def load_model(self, *args, **kwargs):
311
  from openai import OpenAI
312
+
313
  self.client = OpenAI()
314
 
315
+ self.transcribed_seconds = (
316
+ 0 # for logging how many seconds were processed by API, to know the cost
317
+ )
318
 
319
  def ts_words(self, segments):
320
  no_speech_segments = []
 
322
  for segment in segments.segments:
323
  # TODO: threshold can be set from outside
324
  if segment["no_speech_prob"] > 0.8:
325
+ no_speech_segments.append(
326
+ (segment.get("start"), segment.get("end"))
327
+ )
328
 
329
  o = []
330
  for word in segments.words:
 
336
  o.append((start, end, word.word))
337
  return o
338
 
 
339
  def segments_end_ts(self, res):
340
  return [s.end for s in res.words]
341
 
 
343
  # Write the audio data to a buffer
344
  buffer = io.BytesIO()
345
  buffer.name = "temp.wav"
346
+ sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
347
  buffer.seek(0) # Reset buffer's position to the beginning
348
 
349
+ self.transcribed_seconds += math.ceil(
350
+ len(audio_data) / 16000
351
+ ) # it rounds up to the whole seconds
352
 
353
  params = {
354
  "model": self.modelname,
355
  "file": buffer,
356
  "response_format": self.response_format,
357
  "temperature": self.temperature,
358
+ "timestamp_granularities": ["word", "segment"],
359
  }
360
  if self.task != "translate" and self.original_language:
361
  params["language"] = self.original_language
 
369
 
370
  # Process transcription/translation
371
  transcript = proc.create(**params)
372
+ logger.debug(
373
+ f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
374
+ )
375
 
376
  return transcript
377
 
 
382
  self.task = "translate"
383
 
384
 
 
 
385
  class HypothesisBuffer:
386
 
387
  def __init__(self, logfile=sys.stderr):
 
397
  def insert(self, new, offset):
398
  # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
399
  # the new tail is added to self.new
400
+
401
+ new = [(a + offset, b + offset, t) for a, b, t in new]
402
+ self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
403
 
404
  if len(self.new) >= 1:
405
+ a, b, t = self.new[0]
406
  if abs(a - self.last_commited_time) < 1:
407
  if self.commited_in_buffer:
408
  # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
409
  cn = len(self.commited_in_buffer)
410
  nn = len(self.new)
411
+ for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
412
+ c = " ".join(
413
+ [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
414
+ ::-1
415
+ ]
416
+ )
417
+ tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
418
  if c == tail:
419
  words = []
420
  for j in range(i):
 
424
  break
425
 
426
  def flush(self):
427
+ # returns commited chunk = the longest common prefix of 2 last inserts.
428
 
429
  commit = []
430
  while self.new:
 
434
  break
435
 
436
  if nt == self.buffer[0][2]:
437
+ commit.append((na, nb, nt))
438
  self.last_commited_word = nt
439
  self.last_commited_time = nb
440
  self.buffer.pop(0)
 
453
  def complete(self):
454
  return self.buffer
455
 
456
+
457
  class OnlineASRProcessor:
458
 
459
  SAMPLING_RATE = 16000
460
 
461
+ def __init__(
462
+ self,
463
+ asr,
464
+ tokenize_method=None,
465
+ buffer_trimming=("segment", 15),
466
+ logfile=sys.stderr,
467
+ ):
468
  """asr: WhisperASR object
469
+ tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
470
  ("segment", 15)
471
  buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
472
+ logfile: where to store the log.
473
  """
474
  self.asr = asr
475
+ self.tokenize = tokenize_method
476
  self.logfile = logfile
477
 
478
  self.init()
 
481
 
482
  def init(self, offset=None):
483
  """run this when starting or restarting processing"""
484
+ self.audio_buffer = np.array([], dtype=np.float32)
485
  self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
486
  self.buffer_time_offset = 0
487
  if offset is not None:
 
493
  self.audio_buffer = np.append(self.audio_buffer, audio)
494
 
495
  def prompt(self):
496
+ """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
497
  "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
498
  """
499
+ k = max(0, len(self.commited) - 1)
500
+ while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
501
  k -= 1
502
 
503
  p = self.commited[:k]
504
+ p = [t for _, _, t in p]
505
  prompt = []
506
  l = 0
507
  while p and l < 200: # 200 characters prompt size
508
  x = p.pop(-1)
509
+ l += len(x) + 1
510
  prompt.append(x)
511
  non_prompt = self.commited[k:]
512
+ return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
513
+ t for _, _, t in non_prompt
514
+ )
515
 
516
  def process_iter(self):
517
  """Runs on the current audio buffer.
518
+ Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
519
  The non-emty text is confirmed (committed) partial transcript.
520
  """
521
 
522
  prompt, non_prompt = self.prompt()
523
  logger.debug(f"PROMPT: {prompt}")
524
  logger.debug(f"CONTEXT: {non_prompt}")
525
+ logger.debug(
526
+ f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
527
+ )
528
  res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
529
 
530
  # transform to [(beg,end,"word1"), ...]
 
534
  o = self.transcript_buffer.flush()
535
  self.commited.extend(o)
536
  completed = self.to_flush(o)
537
+ logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
538
  the_rest = self.to_flush(self.transcript_buffer.complete())
539
+ logger.debug(f"INCOMPLETE: {the_rest[2]}")
540
 
541
  # there is a newly confirmed text
542
 
543
  if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
544
+ if (
545
+ len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
546
+ ): # longer than this
547
  self.chunk_completed_sentence()
548
 
 
549
  if self.buffer_trimming_way == "segment":
550
  s = self.buffer_trimming_sec # trim the completed segments longer than s,
551
  else:
552
+ s = 30 # if the audio buffer is longer than 30s, trim it
553
+
554
+ if len(self.audio_buffer) / self.SAMPLING_RATE > s:
555
  self.chunk_completed_segment(res)
556
 
557
  # alternative: on any word
558
+ # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
559
  # let's find commited word that is less
560
+ # k = len(self.commited)-1
561
+ # while k>0 and self.commited[k][1] > l:
562
  # k -= 1
563
+ # t = self.commited[k][1]
564
  logger.debug("chunking segment")
565
+ # self.chunk_at(t)
566
 
567
+ logger.debug(
568
+ f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
569
+ )
570
  return self.to_flush(o)
571
 
572
  def chunk_completed_sentence(self):
573
+ if self.commited == []:
574
+ return
575
+ logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
576
  sents = self.words_to_sentences(self.commited)
577
  for s in sents:
578
  logger.debug(f"\t\tSENT: {s}")
 
587
  self.chunk_at(chunk_at)
588
 
589
  def chunk_completed_segment(self, res):
590
+ if self.commited == []:
591
+ return
592
 
593
  ends = self.asr.segments_end_ts(res)
594
 
 
596
 
597
  if len(ends) > 1:
598
 
599
+ e = ends[-2] + self.buffer_time_offset
600
  while len(ends) > 2 and e > t:
601
  ends.pop(-1)
602
+ e = ends[-2] + self.buffer_time_offset
603
  if e <= t:
604
  logger.debug(f"--- segment chunked at {e:2.2f}")
605
  self.chunk_at(e)
 
608
  else:
609
  logger.debug(f"--- not enough segments to chunk")
610
 
 
 
 
 
611
  def chunk_at(self, time):
612
+ """trims the hypothesis and audio buffer at "time" """
 
613
  self.transcript_buffer.pop_commited(time)
614
  cut_seconds = time - self.buffer_time_offset
615
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
616
  self.buffer_time_offset = time
617
 
618
  def words_to_sentences(self, words):
619
+ """Uses self.tokenize for sentence segmentation of words.
620
  Returns: [(beg,end,"sentence 1"),...]
621
  """
622
+
623
  cwords = [w for w in words]
624
  t = " ".join(o[2] for o in cwords)
625
+ s = self.tokenize(t)
626
  out = []
627
  while s:
628
  beg = None
 
630
  sent = s.pop(0).strip()
631
  fsent = sent
632
  while cwords:
633
+ b, e, w = cwords.pop(0)
634
  w = w.strip()
635
  if beg is None and sent.startswith(w):
636
  beg = b
637
  elif end is None and sent == w:
638
  end = e
639
+ out.append((beg, end, fsent))
640
  break
641
+ sent = sent[len(w) :].strip()
642
  return out
643
 
644
  def finish(self):
 
648
  o = self.transcript_buffer.complete()
649
  f = self.to_flush(o)
650
  logger.debug(f"last, noncommited: {f}")
651
+ self.buffer_time_offset += len(self.audio_buffer) / 16000
652
  return f
653
 
654
+ def to_flush(
655
+ self,
656
+ sents,
657
+ sep=None,
658
+ offset=0,
659
+ ):
660
  # concatenates the timestamped words or sentences into one sequence that is flushed in one line
661
  # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
662
  # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
 
669
  else:
670
  b = offset + sents[0][0]
671
  e = offset + sents[-1][1]
672
+ return (b, e, t)
673
+
674
 
675
  class VACOnlineASRProcessor(OnlineASRProcessor):
676
+ """Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
677
 
678
+ It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
679
+ it runs VAD and continuously detects whether there is speech or not.
680
  When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
681
+ """
682
 
683
  def __init__(self, online_chunk_size, *a, **kw):
684
  self.online_chunk_size = online_chunk_size
 
687
 
688
  # VAC:
689
  import torch
690
+
691
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
 
 
692
  from silero_vad_iterator import FixedVADIterator
693
+
694
+ self.vac = FixedVADIterator(
695
+ model
696
+ ) # we use the default options there: 500ms silence, 100ms padding, etc.
697
 
698
  self.logfile = self.online.logfile
699
  self.init()
 
706
  self.is_currently_final = False
707
 
708
  self.status = None # or "voice" or "nonvoice"
709
+ self.audio_buffer = np.array([], dtype=np.float32)
710
  self.buffer_offset = 0 # in frames
711
 
712
  def clear_buffer(self):
713
  self.buffer_offset += len(self.audio_buffer)
714
+ self.audio_buffer = np.array([], dtype=np.float32)
 
715
 
716
  def insert_audio_chunk(self, audio):
717
  res = self.vac(audio)
718
  self.audio_buffer = np.append(self.audio_buffer, audio)
719
 
720
  if res is not None:
721
+ frame = list(res.values())[0] - self.buffer_offset
722
+ if "start" in res and "end" not in res:
723
+ self.status = "voice"
724
  send_audio = self.audio_buffer[frame:]
725
+ self.online.init(
726
+ offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
727
+ )
728
  self.online.insert_audio_chunk(send_audio)
729
  self.current_online_chunk_buffer_size += len(send_audio)
730
  self.clear_buffer()
731
+ elif "end" in res and "start" not in res:
732
+ self.status = "nonvoice"
733
  send_audio = self.audio_buffer[:frame]
734
  self.online.insert_audio_chunk(send_audio)
735
  self.current_online_chunk_buffer_size += len(send_audio)
736
  self.is_currently_final = True
737
  self.clear_buffer()
738
  else:
739
+ beg = res["start"] - self.buffer_offset
740
+ end = res["end"] - self.buffer_offset
741
+ self.status = "nonvoice"
742
  send_audio = self.audio_buffer[beg:end]
743
+ self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
744
  self.online.insert_audio_chunk(send_audio)
745
  self.current_online_chunk_buffer_size += len(send_audio)
746
  self.is_currently_final = True
747
  self.clear_buffer()
748
  else:
749
+ if self.status == "voice":
750
  self.online.insert_audio_chunk(self.audio_buffer)
751
  self.current_online_chunk_buffer_size += len(self.audio_buffer)
752
  self.clear_buffer()
753
  else:
754
  # We keep 1 second because VAD may later find start of voice in it.
755
+ # But we trim it to prevent OOM.
756
+ self.buffer_offset += max(
757
+ 0, len(self.audio_buffer) - self.SAMPLING_RATE
758
+ )
759
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
760
 
761
  def process_iter(self):
762
  if self.is_currently_final:
763
  return self.finish()
764
+ elif (
765
+ self.current_online_chunk_buffer_size
766
+ > self.SAMPLING_RATE * self.online_chunk_size
767
+ ):
768
  self.current_online_chunk_buffer_size = 0
769
  ret = self.online.process_iter()
770
  return ret
 
779
  return ret
780
 
781
 
782
+ WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
783
+ ","
784
+ )
785
 
 
786
 
787
  def create_tokenizer(lan):
788
  """returns an object that has split function that works like the one of MosesTokenizer"""
789
 
790
+ assert (
791
+ lan in WHISPER_LANG_CODES
792
+ ), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
793
 
794
  if lan == "uk":
795
  import tokenize_uk
796
+
797
  class UkrainianTokenizer:
798
  def split(self, text):
799
  return tokenize_uk.tokenize_sents(text)
800
+
801
  return UkrainianTokenizer()
802
 
803
  # supported by fast-mosestokenizer
804
+ if (
805
+ lan
806
+ in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
807
+ ):
808
  from mosestokenizer import MosesTokenizer
809
+
810
  return MosesTokenizer(lan)
811
 
812
  # the following languages are in Whisper, but not in wtpsplit:
813
+ if (
814
+ lan
815
+ in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
816
+ ):
817
+ logger.debug(
818
+ f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
819
+ )
820
  lan = None
821
 
822
  from wtpsplit import WtP
823
+
824
  # downloads the model from huggingface on the first use
825
  wtp = WtP("wtp-canine-s-12l-no-adapters")
826
+
827
  class WtPtok:
828
  def split(self, sent):
829
  return wtp.split(sent, lang_code=lan)
830
+
831
  return WtPtok()
832
 
833
 
 
835
  """shared args for simulation (this entry point) and server
836
  parser: argparse.ArgumentParser object
837
  """
838
+ parser.add_argument(
839
+ "--min-chunk-size",
840
+ type=float,
841
+ default=1.0,
842
+ help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
843
+ )
844
+ parser.add_argument(
845
+ "--model",
846
+ type=str,
847
+ default="large-v2",
848
+ choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
849
+ ","
850
+ ),
851
+ help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
852
+ )
853
+ parser.add_argument(
854
+ "--model_cache_dir",
855
+ type=str,
856
+ default=None,
857
+ help="Overriding the default model cache dir where models downloaded from the hub are saved",
858
+ )
859
+ parser.add_argument(
860
+ "--model_dir",
861
+ type=str,
862
+ default=None,
863
+ help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
864
+ )
865
+ parser.add_argument(
866
+ "--lan",
867
+ "--language",
868
+ type=str,
869
+ default="auto",
870
+ help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
871
+ )
872
+ parser.add_argument(
873
+ "--task",
874
+ type=str,
875
+ default="transcribe",
876
+ choices=["transcribe", "translate"],
877
+ help="Transcribe or translate.",
878
+ )
879
+ parser.add_argument(
880
+ "--backend",
881
+ type=str,
882
+ default="faster-whisper",
883
+ choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
884
+ help="Load only this backend for Whisper processing.",
885
+ )
886
+ parser.add_argument(
887
+ "--vac",
888
+ action="store_true",
889
+ default=False,
890
+ help="Use VAC = voice activity controller. Recommended. Requires torch.",
891
+ )
892
+ parser.add_argument(
893
+ "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
894
+ )
895
+ parser.add_argument(
896
+ "--vad",
897
+ action="store_true",
898
+ default=False,
899
+ help="Use VAD = voice activity detection, with the default parameters.",
900
+ )
901
+ parser.add_argument(
902
+ "--buffer_trimming",
903
+ type=str,
904
+ default="segment",
905
+ choices=["sentence", "segment"],
906
+ help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
907
+ )
908
+ parser.add_argument(
909
+ "--buffer_trimming_sec",
910
+ type=float,
911
+ default=15,
912
+ help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
913
+ )
914
+ parser.add_argument(
915
+ "-l",
916
+ "--log-level",
917
+ dest="log_level",
918
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
919
+ help="Set the log level",
920
+ default="DEBUG",
921
+ )
922
+
923
 
924
  def asr_factory(args, logfile=sys.stderr):
925
  """
 
941
  size = args.model
942
  t = time.time()
943
  logger.info(f"Loading Whisper {size} model for {args.lan}...")
944
+ asr = asr_cls(
945
+ modelsize=size,
946
+ lan=args.lan,
947
+ cache_dir=args.model_cache_dir,
948
+ model_dir=args.model_dir,
949
+ )
950
  e = time.time()
951
  logger.info(f"done. It took {round(e-t,2)} seconds.")
952
 
953
  # Apply common configurations
954
+ if getattr(args, "vad", False): # Checks if VAD argument is present and True
955
  logger.info("Setting VAD filter")
956
  asr.use_vad()
957
 
 
970
 
971
  # Create the OnlineASRProcessor
972
  if args.vac:
973
+
974
+ online = VACOnlineASRProcessor(
975
+ args.min_chunk_size,
976
+ asr,
977
+ tokenizer,
978
+ logfile=logfile,
979
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
980
+ )
981
  else:
982
+ online = OnlineASRProcessor(
983
+ asr,
984
+ tokenizer,
985
+ logfile=logfile,
986
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
987
+ )
988
 
989
  return asr, online
990
 
991
+
992
+ def set_logging(args, logger, other="_server"):
993
+ logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
994
  logger.setLevel(args.log_level)
995
+ logging.getLogger("whisper_online" + other).setLevel(args.log_level)
 
996
 
997
 
998
+ # logging.getLogger("whisper_online_server").setLevel(args.log_level)
999
+
1000
 
1001
  if __name__ == "__main__":
1002
 
1003
  import argparse
1004
+
1005
  parser = argparse.ArgumentParser()
1006
+ parser.add_argument(
1007
+ "audio_path",
1008
+ type=str,
1009
+ help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
1010
+ )
1011
  add_shared_args(parser)
1012
+ parser.add_argument(
1013
+ "--start_at",
1014
+ type=float,
1015
+ default=0.0,
1016
+ help="Start processing audio at this time.",
1017
+ )
1018
+ parser.add_argument(
1019
+ "--offline", action="store_true", default=False, help="Offline mode."
1020
+ )
1021
+ parser.add_argument(
1022
+ "--comp_unaware",
1023
+ action="store_true",
1024
+ default=False,
1025
+ help="Computationally unaware simulation.",
1026
+ )
1027
+
1028
  args = parser.parse_args()
1029
 
1030
  # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
1031
  logfile = sys.stderr
1032
 
1033
  if args.offline and args.comp_unaware:
1034
+ logger.error(
1035
+ "No or one option from --offline and --comp_unaware are available, not both. Exiting."
1036
+ )
1037
  sys.exit(1)
1038
 
1039
+ # if args.log_level:
1040
+ # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
1041
+ # level=getattr(logging, args.log_level))
1042
 
1043
+ set_logging(args, logger)
1044
 
1045
  audio_path = args.audio_path
1046
 
1047
  SAMPLING_RATE = 16000
1048
+ duration = len(load_audio(audio_path)) / SAMPLING_RATE
1049
  logger.info("Audio duration is: %2.2f seconds" % duration)
1050
 
1051
  asr, online = asr_factory(args, logfile=logfile)
 
1055
  min_chunk = args.min_chunk_size
1056
 
1057
  # load the audio into the LRU cache before we start the timer
1058
+ a = load_audio_chunk(audio_path, 0, 1)
1059
 
1060
  # warm up the ASR because the very first transcribe takes much more time than the other
1061
  asr.transcribe(a)
1062
 
1063
  beg = args.start_at
1064
+ start = time.time() - beg
1065
 
1066
  def output_transcript(o, now=None):
1067
  # output format in stdout is like:
 
1071
  # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
1072
  # - the next words: segment transcript
1073
  if now is None:
1074
+ now = time.time() - start
1075
  if o[0] is not None:
1076
+ print(
1077
+ "%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]),
1078
+ file=logfile,
1079
+ flush=True,
1080
+ )
1081
+ print(
1082
+ "%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]),
1083
+ flush=True,
1084
+ )
1085
  else:
1086
  # No text, so no output
1087
  pass
1088
 
1089
+ if args.offline: ## offline mode processing (for testing/debugging)
1090
  a = load_audio(audio_path)
1091
  online.insert_audio_chunk(a)
1092
  try:
 
1096
  else:
1097
  output_transcript(o)
1098
  now = None
1099
+ elif args.comp_unaware: # computational unaware mode
1100
  end = beg + min_chunk
1101
  while True:
1102
+ a = load_audio_chunk(audio_path, beg, end)
1103
  online.insert_audio_chunk(a)
1104
  try:
1105
  o = online.process_iter()
 
1113
 
1114
  if end >= duration:
1115
  break
1116
+
1117
  beg = end
1118
+
1119
  if end + min_chunk > duration:
1120
  end = duration
1121
  else:
1122
  end += min_chunk
1123
  now = duration
1124
 
1125
+ else: # online = simultaneous mode
1126
  end = 0
1127
  while True:
1128
  now = time.time() - start
1129
+ if now < end + min_chunk:
1130
+ time.sleep(min_chunk + end - now)
1131
  end = time.time() - start
1132
+ a = load_audio_chunk(audio_path, beg, end)
1133
  beg = end
1134
  online.insert_audio_chunk(a)
1135
 
 
1141
  else:
1142
  output_transcript(o)
1143
  now = time.time() - start
1144
+ logger.debug(
1145
+ f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}"
1146
+ )
1147
 
1148
  if end >= duration:
1149
  break