SilasKieser commited on
Commit
5fdb08e
·
1 Parent(s): 4cb3660

black formating

Browse files
silero_vad_iterator.py CHANGED
@@ -6,15 +6,16 @@ import torch
6
 
7
  # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8
 
9
- class VADIterator:
10
- def __init__(self,
11
- model,
12
- threshold: float = 0.5,
13
- sampling_rate: int = 16000,
14
- min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
15
- speech_pad_ms: int = 100 # same
16
- ):
17
 
 
 
 
 
 
 
 
 
 
18
  """
19
  Class for stream imitation
20
 
@@ -41,7 +42,9 @@ class VADIterator:
41
  self.sampling_rate = sampling_rate
42
 
43
  if sampling_rate not in [8000, 16000]:
44
- raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
 
 
45
 
46
  self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
47
  self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
@@ -80,7 +83,13 @@ class VADIterator:
80
  if (speech_prob >= self.threshold) and not self.triggered:
81
  self.triggered = True
82
  speech_start = self.current_sample - self.speech_pad_samples
83
- return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
 
 
 
 
 
 
84
 
85
  if (speech_prob < self.threshold - 0.15) and self.triggered:
86
  if not self.temp_end:
@@ -91,26 +100,35 @@ class VADIterator:
91
  speech_end = self.temp_end + self.speech_pad_samples
92
  self.temp_end = 0
93
  self.triggered = False
94
- return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
 
 
 
 
 
 
95
 
96
  return None
97
 
 
98
  #######################
99
- # because Silero now requires exactly 512-sized audio chunks
100
 
101
  import numpy as np
 
 
102
  class FixedVADIterator(VADIterator):
103
- '''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
104
- If audio to be processed at once is long and multiple voiced segments detected,
105
- then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
106
- '''
107
 
108
  def reset_states(self):
109
  super().reset_states()
110
- self.buffer = np.array([],dtype=np.float32)
111
 
112
  def __call__(self, x, return_seconds=False):
113
- self.buffer = np.append(self.buffer, x)
114
  ret = None
115
  while len(self.buffer) >= 512:
116
  r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
@@ -118,29 +136,28 @@ class FixedVADIterator(VADIterator):
118
  if ret is None:
119
  ret = r
120
  elif r is not None:
121
- if 'end' in r:
122
- ret['end'] = r['end'] # the latter end
123
- if 'start' in r and 'end' in ret: # there is an earlier start.
124
  # Remove end, merging this segment with the previous one.
125
- del ret['end']
126
  return ret if ret != {} else None
127
 
 
128
  if __name__ == "__main__":
129
  # test/demonstrate the need for FixedVADIterator:
130
 
131
  import torch
132
- model, _ = torch.hub.load(
133
- repo_or_dir='snakers4/silero-vad',
134
- model='silero_vad'
135
- )
136
  vac = FixedVADIterator(model)
137
- # vac = VADIterator(model) # the second case crashes with this
138
 
139
  # this works: for both
140
- audio_buffer = np.array([0]*(512),dtype=np.float32)
141
  vac(audio_buffer)
142
 
143
- # this crashes on the non FixedVADIterator with
144
  # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
145
- audio_buffer = np.array([0]*(512-1),dtype=np.float32)
146
  vac(audio_buffer)
 
6
 
7
  # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8
 
 
 
 
 
 
 
 
 
9
 
10
+ class VADIterator:
11
+ def __init__(
12
+ self,
13
+ model,
14
+ threshold: float = 0.5,
15
+ sampling_rate: int = 16000,
16
+ min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
17
+ speech_pad_ms: int = 100, # same
18
+ ):
19
  """
20
  Class for stream imitation
21
 
 
42
  self.sampling_rate = sampling_rate
43
 
44
  if sampling_rate not in [8000, 16000]:
45
+ raise ValueError(
46
+ "VADIterator does not support sampling rates other than [8000, 16000]"
47
+ )
48
 
49
  self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
50
  self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
 
83
  if (speech_prob >= self.threshold) and not self.triggered:
84
  self.triggered = True
85
  speech_start = self.current_sample - self.speech_pad_samples
86
+ return {
87
+ "start": (
88
+ int(speech_start)
89
+ if not return_seconds
90
+ else round(speech_start / self.sampling_rate, 1)
91
+ )
92
+ }
93
 
94
  if (speech_prob < self.threshold - 0.15) and self.triggered:
95
  if not self.temp_end:
 
100
  speech_end = self.temp_end + self.speech_pad_samples
101
  self.temp_end = 0
102
  self.triggered = False
103
+ return {
104
+ "end": (
105
+ int(speech_end)
106
+ if not return_seconds
107
+ else round(speech_end / self.sampling_rate, 1)
108
+ )
109
+ }
110
 
111
  return None
112
 
113
+
114
  #######################
115
+ # because Silero now requires exactly 512-sized audio chunks
116
 
117
  import numpy as np
118
+
119
+
120
  class FixedVADIterator(VADIterator):
121
+ """It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
122
+ If audio to be processed at once is long and multiple voiced segments detected,
123
+ then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
124
+ """
125
 
126
  def reset_states(self):
127
  super().reset_states()
128
+ self.buffer = np.array([], dtype=np.float32)
129
 
130
  def __call__(self, x, return_seconds=False):
131
+ self.buffer = np.append(self.buffer, x)
132
  ret = None
133
  while len(self.buffer) >= 512:
134
  r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
 
136
  if ret is None:
137
  ret = r
138
  elif r is not None:
139
+ if "end" in r:
140
+ ret["end"] = r["end"] # the latter end
141
+ if "start" in r and "end" in ret: # there is an earlier start.
142
  # Remove end, merging this segment with the previous one.
143
+ del ret["end"]
144
  return ret if ret != {} else None
145
 
146
+
147
  if __name__ == "__main__":
148
  # test/demonstrate the need for FixedVADIterator:
149
 
150
  import torch
151
+
152
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
 
 
153
  vac = FixedVADIterator(model)
154
+ # vac = VADIterator(model) # the second case crashes with this
155
 
156
  # this works: for both
157
+ audio_buffer = np.array([0] * (512), dtype=np.float32)
158
  vac(audio_buffer)
159
 
160
+ # this crashes on the non FixedVADIterator with
161
  # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
162
+ audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
163
  vac(audio_buffer)
whisper_fastapi_online_server.py CHANGED
@@ -22,10 +22,21 @@ app.add_middleware(
22
 
23
 
24
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
25
- parser.add_argument("--host", type=str, default='localhost', help="The host address to bind the server to.")
26
- parser.add_argument("--port", type=int, default=8000, help="The port number to bind the server to.")
27
- parser.add_argument("--warmup-file", type=str, dest="warmup_file",
28
- help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
 
 
 
 
 
 
 
 
 
 
 
29
  add_shared_args(parser)
30
  args = parser.parse_args()
31
 
@@ -35,29 +46,38 @@ asr, online = asr_factory(args)
35
  with open("src/live_transcription.html", "r") as f:
36
  html = f.read()
37
 
 
38
  @app.get("/")
39
  async def get():
40
  return HTMLResponse(html)
41
 
 
42
  SAMPLE_RATE = 16000
43
  CHANNELS = 1
44
  SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
45
- BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
46
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
47
 
 
48
  async def start_ffmpeg_decoder():
49
  """
50
  Start an FFmpeg process in async streaming mode that reads WebM from stdin
51
  and outputs raw s16le PCM on stdout. Returns the process object.
52
  """
53
  process = (
54
- ffmpeg
55
- .input('pipe:0', format='webm')
56
- .output('pipe:1', format='s16le', acodec='pcm_s16le', ac=CHANNELS, ar=str(SAMPLE_RATE))
 
 
 
 
 
57
  .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
58
  )
59
  return process
60
 
 
61
  @app.websocket("/asr")
62
  async def websocket_endpoint(websocket: WebSocket):
63
  await websocket.accept()
@@ -65,6 +85,7 @@ async def websocket_endpoint(websocket: WebSocket):
65
 
66
  ffmpeg_process = await start_ffmpeg_decoder()
67
  pcm_buffer = bytearray()
 
68
  # Continuously read decoded PCM from ffmpeg stdout in a background task
69
  async def ffmpeg_stdout_reader():
70
  nonlocal pcm_buffer
@@ -75,10 +96,16 @@ async def websocket_endpoint(websocket: WebSocket):
75
  try:
76
  elapsed_time = int(time() - beg)
77
  beg = time()
78
- chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 32000*elapsed_time)
79
- if not chunk: # The first chunk will be almost empty, FFmpeg is still starting up
80
- chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 4096)
81
- if not chunk: # FFmpeg might have closed
 
 
 
 
 
 
82
  print("FFmpeg stdout closed.")
83
  break
84
 
@@ -86,21 +113,29 @@ async def websocket_endpoint(websocket: WebSocket):
86
 
87
  if len(pcm_buffer) >= BYTES_PER_SEC:
88
  # Convert int16 -> float32
89
- pcm_array = np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
 
 
 
90
  pcm_buffer = bytearray()
91
  online.insert_audio_chunk(pcm_array)
92
  transcription = online.process_iter()[2]
93
  full_transcription += transcription
94
  if args.vac:
95
- buffer = online.online.to_flush(online.online.transcript_buffer.buffer)[2] # We need to access the underlying online object to get the buffer
 
 
 
 
96
  else:
97
  buffer = online.to_flush(online.transcript_buffer.buffer)[2]
98
- if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
 
 
99
  buffer = ""
100
- await websocket.send_json({
101
- "transcription": transcription,
102
- "buffer": buffer
103
- })
104
  except Exception as e:
105
  print(f"Exception in ffmpeg_stdout_reader: {e}")
106
  break
@@ -135,8 +170,11 @@ async def websocket_endpoint(websocket: WebSocket):
135
  pass
136
 
137
  ffmpeg_process.wait()
138
-
139
-
140
  if __name__ == "__main__":
141
  import uvicorn
142
- uvicorn.run("whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True)
 
 
 
 
22
 
23
 
24
  parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
25
+ parser.add_argument(
26
+ "--host",
27
+ type=str,
28
+ default="localhost",
29
+ help="The host address to bind the server to.",
30
+ )
31
+ parser.add_argument(
32
+ "--port", type=int, default=8000, help="The port number to bind the server to."
33
+ )
34
+ parser.add_argument(
35
+ "--warmup-file",
36
+ type=str,
37
+ dest="warmup_file",
38
+ help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
39
+ )
40
  add_shared_args(parser)
41
  args = parser.parse_args()
42
 
 
46
  with open("src/live_transcription.html", "r") as f:
47
  html = f.read()
48
 
49
+
50
  @app.get("/")
51
  async def get():
52
  return HTMLResponse(html)
53
 
54
+
55
  SAMPLE_RATE = 16000
56
  CHANNELS = 1
57
  SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
58
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
59
  BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
60
 
61
+
62
  async def start_ffmpeg_decoder():
63
  """
64
  Start an FFmpeg process in async streaming mode that reads WebM from stdin
65
  and outputs raw s16le PCM on stdout. Returns the process object.
66
  """
67
  process = (
68
+ ffmpeg.input("pipe:0", format="webm")
69
+ .output(
70
+ "pipe:1",
71
+ format="s16le",
72
+ acodec="pcm_s16le",
73
+ ac=CHANNELS,
74
+ ar=str(SAMPLE_RATE),
75
+ )
76
  .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
77
  )
78
  return process
79
 
80
+
81
  @app.websocket("/asr")
82
  async def websocket_endpoint(websocket: WebSocket):
83
  await websocket.accept()
 
85
 
86
  ffmpeg_process = await start_ffmpeg_decoder()
87
  pcm_buffer = bytearray()
88
+
89
  # Continuously read decoded PCM from ffmpeg stdout in a background task
90
  async def ffmpeg_stdout_reader():
91
  nonlocal pcm_buffer
 
96
  try:
97
  elapsed_time = int(time() - beg)
98
  beg = time()
99
+ chunk = await loop.run_in_executor(
100
+ None, ffmpeg_process.stdout.read, 32000 * elapsed_time
101
+ )
102
+ if (
103
+ not chunk
104
+ ): # The first chunk will be almost empty, FFmpeg is still starting up
105
+ chunk = await loop.run_in_executor(
106
+ None, ffmpeg_process.stdout.read, 4096
107
+ )
108
+ if not chunk: # FFmpeg might have closed
109
  print("FFmpeg stdout closed.")
110
  break
111
 
 
113
 
114
  if len(pcm_buffer) >= BYTES_PER_SEC:
115
  # Convert int16 -> float32
116
+ pcm_array = (
117
+ np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32)
118
+ / 32768.0
119
+ )
120
  pcm_buffer = bytearray()
121
  online.insert_audio_chunk(pcm_array)
122
  transcription = online.process_iter()[2]
123
  full_transcription += transcription
124
  if args.vac:
125
+ buffer = online.online.to_flush(
126
+ online.online.transcript_buffer.buffer
127
+ )[
128
+ 2
129
+ ] # We need to access the underlying online object to get the buffer
130
  else:
131
  buffer = online.to_flush(online.transcript_buffer.buffer)[2]
132
+ if (
133
+ buffer in full_transcription
134
+ ): # With VAC, the buffer is not updated until the next chunk is processed
135
  buffer = ""
136
+ await websocket.send_json(
137
+ {"transcription": transcription, "buffer": buffer}
138
+ )
 
139
  except Exception as e:
140
  print(f"Exception in ffmpeg_stdout_reader: {e}")
141
  break
 
170
  pass
171
 
172
  ffmpeg_process.wait()
173
+
174
+
175
  if __name__ == "__main__":
176
  import uvicorn
177
+
178
+ uvicorn.run(
179
+ "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True
180
+ )
whisper_online.py CHANGED
@@ -12,26 +12,31 @@ import math
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
15
  @lru_cache(10**6)
16
  def load_audio(fname):
17
  a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
18
  return a
19
 
 
20
  def load_audio_chunk(fname, beg, end):
21
  audio = load_audio(fname)
22
- beg_s = int(beg*16000)
23
- end_s = int(end*16000)
24
  return audio[beg_s:end_s]
25
 
26
 
27
  # Whisper backend
28
 
 
29
  class ASRBase:
30
 
31
- sep = " " # join transcribe words with this character (" " for whisper_timestamped,
32
- # "" for faster-whisper because it emits the spaces when neeeded)
33
 
34
- def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
 
 
35
  self.logfile = logfile
36
 
37
  self.transcribe_kargs = {}
@@ -42,7 +47,6 @@ class ASRBase:
42
 
43
  self.model = self.load_model(modelsize, cache_dir, model_dir)
44
 
45
-
46
  def load_model(self, modelsize, cache_dir):
47
  raise NotImplemented("must be implemented in the child class")
48
 
@@ -64,24 +68,30 @@ class WhisperTimestampedASR(ASRBase):
64
  import whisper
65
  import whisper_timestamped
66
  from whisper_timestamped import transcribe_timestamped
 
67
  self.transcribe_timestamped = transcribe_timestamped
68
  if model_dir is not None:
69
  logger.debug("ignoring model_dir, not implemented")
70
  return whisper.load_model(modelsize, download_root=cache_dir)
71
 
72
  def transcribe(self, audio, init_prompt=""):
73
- result = self.transcribe_timestamped(self.model,
74
- audio, language=self.original_language,
75
- initial_prompt=init_prompt, verbose=None,
76
- condition_on_previous_text=True, **self.transcribe_kargs)
 
 
 
 
 
77
  return result
78
-
79
- def ts_words(self,r):
80
  # return: transcribe result object to [(beg,end,"word1"), ...]
81
  o = []
82
  for s in r["segments"]:
83
  for w in s["words"]:
84
- t = (w["start"],w["end"],w["text"])
85
  o.append(t)
86
  return o
87
 
@@ -95,43 +105,55 @@ class WhisperTimestampedASR(ASRBase):
95
  self.transcribe_kargs["task"] = "translate"
96
 
97
 
98
-
99
-
100
  class FasterWhisperASR(ASRBase):
101
- """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
102
- """
103
 
104
  sep = ""
105
 
106
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
107
  from faster_whisper import WhisperModel
108
- # logging.getLogger("faster_whisper").setLevel(logger.level)
 
109
  if model_dir is not None:
110
- logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
 
 
111
  model_size_or_path = model_dir
112
  elif modelsize is not None:
113
  model_size_or_path = modelsize
114
  else:
115
  raise ValueError("modelsize or model_dir parameter must be set")
116
 
117
-
118
  # this worked fast and reliably on NVIDIA L40
119
- model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
 
 
 
 
 
120
 
121
  # or run on GPU with INT8
122
  # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
123
- #model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
124
 
125
  # or run on CPU with INT8
126
  # tested: works, but slow, appx 10-times than cuda FP16
127
- # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
128
  return model
129
 
130
  def transcribe(self, audio, init_prompt=""):
131
 
132
  # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
133
- segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
134
- #print(info) # info contains language detection result
 
 
 
 
 
 
 
 
135
 
136
  return list(segments)
137
 
@@ -156,40 +178,45 @@ class FasterWhisperASR(ASRBase):
156
  def set_translate_task(self):
157
  self.transcribe_kargs["task"] = "translate"
158
 
 
159
  class MLXWhisper(ASRBase):
160
  """
161
  Uses MPX Whisper library as the backend, optimized for Apple Silicon.
162
  Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
163
- Significantly faster than faster-whisper (without CUDA) on Apple M1.
164
  """
165
 
166
  sep = " "
167
 
168
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
169
  """
170
- Loads the MLX-compatible Whisper model.
171
-
172
- Args:
173
- modelsize (str, optional): The size or name of the Whisper model to load.
174
- If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
175
- Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
176
- cache_dir (str, optional): Path to the directory for caching models.
177
- **Note**: This is not supported by MLX Whisper and will be ignored.
178
- model_dir (str, optional): Direct path to a custom model directory.
179
- If specified, it overrides the `modelsize` parameter.
180
  """
181
  from mlx_whisper import transcribe
182
 
183
  if model_dir is not None:
184
- logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
 
 
185
  model_size_or_path = model_dir
186
  elif modelsize is not None:
187
  model_size_or_path = self.translate_model_name(modelsize)
188
- logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
189
-
 
 
190
  self.model_size_or_path = model_size_or_path
191
  return transcribe
192
-
193
  def translate_model_name(self, model_name):
194
  """
195
  Translates a given model name to its corresponding MLX-compatible model path.
@@ -214,7 +241,7 @@ class MLXWhisper(ASRBase):
214
  "large-v2": "mlx-community/whisper-large-v2-mlx",
215
  "large-v3": "mlx-community/whisper-large-v3-mlx",
216
  "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
217
- "large": "mlx-community/whisper-large-mlx"
218
  }
219
 
220
  # Retrieve the corresponding MLX model path
@@ -223,8 +250,10 @@ class MLXWhisper(ASRBase):
223
  if mlx_model_path:
224
  return mlx_model_path
225
  else:
226
- raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
227
-
 
 
228
  def transcribe(self, audio, init_prompt=""):
229
  segments = self.model(
230
  audio,
@@ -233,11 +262,10 @@ class MLXWhisper(ASRBase):
233
  word_timestamps=True,
234
  condition_on_previous_text=True,
235
  path_or_hf_repo=self.model_size_or_path,
236
- **self.transcribe_kargs
237
  )
238
  return segments.get("segments", [])
239
 
240
-
241
  def ts_words(self, segments):
242
  """
243
  Extract timestamped words from transcription segments and skips words with high no-speech probability.
@@ -248,9 +276,9 @@ class MLXWhisper(ASRBase):
248
  for word in segment.get("words", [])
249
  if segment.get("no_speech_prob", 0) <= 0.9
250
  ]
251
-
252
  def segments_end_ts(self, res):
253
- return [s['end'] for s in res]
254
 
255
  def use_vad(self):
256
  self.transcribe_kargs["vad_filter"] = True
@@ -258,15 +286,18 @@ class MLXWhisper(ASRBase):
258
  def set_translate_task(self):
259
  self.transcribe_kargs["task"] = "translate"
260
 
 
261
  class OpenaiApiASR(ASRBase):
262
  """Uses OpenAI's Whisper API for audio transcription."""
263
 
264
  def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
265
  self.logfile = logfile
266
 
267
- self.modelname = "whisper-1"
268
- self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
269
- self.response_format = "verbose_json"
 
 
270
  self.temperature = temperature
271
 
272
  self.load_model()
@@ -278,10 +309,12 @@ class OpenaiApiASR(ASRBase):
278
 
279
  def load_model(self, *args, **kwargs):
280
  from openai import OpenAI
 
281
  self.client = OpenAI()
282
 
283
- self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
284
-
 
285
 
286
  def ts_words(self, segments):
287
  no_speech_segments = []
@@ -289,7 +322,9 @@ class OpenaiApiASR(ASRBase):
289
  for segment in segments.segments:
290
  # TODO: threshold can be set from outside
291
  if segment["no_speech_prob"] > 0.8:
292
- no_speech_segments.append((segment.get("start"), segment.get("end")))
 
 
293
 
294
  o = []
295
  for word in segments.words:
@@ -301,7 +336,6 @@ class OpenaiApiASR(ASRBase):
301
  o.append((start, end, word.word))
302
  return o
303
 
304
-
305
  def segments_end_ts(self, res):
306
  return [s.end for s in res.words]
307
 
@@ -309,17 +343,19 @@ class OpenaiApiASR(ASRBase):
309
  # Write the audio data to a buffer
310
  buffer = io.BytesIO()
311
  buffer.name = "temp.wav"
312
- sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
313
  buffer.seek(0) # Reset buffer's position to the beginning
314
 
315
- self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
 
 
316
 
317
  params = {
318
  "model": self.modelname,
319
  "file": buffer,
320
  "response_format": self.response_format,
321
  "temperature": self.temperature,
322
- "timestamp_granularities": ["word", "segment"]
323
  }
324
  if self.task != "translate" and self.original_language:
325
  params["language"] = self.original_language
@@ -333,7 +369,9 @@ class OpenaiApiASR(ASRBase):
333
 
334
  # Process transcription/translation
335
  transcript = proc.create(**params)
336
- logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
 
 
337
 
338
  return transcript
339
 
@@ -344,8 +382,6 @@ class OpenaiApiASR(ASRBase):
344
  self.task = "translate"
345
 
346
 
347
-
348
-
349
  class HypothesisBuffer:
350
 
351
  def __init__(self, logfile=sys.stderr):
@@ -361,20 +397,24 @@ class HypothesisBuffer:
361
  def insert(self, new, offset):
362
  # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
363
  # the new tail is added to self.new
364
-
365
- new = [(a+offset,b+offset,t) for a,b,t in new]
366
- self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1]
367
 
368
  if len(self.new) >= 1:
369
- a,b,t = self.new[0]
370
  if abs(a - self.last_commited_time) < 1:
371
  if self.commited_in_buffer:
372
  # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
373
  cn = len(self.commited_in_buffer)
374
  nn = len(self.new)
375
- for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum
376
- c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
377
- tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
 
 
 
 
378
  if c == tail:
379
  words = []
380
  for j in range(i):
@@ -384,7 +424,7 @@ class HypothesisBuffer:
384
  break
385
 
386
  def flush(self):
387
- # returns commited chunk = the longest common prefix of 2 last inserts.
388
 
389
  commit = []
390
  while self.new:
@@ -394,7 +434,7 @@ class HypothesisBuffer:
394
  break
395
 
396
  if nt == self.buffer[0][2]:
397
- commit.append((na,nb,nt))
398
  self.last_commited_word = nt
399
  self.last_commited_time = nb
400
  self.buffer.pop(0)
@@ -413,16 +453,19 @@ class HypothesisBuffer:
413
  def complete(self):
414
  return self.buffer
415
 
 
416
  class OnlineASRProcessor:
417
 
418
  SAMPLING_RATE = 16000
419
 
420
- def __init__(self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr):
 
 
421
  """asr: WhisperASR object
422
  tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
423
  ("segment", 15)
424
  buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
425
- logfile: where to store the log.
426
  """
427
  self.asr = asr
428
  self.tokenizer = tokenizer
@@ -434,7 +477,7 @@ class OnlineASRProcessor:
434
 
435
  def init(self, offset=None):
436
  """run this when starting or restarting processing"""
437
- self.audio_buffer = np.array([],dtype=np.float32)
438
  self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
439
  self.buffer_time_offset = 0
440
  if offset is not None:
@@ -446,34 +489,38 @@ class OnlineASRProcessor:
446
  self.audio_buffer = np.append(self.audio_buffer, audio)
447
 
448
  def prompt(self):
449
- """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
450
  "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
451
  """
452
- k = max(0,len(self.commited)-1)
453
- while k > 0 and self.commited[k-1][1] > self.buffer_time_offset:
454
  k -= 1
455
 
456
  p = self.commited[:k]
457
- p = [t for _,_,t in p]
458
  prompt = []
459
  l = 0
460
  while p and l < 200: # 200 characters prompt size
461
  x = p.pop(-1)
462
- l += len(x)+1
463
  prompt.append(x)
464
  non_prompt = self.commited[k:]
465
- return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt)
 
 
466
 
467
  def process_iter(self):
468
  """Runs on the current audio buffer.
469
- Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
470
  The non-emty text is confirmed (committed) partial transcript.
471
  """
472
 
473
  prompt, non_prompt = self.prompt()
474
  logger.debug(f"PROMPT: {prompt}")
475
  logger.debug(f"CONTEXT: {non_prompt}")
476
- logger.debug(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}")
 
 
477
  res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
478
 
479
  # transform to [(beg,end,"word1"), ...]
@@ -490,33 +537,37 @@ class OnlineASRProcessor:
490
  # there is a newly confirmed text
491
 
492
  if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
493
- if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: # longer than this
 
 
494
  self.chunk_completed_sentence()
495
 
496
-
497
  if self.buffer_trimming_way == "segment":
498
  s = self.buffer_trimming_sec # trim the completed segments longer than s,
499
  else:
500
- s = 30 # if the audio buffer is longer than 30s, trim it
501
-
502
- if len(self.audio_buffer)/self.SAMPLING_RATE > s:
503
  self.chunk_completed_segment(res)
504
 
505
  # alternative: on any word
506
- #l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
507
  # let's find commited word that is less
508
- #k = len(self.commited)-1
509
- #while k>0 and self.commited[k][1] > l:
510
  # k -= 1
511
- #t = self.commited[k][1]
512
  logger.debug("chunking segment")
513
- #self.chunk_at(t)
514
 
515
- logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}")
 
 
516
  return self.to_flush(o)
517
 
518
  def chunk_completed_sentence(self):
519
- if self.commited == []: return
 
520
  logger.debug(self.commited)
521
  sents = self.words_to_sentences(self.commited)
522
  for s in sents:
@@ -532,7 +583,8 @@ class OnlineASRProcessor:
532
  self.chunk_at(chunk_at)
533
 
534
  def chunk_completed_segment(self, res):
535
- if self.commited == []: return
 
536
 
537
  ends = self.asr.segments_end_ts(res)
538
 
@@ -540,10 +592,10 @@ class OnlineASRProcessor:
540
 
541
  if len(ends) > 1:
542
 
543
- e = ends[-2]+self.buffer_time_offset
544
  while len(ends) > 2 and e > t:
545
  ends.pop(-1)
546
- e = ends[-2]+self.buffer_time_offset
547
  if e <= t:
548
  logger.debug(f"--- segment chunked at {e:2.2f}")
549
  self.chunk_at(e)
@@ -552,23 +604,18 @@ class OnlineASRProcessor:
552
  else:
553
  logger.debug(f"--- not enough segments to chunk")
554
 
555
-
556
-
557
-
558
-
559
  def chunk_at(self, time):
560
- """trims the hypothesis and audio buffer at "time"
561
- """
562
  self.transcript_buffer.pop_commited(time)
563
  cut_seconds = time - self.buffer_time_offset
564
- self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):]
565
  self.buffer_time_offset = time
566
 
567
  def words_to_sentences(self, words):
568
  """Uses self.tokenizer for sentence segmentation of words.
569
  Returns: [(beg,end,"sentence 1"),...]
570
  """
571
-
572
  cwords = [w for w in words]
573
  t = " ".join(o[2] for o in cwords)
574
  s = self.tokenizer.split(t)
@@ -579,15 +626,15 @@ class OnlineASRProcessor:
579
  sent = s.pop(0).strip()
580
  fsent = sent
581
  while cwords:
582
- b,e,w = cwords.pop(0)
583
  w = w.strip()
584
  if beg is None and sent.startswith(w):
585
  beg = b
586
  elif end is None and sent == w:
587
  end = e
588
- out.append((beg,end,fsent))
589
  break
590
- sent = sent[len(w):].strip()
591
  return out
592
 
593
  def finish(self):
@@ -597,11 +644,15 @@ class OnlineASRProcessor:
597
  o = self.transcript_buffer.complete()
598
  f = self.to_flush(o)
599
  logger.debug(f"last, noncommited: {f}")
600
- self.buffer_time_offset += len(self.audio_buffer)/16000
601
  return f
602
 
603
-
604
- def to_flush(self, sents, sep=None, offset=0, ):
 
 
 
 
605
  # concatenates the timestamped words or sentences into one sequence that is flushed in one line
606
  # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
607
  # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
@@ -614,15 +665,16 @@ class OnlineASRProcessor:
614
  else:
615
  b = offset + sents[0][0]
616
  e = offset + sents[-1][1]
617
- return (b,e,t)
 
618
 
619
  class VACOnlineASRProcessor(OnlineASRProcessor):
620
- '''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
621
 
622
- It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
623
- it runs VAD and continuously detects whether there is speech or not.
624
  When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
625
- '''
626
 
627
  def __init__(self, online_chunk_size, *a, **kw):
628
  self.online_chunk_size = online_chunk_size
@@ -631,12 +683,13 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
631
 
632
  # VAC:
633
  import torch
634
- model, _ = torch.hub.load(
635
- repo_or_dir='snakers4/silero-vad',
636
- model='silero_vad'
637
- )
638
  from silero_vad_iterator import FixedVADIterator
639
- self.vac = FixedVADIterator(model) # we use the default options there: 500ms silence, 100ms padding, etc.
 
 
 
640
 
641
  self.logfile = self.online.logfile
642
  self.init()
@@ -649,60 +702,65 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
649
  self.is_currently_final = False
650
 
651
  self.status = None # or "voice" or "nonvoice"
652
- self.audio_buffer = np.array([],dtype=np.float32)
653
  self.buffer_offset = 0 # in frames
654
 
655
  def clear_buffer(self):
656
  self.buffer_offset += len(self.audio_buffer)
657
- self.audio_buffer = np.array([],dtype=np.float32)
658
-
659
 
660
  def insert_audio_chunk(self, audio):
661
  res = self.vac(audio)
662
  self.audio_buffer = np.append(self.audio_buffer, audio)
663
 
664
  if res is not None:
665
- frame = list(res.values())[0]-self.buffer_offset
666
- if 'start' in res and 'end' not in res:
667
- self.status = 'voice'
668
  send_audio = self.audio_buffer[frame:]
669
- self.online.init(offset=(frame+self.buffer_offset)/self.SAMPLING_RATE)
 
 
670
  self.online.insert_audio_chunk(send_audio)
671
  self.current_online_chunk_buffer_size += len(send_audio)
672
  self.clear_buffer()
673
- elif 'end' in res and 'start' not in res:
674
- self.status = 'nonvoice'
675
  send_audio = self.audio_buffer[:frame]
676
  self.online.insert_audio_chunk(send_audio)
677
  self.current_online_chunk_buffer_size += len(send_audio)
678
  self.is_currently_final = True
679
  self.clear_buffer()
680
  else:
681
- beg = res["start"]-self.buffer_offset
682
- end = res["end"]-self.buffer_offset
683
- self.status = 'nonvoice'
684
  send_audio = self.audio_buffer[beg:end]
685
- self.online.init(offset=(beg+self.buffer_offset)/self.SAMPLING_RATE)
686
  self.online.insert_audio_chunk(send_audio)
687
  self.current_online_chunk_buffer_size += len(send_audio)
688
  self.is_currently_final = True
689
  self.clear_buffer()
690
  else:
691
- if self.status == 'voice':
692
  self.online.insert_audio_chunk(self.audio_buffer)
693
  self.current_online_chunk_buffer_size += len(self.audio_buffer)
694
  self.clear_buffer()
695
  else:
696
  # We keep 1 second because VAD may later find start of voice in it.
697
- # But we trim it to prevent OOM.
698
- self.buffer_offset += max(0,len(self.audio_buffer)-self.SAMPLING_RATE)
699
- self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
700
-
 
701
 
702
  def process_iter(self):
703
  if self.is_currently_final:
704
  return self.finish()
705
- elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE*self.online_chunk_size:
 
 
 
706
  self.current_online_chunk_buffer_size = 0
707
  ret = self.online.process_iter()
708
  return ret
@@ -717,37 +775,55 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
717
  return ret
718
 
719
 
 
 
 
720
 
721
- WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",")
722
 
723
  def create_tokenizer(lan):
724
  """returns an object that has split function that works like the one of MosesTokenizer"""
725
 
726
- assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
 
 
727
 
728
  if lan == "uk":
729
  import tokenize_uk
 
730
  class UkrainianTokenizer:
731
  def split(self, text):
732
  return tokenize_uk.tokenize_sents(text)
 
733
  return UkrainianTokenizer()
734
 
735
  # supported by fast-mosestokenizer
736
- if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
 
 
 
737
  from mosestokenizer import MosesTokenizer
 
738
  return MosesTokenizer(lan)
739
 
740
  # the following languages are in Whisper, but not in wtpsplit:
741
- if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
742
- logger.debug(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.")
 
 
 
 
 
743
  lan = None
744
 
745
  from wtpsplit import WtP
 
746
  # downloads the model from huggingface on the first use
747
  wtp = WtP("wtp-canine-s-12l-no-adapters")
 
748
  class WtPtok:
749
  def split(self, sent):
750
  return wtp.split(sent, lang_code=lan)
 
751
  return WtPtok()
752
 
753
 
@@ -755,19 +831,91 @@ def add_shared_args(parser):
755
  """shared args for simulation (this entry point) and server
756
  parser: argparse.ArgumentParser object
757
  """
758
- parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
759
- parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
760
- parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
761
- parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
762
- parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
763
- parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
764
- parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.')
765
- parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
766
- parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
767
- parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
768
- parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
769
- parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
770
- parser.add_argument("-l", "--log-level", dest="log_level", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Set the log level", default='DEBUG')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
 
772
  def asr_factory(args, logfile=sys.stderr):
773
  """
@@ -789,12 +937,17 @@ def asr_factory(args, logfile=sys.stderr):
789
  size = args.model
790
  t = time.time()
791
  logger.info(f"Loading Whisper {size} model for {args.lan}...")
792
- asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
 
 
 
 
 
793
  e = time.time()
794
  logger.info(f"done. It took {round(e-t,2)} seconds.")
795
 
796
  # Apply common configurations
797
- if getattr(args, 'vad', False): # Checks if VAD argument is present and True
798
  logger.info("Setting VAD filter")
799
  asr.use_vad()
800
 
@@ -813,51 +966,82 @@ def asr_factory(args, logfile=sys.stderr):
813
 
814
  # Create the OnlineASRProcessor
815
  if args.vac:
816
-
817
- online = VACOnlineASRProcessor(args.min_chunk_size, asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
 
 
 
 
 
 
818
  else:
819
- online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
 
 
 
 
 
820
 
821
  return asr, online
822
 
823
- def set_logging(args,logger,other="_server"):
824
- logging.basicConfig(#format='%(name)s
825
- format='%(levelname)s\t%(message)s')
826
  logger.setLevel(args.log_level)
827
- logging.getLogger("whisper_online"+other).setLevel(args.log_level)
828
- # logging.getLogger("whisper_online_server").setLevel(args.log_level)
829
 
830
 
 
 
831
 
832
  if __name__ == "__main__":
833
 
834
  import argparse
 
835
  parser = argparse.ArgumentParser()
836
- parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
 
 
 
 
837
  add_shared_args(parser)
838
- parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
839
- parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
840
- parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
841
-
 
 
 
 
 
 
 
 
 
 
 
 
842
  args = parser.parse_args()
843
 
844
  # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
845
  logfile = sys.stderr
846
 
847
  if args.offline and args.comp_unaware:
848
- logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
 
 
849
  sys.exit(1)
850
 
851
- # if args.log_level:
852
- # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
853
- # level=getattr(logging, args.log_level))
854
 
855
- set_logging(args,logger)
856
 
857
  audio_path = args.audio_path
858
 
859
  SAMPLING_RATE = 16000
860
- duration = len(load_audio(audio_path))/SAMPLING_RATE
861
  logger.info("Audio duration is: %2.2f seconds" % duration)
862
 
863
  asr, online = asr_factory(args, logfile=logfile)
@@ -867,13 +1051,13 @@ if __name__ == "__main__":
867
  min_chunk = args.min_chunk_size
868
 
869
  # load the audio into the LRU cache before we start the timer
870
- a = load_audio_chunk(audio_path,0,1)
871
 
872
  # warm up the ASR because the very first transcribe takes much more time than the other
873
  asr.transcribe(a)
874
 
875
  beg = args.start_at
876
- start = time.time()-beg
877
 
878
  def output_transcript(o, now=None):
879
  # output format in stdout is like:
@@ -883,15 +1067,22 @@ if __name__ == "__main__":
883
  # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
884
  # - the next words: segment transcript
885
  if now is None:
886
- now = time.time()-start
887
  if o[0] is not None:
888
- print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
889
- print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
 
 
 
 
 
 
 
890
  else:
891
  # No text, so no output
892
  pass
893
 
894
- if args.offline: ## offline mode processing (for testing/debugging)
895
  a = load_audio(audio_path)
896
  online.insert_audio_chunk(a)
897
  try:
@@ -901,10 +1092,10 @@ if __name__ == "__main__":
901
  else:
902
  output_transcript(o)
903
  now = None
904
- elif args.comp_unaware: # computational unaware mode
905
  end = beg + min_chunk
906
  while True:
907
- a = load_audio_chunk(audio_path,beg,end)
908
  online.insert_audio_chunk(a)
909
  try:
910
  o = online.process_iter()
@@ -918,23 +1109,23 @@ if __name__ == "__main__":
918
 
919
  if end >= duration:
920
  break
921
-
922
  beg = end
923
-
924
  if end + min_chunk > duration:
925
  end = duration
926
  else:
927
  end += min_chunk
928
  now = duration
929
 
930
- else: # online = simultaneous mode
931
  end = 0
932
  while True:
933
  now = time.time() - start
934
- if now < end+min_chunk:
935
- time.sleep(min_chunk+end-now)
936
  end = time.time() - start
937
- a = load_audio_chunk(audio_path,beg,end)
938
  beg = end
939
  online.insert_audio_chunk(a)
940
 
@@ -946,7 +1137,9 @@ if __name__ == "__main__":
946
  else:
947
  output_transcript(o)
948
  now = time.time() - start
949
- logger.debug(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}")
 
 
950
 
951
  if end >= duration:
952
  break
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+
16
  @lru_cache(10**6)
17
  def load_audio(fname):
18
  a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
19
  return a
20
 
21
+
22
  def load_audio_chunk(fname, beg, end):
23
  audio = load_audio(fname)
24
+ beg_s = int(beg * 16000)
25
+ end_s = int(end * 16000)
26
  return audio[beg_s:end_s]
27
 
28
 
29
  # Whisper backend
30
 
31
+
32
  class ASRBase:
33
 
34
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
35
+ # "" for faster-whisper because it emits the spaces when neeeded)
36
 
37
+ def __init__(
38
+ self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
39
+ ):
40
  self.logfile = logfile
41
 
42
  self.transcribe_kargs = {}
 
47
 
48
  self.model = self.load_model(modelsize, cache_dir, model_dir)
49
 
 
50
  def load_model(self, modelsize, cache_dir):
51
  raise NotImplemented("must be implemented in the child class")
52
 
 
68
  import whisper
69
  import whisper_timestamped
70
  from whisper_timestamped import transcribe_timestamped
71
+
72
  self.transcribe_timestamped = transcribe_timestamped
73
  if model_dir is not None:
74
  logger.debug("ignoring model_dir, not implemented")
75
  return whisper.load_model(modelsize, download_root=cache_dir)
76
 
77
  def transcribe(self, audio, init_prompt=""):
78
+ result = self.transcribe_timestamped(
79
+ self.model,
80
+ audio,
81
+ language=self.original_language,
82
+ initial_prompt=init_prompt,
83
+ verbose=None,
84
+ condition_on_previous_text=True,
85
+ **self.transcribe_kargs,
86
+ )
87
  return result
88
+
89
+ def ts_words(self, r):
90
  # return: transcribe result object to [(beg,end,"word1"), ...]
91
  o = []
92
  for s in r["segments"]:
93
  for w in s["words"]:
94
+ t = (w["start"], w["end"], w["text"])
95
  o.append(t)
96
  return o
97
 
 
105
  self.transcribe_kargs["task"] = "translate"
106
 
107
 
 
 
108
  class FasterWhisperASR(ASRBase):
109
+ """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
 
110
 
111
  sep = ""
112
 
113
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
114
  from faster_whisper import WhisperModel
115
+
116
+ # logging.getLogger("faster_whisper").setLevel(logger.level)
117
  if model_dir is not None:
118
+ logger.debug(
119
+ f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
120
+ )
121
  model_size_or_path = model_dir
122
  elif modelsize is not None:
123
  model_size_or_path = modelsize
124
  else:
125
  raise ValueError("modelsize or model_dir parameter must be set")
126
 
 
127
  # this worked fast and reliably on NVIDIA L40
128
+ model = WhisperModel(
129
+ model_size_or_path,
130
+ device="cuda",
131
+ compute_type="float16",
132
+ download_root=cache_dir,
133
+ )
134
 
135
  # or run on GPU with INT8
136
  # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
137
+ # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
138
 
139
  # or run on CPU with INT8
140
  # tested: works, but slow, appx 10-times than cuda FP16
141
+ # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
142
  return model
143
 
144
  def transcribe(self, audio, init_prompt=""):
145
 
146
  # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
147
+ segments, info = self.model.transcribe(
148
+ audio,
149
+ language=self.original_language,
150
+ initial_prompt=init_prompt,
151
+ beam_size=5,
152
+ word_timestamps=True,
153
+ condition_on_previous_text=True,
154
+ **self.transcribe_kargs,
155
+ )
156
+ # print(info) # info contains language detection result
157
 
158
  return list(segments)
159
 
 
178
  def set_translate_task(self):
179
  self.transcribe_kargs["task"] = "translate"
180
 
181
+
182
  class MLXWhisper(ASRBase):
183
  """
184
  Uses MPX Whisper library as the backend, optimized for Apple Silicon.
185
  Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
186
+ Significantly faster than faster-whisper (without CUDA) on Apple M1.
187
  """
188
 
189
  sep = " "
190
 
191
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
192
  """
193
+ Loads the MLX-compatible Whisper model.
194
+
195
+ Args:
196
+ modelsize (str, optional): The size or name of the Whisper model to load.
197
+ If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
198
+ Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
199
+ cache_dir (str, optional): Path to the directory for caching models.
200
+ **Note**: This is not supported by MLX Whisper and will be ignored.
201
+ model_dir (str, optional): Direct path to a custom model directory.
202
+ If specified, it overrides the `modelsize` parameter.
203
  """
204
  from mlx_whisper import transcribe
205
 
206
  if model_dir is not None:
207
+ logger.debug(
208
+ f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
209
+ )
210
  model_size_or_path = model_dir
211
  elif modelsize is not None:
212
  model_size_or_path = self.translate_model_name(modelsize)
213
+ logger.debug(
214
+ f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
215
+ )
216
+
217
  self.model_size_or_path = model_size_or_path
218
  return transcribe
219
+
220
  def translate_model_name(self, model_name):
221
  """
222
  Translates a given model name to its corresponding MLX-compatible model path.
 
241
  "large-v2": "mlx-community/whisper-large-v2-mlx",
242
  "large-v3": "mlx-community/whisper-large-v3-mlx",
243
  "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
244
+ "large": "mlx-community/whisper-large-mlx",
245
  }
246
 
247
  # Retrieve the corresponding MLX model path
 
250
  if mlx_model_path:
251
  return mlx_model_path
252
  else:
253
+ raise ValueError(
254
+ f"Model name '{model_name}' is not recognized or not supported."
255
+ )
256
+
257
  def transcribe(self, audio, init_prompt=""):
258
  segments = self.model(
259
  audio,
 
262
  word_timestamps=True,
263
  condition_on_previous_text=True,
264
  path_or_hf_repo=self.model_size_or_path,
265
+ **self.transcribe_kargs,
266
  )
267
  return segments.get("segments", [])
268
 
 
269
  def ts_words(self, segments):
270
  """
271
  Extract timestamped words from transcription segments and skips words with high no-speech probability.
 
276
  for word in segment.get("words", [])
277
  if segment.get("no_speech_prob", 0) <= 0.9
278
  ]
279
+
280
  def segments_end_ts(self, res):
281
+ return [s["end"] for s in res]
282
 
283
  def use_vad(self):
284
  self.transcribe_kargs["vad_filter"] = True
 
286
  def set_translate_task(self):
287
  self.transcribe_kargs["task"] = "translate"
288
 
289
+
290
  class OpenaiApiASR(ASRBase):
291
  """Uses OpenAI's Whisper API for audio transcription."""
292
 
293
  def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
294
  self.logfile = logfile
295
 
296
+ self.modelname = "whisper-1"
297
+ self.original_language = (
298
+ None if lan == "auto" else lan
299
+ ) # ISO-639-1 language code
300
+ self.response_format = "verbose_json"
301
  self.temperature = temperature
302
 
303
  self.load_model()
 
309
 
310
  def load_model(self, *args, **kwargs):
311
  from openai import OpenAI
312
+
313
  self.client = OpenAI()
314
 
315
+ self.transcribed_seconds = (
316
+ 0 # for logging how many seconds were processed by API, to know the cost
317
+ )
318
 
319
  def ts_words(self, segments):
320
  no_speech_segments = []
 
322
  for segment in segments.segments:
323
  # TODO: threshold can be set from outside
324
  if segment["no_speech_prob"] > 0.8:
325
+ no_speech_segments.append(
326
+ (segment.get("start"), segment.get("end"))
327
+ )
328
 
329
  o = []
330
  for word in segments.words:
 
336
  o.append((start, end, word.word))
337
  return o
338
 
 
339
  def segments_end_ts(self, res):
340
  return [s.end for s in res.words]
341
 
 
343
  # Write the audio data to a buffer
344
  buffer = io.BytesIO()
345
  buffer.name = "temp.wav"
346
+ sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
347
  buffer.seek(0) # Reset buffer's position to the beginning
348
 
349
+ self.transcribed_seconds += math.ceil(
350
+ len(audio_data) / 16000
351
+ ) # it rounds up to the whole seconds
352
 
353
  params = {
354
  "model": self.modelname,
355
  "file": buffer,
356
  "response_format": self.response_format,
357
  "temperature": self.temperature,
358
+ "timestamp_granularities": ["word", "segment"],
359
  }
360
  if self.task != "translate" and self.original_language:
361
  params["language"] = self.original_language
 
369
 
370
  # Process transcription/translation
371
  transcript = proc.create(**params)
372
+ logger.debug(
373
+ f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
374
+ )
375
 
376
  return transcript
377
 
 
382
  self.task = "translate"
383
 
384
 
 
 
385
  class HypothesisBuffer:
386
 
387
  def __init__(self, logfile=sys.stderr):
 
397
  def insert(self, new, offset):
398
  # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
399
  # the new tail is added to self.new
400
+
401
+ new = [(a + offset, b + offset, t) for a, b, t in new]
402
+ self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
403
 
404
  if len(self.new) >= 1:
405
+ a, b, t = self.new[0]
406
  if abs(a - self.last_commited_time) < 1:
407
  if self.commited_in_buffer:
408
  # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
409
  cn = len(self.commited_in_buffer)
410
  nn = len(self.new)
411
+ for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
412
+ c = " ".join(
413
+ [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
414
+ ::-1
415
+ ]
416
+ )
417
+ tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
418
  if c == tail:
419
  words = []
420
  for j in range(i):
 
424
  break
425
 
426
  def flush(self):
427
+ # returns commited chunk = the longest common prefix of 2 last inserts.
428
 
429
  commit = []
430
  while self.new:
 
434
  break
435
 
436
  if nt == self.buffer[0][2]:
437
+ commit.append((na, nb, nt))
438
  self.last_commited_word = nt
439
  self.last_commited_time = nb
440
  self.buffer.pop(0)
 
453
  def complete(self):
454
  return self.buffer
455
 
456
+
457
  class OnlineASRProcessor:
458
 
459
  SAMPLING_RATE = 16000
460
 
461
+ def __init__(
462
+ self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr
463
+ ):
464
  """asr: WhisperASR object
465
  tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
466
  ("segment", 15)
467
  buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
468
+ logfile: where to store the log.
469
  """
470
  self.asr = asr
471
  self.tokenizer = tokenizer
 
477
 
478
  def init(self, offset=None):
479
  """run this when starting or restarting processing"""
480
+ self.audio_buffer = np.array([], dtype=np.float32)
481
  self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
482
  self.buffer_time_offset = 0
483
  if offset is not None:
 
489
  self.audio_buffer = np.append(self.audio_buffer, audio)
490
 
491
  def prompt(self):
492
+ """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
493
  "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
494
  """
495
+ k = max(0, len(self.commited) - 1)
496
+ while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
497
  k -= 1
498
 
499
  p = self.commited[:k]
500
+ p = [t for _, _, t in p]
501
  prompt = []
502
  l = 0
503
  while p and l < 200: # 200 characters prompt size
504
  x = p.pop(-1)
505
+ l += len(x) + 1
506
  prompt.append(x)
507
  non_prompt = self.commited[k:]
508
+ return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
509
+ t for _, _, t in non_prompt
510
+ )
511
 
512
  def process_iter(self):
513
  """Runs on the current audio buffer.
514
+ Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
515
  The non-emty text is confirmed (committed) partial transcript.
516
  """
517
 
518
  prompt, non_prompt = self.prompt()
519
  logger.debug(f"PROMPT: {prompt}")
520
  logger.debug(f"CONTEXT: {non_prompt}")
521
+ logger.debug(
522
+ f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
523
+ )
524
  res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
525
 
526
  # transform to [(beg,end,"word1"), ...]
 
537
  # there is a newly confirmed text
538
 
539
  if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
540
+ if (
541
+ len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
542
+ ): # longer than this
543
  self.chunk_completed_sentence()
544
 
 
545
  if self.buffer_trimming_way == "segment":
546
  s = self.buffer_trimming_sec # trim the completed segments longer than s,
547
  else:
548
+ s = 30 # if the audio buffer is longer than 30s, trim it
549
+
550
+ if len(self.audio_buffer) / self.SAMPLING_RATE > s:
551
  self.chunk_completed_segment(res)
552
 
553
  # alternative: on any word
554
+ # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
555
  # let's find commited word that is less
556
+ # k = len(self.commited)-1
557
+ # while k>0 and self.commited[k][1] > l:
558
  # k -= 1
559
+ # t = self.commited[k][1]
560
  logger.debug("chunking segment")
561
+ # self.chunk_at(t)
562
 
563
+ logger.debug(
564
+ f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
565
+ )
566
  return self.to_flush(o)
567
 
568
  def chunk_completed_sentence(self):
569
+ if self.commited == []:
570
+ return
571
  logger.debug(self.commited)
572
  sents = self.words_to_sentences(self.commited)
573
  for s in sents:
 
583
  self.chunk_at(chunk_at)
584
 
585
  def chunk_completed_segment(self, res):
586
+ if self.commited == []:
587
+ return
588
 
589
  ends = self.asr.segments_end_ts(res)
590
 
 
592
 
593
  if len(ends) > 1:
594
 
595
+ e = ends[-2] + self.buffer_time_offset
596
  while len(ends) > 2 and e > t:
597
  ends.pop(-1)
598
+ e = ends[-2] + self.buffer_time_offset
599
  if e <= t:
600
  logger.debug(f"--- segment chunked at {e:2.2f}")
601
  self.chunk_at(e)
 
604
  else:
605
  logger.debug(f"--- not enough segments to chunk")
606
 
 
 
 
 
607
  def chunk_at(self, time):
608
+ """trims the hypothesis and audio buffer at "time" """
 
609
  self.transcript_buffer.pop_commited(time)
610
  cut_seconds = time - self.buffer_time_offset
611
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
612
  self.buffer_time_offset = time
613
 
614
  def words_to_sentences(self, words):
615
  """Uses self.tokenizer for sentence segmentation of words.
616
  Returns: [(beg,end,"sentence 1"),...]
617
  """
618
+
619
  cwords = [w for w in words]
620
  t = " ".join(o[2] for o in cwords)
621
  s = self.tokenizer.split(t)
 
626
  sent = s.pop(0).strip()
627
  fsent = sent
628
  while cwords:
629
+ b, e, w = cwords.pop(0)
630
  w = w.strip()
631
  if beg is None and sent.startswith(w):
632
  beg = b
633
  elif end is None and sent == w:
634
  end = e
635
+ out.append((beg, end, fsent))
636
  break
637
+ sent = sent[len(w) :].strip()
638
  return out
639
 
640
  def finish(self):
 
644
  o = self.transcript_buffer.complete()
645
  f = self.to_flush(o)
646
  logger.debug(f"last, noncommited: {f}")
647
+ self.buffer_time_offset += len(self.audio_buffer) / 16000
648
  return f
649
 
650
+ def to_flush(
651
+ self,
652
+ sents,
653
+ sep=None,
654
+ offset=0,
655
+ ):
656
  # concatenates the timestamped words or sentences into one sequence that is flushed in one line
657
  # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
658
  # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
 
665
  else:
666
  b = offset + sents[0][0]
667
  e = offset + sents[-1][1]
668
+ return (b, e, t)
669
+
670
 
671
  class VACOnlineASRProcessor(OnlineASRProcessor):
672
+ """Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
673
 
674
+ It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
675
+ it runs VAD and continuously detects whether there is speech or not.
676
  When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
677
+ """
678
 
679
  def __init__(self, online_chunk_size, *a, **kw):
680
  self.online_chunk_size = online_chunk_size
 
683
 
684
  # VAC:
685
  import torch
686
+
687
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
 
 
688
  from silero_vad_iterator import FixedVADIterator
689
+
690
+ self.vac = FixedVADIterator(
691
+ model
692
+ ) # we use the default options there: 500ms silence, 100ms padding, etc.
693
 
694
  self.logfile = self.online.logfile
695
  self.init()
 
702
  self.is_currently_final = False
703
 
704
  self.status = None # or "voice" or "nonvoice"
705
+ self.audio_buffer = np.array([], dtype=np.float32)
706
  self.buffer_offset = 0 # in frames
707
 
708
  def clear_buffer(self):
709
  self.buffer_offset += len(self.audio_buffer)
710
+ self.audio_buffer = np.array([], dtype=np.float32)
 
711
 
712
  def insert_audio_chunk(self, audio):
713
  res = self.vac(audio)
714
  self.audio_buffer = np.append(self.audio_buffer, audio)
715
 
716
  if res is not None:
717
+ frame = list(res.values())[0] - self.buffer_offset
718
+ if "start" in res and "end" not in res:
719
+ self.status = "voice"
720
  send_audio = self.audio_buffer[frame:]
721
+ self.online.init(
722
+ offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
723
+ )
724
  self.online.insert_audio_chunk(send_audio)
725
  self.current_online_chunk_buffer_size += len(send_audio)
726
  self.clear_buffer()
727
+ elif "end" in res and "start" not in res:
728
+ self.status = "nonvoice"
729
  send_audio = self.audio_buffer[:frame]
730
  self.online.insert_audio_chunk(send_audio)
731
  self.current_online_chunk_buffer_size += len(send_audio)
732
  self.is_currently_final = True
733
  self.clear_buffer()
734
  else:
735
+ beg = res["start"] - self.buffer_offset
736
+ end = res["end"] - self.buffer_offset
737
+ self.status = "nonvoice"
738
  send_audio = self.audio_buffer[beg:end]
739
+ self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
740
  self.online.insert_audio_chunk(send_audio)
741
  self.current_online_chunk_buffer_size += len(send_audio)
742
  self.is_currently_final = True
743
  self.clear_buffer()
744
  else:
745
+ if self.status == "voice":
746
  self.online.insert_audio_chunk(self.audio_buffer)
747
  self.current_online_chunk_buffer_size += len(self.audio_buffer)
748
  self.clear_buffer()
749
  else:
750
  # We keep 1 second because VAD may later find start of voice in it.
751
+ # But we trim it to prevent OOM.
752
+ self.buffer_offset += max(
753
+ 0, len(self.audio_buffer) - self.SAMPLING_RATE
754
+ )
755
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
756
 
757
  def process_iter(self):
758
  if self.is_currently_final:
759
  return self.finish()
760
+ elif (
761
+ self.current_online_chunk_buffer_size
762
+ > self.SAMPLING_RATE * self.online_chunk_size
763
+ ):
764
  self.current_online_chunk_buffer_size = 0
765
  ret = self.online.process_iter()
766
  return ret
 
775
  return ret
776
 
777
 
778
+ WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
779
+ ","
780
+ )
781
 
 
782
 
783
  def create_tokenizer(lan):
784
  """returns an object that has split function that works like the one of MosesTokenizer"""
785
 
786
+ assert (
787
+ lan in WHISPER_LANG_CODES
788
+ ), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
789
 
790
  if lan == "uk":
791
  import tokenize_uk
792
+
793
  class UkrainianTokenizer:
794
  def split(self, text):
795
  return tokenize_uk.tokenize_sents(text)
796
+
797
  return UkrainianTokenizer()
798
 
799
  # supported by fast-mosestokenizer
800
+ if (
801
+ lan
802
+ in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
803
+ ):
804
  from mosestokenizer import MosesTokenizer
805
+
806
  return MosesTokenizer(lan)
807
 
808
  # the following languages are in Whisper, but not in wtpsplit:
809
+ if (
810
+ lan
811
+ in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
812
+ ):
813
+ logger.debug(
814
+ f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
815
+ )
816
  lan = None
817
 
818
  from wtpsplit import WtP
819
+
820
  # downloads the model from huggingface on the first use
821
  wtp = WtP("wtp-canine-s-12l-no-adapters")
822
+
823
  class WtPtok:
824
  def split(self, sent):
825
  return wtp.split(sent, lang_code=lan)
826
+
827
  return WtPtok()
828
 
829
 
 
831
  """shared args for simulation (this entry point) and server
832
  parser: argparse.ArgumentParser object
833
  """
834
+ parser.add_argument(
835
+ "--min-chunk-size",
836
+ type=float,
837
+ default=1.0,
838
+ help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
839
+ )
840
+ parser.add_argument(
841
+ "--model",
842
+ type=str,
843
+ default="large-v2",
844
+ choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
845
+ ","
846
+ ),
847
+ help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
848
+ )
849
+ parser.add_argument(
850
+ "--model_cache_dir",
851
+ type=str,
852
+ default=None,
853
+ help="Overriding the default model cache dir where models downloaded from the hub are saved",
854
+ )
855
+ parser.add_argument(
856
+ "--model_dir",
857
+ type=str,
858
+ default=None,
859
+ help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
860
+ )
861
+ parser.add_argument(
862
+ "--lan",
863
+ "--language",
864
+ type=str,
865
+ default="auto",
866
+ help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
867
+ )
868
+ parser.add_argument(
869
+ "--task",
870
+ type=str,
871
+ default="transcribe",
872
+ choices=["transcribe", "translate"],
873
+ help="Transcribe or translate.",
874
+ )
875
+ parser.add_argument(
876
+ "--backend",
877
+ type=str,
878
+ default="faster-whisper",
879
+ choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
880
+ help="Load only this backend for Whisper processing.",
881
+ )
882
+ parser.add_argument(
883
+ "--vac",
884
+ action="store_true",
885
+ default=False,
886
+ help="Use VAC = voice activity controller. Recommended. Requires torch.",
887
+ )
888
+ parser.add_argument(
889
+ "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
890
+ )
891
+ parser.add_argument(
892
+ "--vad",
893
+ action="store_true",
894
+ default=False,
895
+ help="Use VAD = voice activity detection, with the default parameters.",
896
+ )
897
+ parser.add_argument(
898
+ "--buffer_trimming",
899
+ type=str,
900
+ default="segment",
901
+ choices=["sentence", "segment"],
902
+ help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
903
+ )
904
+ parser.add_argument(
905
+ "--buffer_trimming_sec",
906
+ type=float,
907
+ default=15,
908
+ help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
909
+ )
910
+ parser.add_argument(
911
+ "-l",
912
+ "--log-level",
913
+ dest="log_level",
914
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
915
+ help="Set the log level",
916
+ default="DEBUG",
917
+ )
918
+
919
 
920
  def asr_factory(args, logfile=sys.stderr):
921
  """
 
937
  size = args.model
938
  t = time.time()
939
  logger.info(f"Loading Whisper {size} model for {args.lan}...")
940
+ asr = asr_cls(
941
+ modelsize=size,
942
+ lan=args.lan,
943
+ cache_dir=args.model_cache_dir,
944
+ model_dir=args.model_dir,
945
+ )
946
  e = time.time()
947
  logger.info(f"done. It took {round(e-t,2)} seconds.")
948
 
949
  # Apply common configurations
950
+ if getattr(args, "vad", False): # Checks if VAD argument is present and True
951
  logger.info("Setting VAD filter")
952
  asr.use_vad()
953
 
 
966
 
967
  # Create the OnlineASRProcessor
968
  if args.vac:
969
+
970
+ online = VACOnlineASRProcessor(
971
+ args.min_chunk_size,
972
+ asr,
973
+ tokenizer,
974
+ logfile=logfile,
975
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
976
+ )
977
  else:
978
+ online = OnlineASRProcessor(
979
+ asr,
980
+ tokenizer,
981
+ logfile=logfile,
982
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
983
+ )
984
 
985
  return asr, online
986
 
987
+
988
+ def set_logging(args, logger, other="_server"):
989
+ logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
990
  logger.setLevel(args.log_level)
991
+ logging.getLogger("whisper_online" + other).setLevel(args.log_level)
 
992
 
993
 
994
+ # logging.getLogger("whisper_online_server").setLevel(args.log_level)
995
+
996
 
997
  if __name__ == "__main__":
998
 
999
  import argparse
1000
+
1001
  parser = argparse.ArgumentParser()
1002
+ parser.add_argument(
1003
+ "audio_path",
1004
+ type=str,
1005
+ help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
1006
+ )
1007
  add_shared_args(parser)
1008
+ parser.add_argument(
1009
+ "--start_at",
1010
+ type=float,
1011
+ default=0.0,
1012
+ help="Start processing audio at this time.",
1013
+ )
1014
+ parser.add_argument(
1015
+ "--offline", action="store_true", default=False, help="Offline mode."
1016
+ )
1017
+ parser.add_argument(
1018
+ "--comp_unaware",
1019
+ action="store_true",
1020
+ default=False,
1021
+ help="Computationally unaware simulation.",
1022
+ )
1023
+
1024
  args = parser.parse_args()
1025
 
1026
  # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
1027
  logfile = sys.stderr
1028
 
1029
  if args.offline and args.comp_unaware:
1030
+ logger.error(
1031
+ "No or one option from --offline and --comp_unaware are available, not both. Exiting."
1032
+ )
1033
  sys.exit(1)
1034
 
1035
+ # if args.log_level:
1036
+ # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
1037
+ # level=getattr(logging, args.log_level))
1038
 
1039
+ set_logging(args, logger)
1040
 
1041
  audio_path = args.audio_path
1042
 
1043
  SAMPLING_RATE = 16000
1044
+ duration = len(load_audio(audio_path)) / SAMPLING_RATE
1045
  logger.info("Audio duration is: %2.2f seconds" % duration)
1046
 
1047
  asr, online = asr_factory(args, logfile=logfile)
 
1051
  min_chunk = args.min_chunk_size
1052
 
1053
  # load the audio into the LRU cache before we start the timer
1054
+ a = load_audio_chunk(audio_path, 0, 1)
1055
 
1056
  # warm up the ASR because the very first transcribe takes much more time than the other
1057
  asr.transcribe(a)
1058
 
1059
  beg = args.start_at
1060
+ start = time.time() - beg
1061
 
1062
  def output_transcript(o, now=None):
1063
  # output format in stdout is like:
 
1067
  # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
1068
  # - the next words: segment transcript
1069
  if now is None:
1070
+ now = time.time() - start
1071
  if o[0] is not None:
1072
+ print(
1073
+ "%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]),
1074
+ file=logfile,
1075
+ flush=True,
1076
+ )
1077
+ print(
1078
+ "%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]),
1079
+ flush=True,
1080
+ )
1081
  else:
1082
  # No text, so no output
1083
  pass
1084
 
1085
+ if args.offline: ## offline mode processing (for testing/debugging)
1086
  a = load_audio(audio_path)
1087
  online.insert_audio_chunk(a)
1088
  try:
 
1092
  else:
1093
  output_transcript(o)
1094
  now = None
1095
+ elif args.comp_unaware: # computational unaware mode
1096
  end = beg + min_chunk
1097
  while True:
1098
+ a = load_audio_chunk(audio_path, beg, end)
1099
  online.insert_audio_chunk(a)
1100
  try:
1101
  o = online.process_iter()
 
1109
 
1110
  if end >= duration:
1111
  break
1112
+
1113
  beg = end
1114
+
1115
  if end + min_chunk > duration:
1116
  end = duration
1117
  else:
1118
  end += min_chunk
1119
  now = duration
1120
 
1121
+ else: # online = simultaneous mode
1122
  end = 0
1123
  while True:
1124
  now = time.time() - start
1125
+ if now < end + min_chunk:
1126
+ time.sleep(min_chunk + end - now)
1127
  end = time.time() - start
1128
+ a = load_audio_chunk(audio_path, beg, end)
1129
  beg = end
1130
  online.insert_audio_chunk(a)
1131
 
 
1137
  else:
1138
  output_transcript(o)
1139
  now = time.time() - start
1140
+ logger.debug(
1141
+ f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}"
1142
+ )
1143
 
1144
  if end >= duration:
1145
  break