qfuxa commited on
Commit
cc68f3b
·
1 Parent(s): 9cbac96

split whisper_online.py into smaller files

Browse files
README.md CHANGED
@@ -12,12 +12,12 @@ This project extends the [Whisper Streaming](https://github.com/ufal/whisper_str
12
 
13
  5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
14
 
15
- ![Demo Screenshot](src/demo.png)
16
 
17
  ## Code Origins
18
 
19
  This project reuses and extends code from the original Whisper Streaming repository:
20
- - whisper_online.py: Contains code from whisper_streaming
21
  - silero_vad_iterator.py: Originally from the Silero VAD repository, included in the whisper_streaming project.
22
 
23
  ## Installation
 
12
 
13
  5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
14
 
15
+ ![Demo Screenshot](src/web/demo.png)
16
 
17
  ## Code Origins
18
 
19
  This project reuses and extends code from the original Whisper Streaming repository:
20
+ - whisper_online.py, backends.py and online_asr.py: Contains code from whisper_streaming
21
  - silero_vad_iterator.py: Originally from the Silero VAD repository, included in the whisper_streaming project.
22
 
23
  ## Installation
src/demo.png DELETED
Binary file (82.6 kB)
 
src/{live_transcription.html → web/live_transcription.html} RENAMED
File without changes
src/whisper_streaming/backends.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+
4
+ import io
5
+ import soundfile as sf
6
+ import math
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class ASRBase:
12
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
13
+ # "" for faster-whisper because it emits the spaces when neeeded)
14
+
15
+ def __init__(
16
+ self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
17
+ ):
18
+ self.logfile = logfile
19
+
20
+ self.transcribe_kargs = {}
21
+ if lan == "auto":
22
+ self.original_language = None
23
+ else:
24
+ self.original_language = lan
25
+
26
+ self.model = self.load_model(modelsize, cache_dir, model_dir)
27
+
28
+ def load_model(self, modelsize, cache_dir):
29
+ raise NotImplemented("must be implemented in the child class")
30
+
31
+ def transcribe(self, audio, init_prompt=""):
32
+ raise NotImplemented("must be implemented in the child class")
33
+
34
+ def use_vad(self):
35
+ raise NotImplemented("must be implemented in the child class")
36
+
37
+
38
+ class WhisperTimestampedASR(ASRBase):
39
+ """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
40
+ On the other hand, the installation for GPU could be easier.
41
+ """
42
+
43
+ sep = " "
44
+
45
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
46
+ import whisper
47
+ import whisper_timestamped
48
+ from whisper_timestamped import transcribe_timestamped
49
+
50
+ self.transcribe_timestamped = transcribe_timestamped
51
+ if model_dir is not None:
52
+ logger.debug("ignoring model_dir, not implemented")
53
+ return whisper.load_model(modelsize, download_root=cache_dir)
54
+
55
+ def transcribe(self, audio, init_prompt=""):
56
+ result = self.transcribe_timestamped(
57
+ self.model,
58
+ audio,
59
+ language=self.original_language,
60
+ initial_prompt=init_prompt,
61
+ verbose=None,
62
+ condition_on_previous_text=True,
63
+ **self.transcribe_kargs,
64
+ )
65
+ return result
66
+
67
+ def ts_words(self, r):
68
+ # return: transcribe result object to [(beg,end,"word1"), ...]
69
+ o = []
70
+ for s in r["segments"]:
71
+ for w in s["words"]:
72
+ t = (w["start"], w["end"], w["text"])
73
+ o.append(t)
74
+ return o
75
+
76
+ def segments_end_ts(self, res):
77
+ return [s["end"] for s in res["segments"]]
78
+
79
+ def use_vad(self):
80
+ self.transcribe_kargs["vad"] = True
81
+
82
+ def set_translate_task(self):
83
+ self.transcribe_kargs["task"] = "translate"
84
+
85
+
86
+ class FasterWhisperASR(ASRBase):
87
+ """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
88
+
89
+ sep = ""
90
+
91
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
92
+ from faster_whisper import WhisperModel
93
+
94
+ # logging.getLogger("faster_whisper").setLevel(logger.level)
95
+ if model_dir is not None:
96
+ logger.debug(
97
+ f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
98
+ )
99
+ model_size_or_path = model_dir
100
+ elif modelsize is not None:
101
+ model_size_or_path = modelsize
102
+ else:
103
+ raise ValueError("modelsize or model_dir parameter must be set")
104
+
105
+ # this worked fast and reliably on NVIDIA L40
106
+ model = WhisperModel(
107
+ model_size_or_path,
108
+ device="cuda",
109
+ compute_type="float16",
110
+ download_root=cache_dir,
111
+ )
112
+
113
+ # or run on GPU with INT8
114
+ # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
115
+ # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
116
+
117
+ # or run on CPU with INT8
118
+ # tested: works, but slow, appx 10-times than cuda FP16
119
+ # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
120
+ return model
121
+
122
+ def transcribe(self, audio, init_prompt=""):
123
+
124
+ # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
125
+ segments, info = self.model.transcribe(
126
+ audio,
127
+ language=self.original_language,
128
+ initial_prompt=init_prompt,
129
+ beam_size=5,
130
+ word_timestamps=True,
131
+ condition_on_previous_text=True,
132
+ **self.transcribe_kargs,
133
+ )
134
+ # print(info) # info contains language detection result
135
+
136
+ return list(segments)
137
+
138
+ def ts_words(self, segments):
139
+ o = []
140
+ for segment in segments:
141
+ for word in segment.words:
142
+ if segment.no_speech_prob > 0.9:
143
+ continue
144
+ # not stripping the spaces -- should not be merged with them!
145
+ w = word.word
146
+ t = (word.start, word.end, w)
147
+ o.append(t)
148
+ return o
149
+
150
+ def segments_end_ts(self, res):
151
+ return [s.end for s in res]
152
+
153
+ def use_vad(self):
154
+ self.transcribe_kargs["vad_filter"] = True
155
+
156
+ def set_translate_task(self):
157
+ self.transcribe_kargs["task"] = "translate"
158
+
159
+
160
+ class MLXWhisper(ASRBase):
161
+ """
162
+ Uses MPX Whisper library as the backend, optimized for Apple Silicon.
163
+ Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
164
+ Significantly faster than faster-whisper (without CUDA) on Apple M1.
165
+ """
166
+
167
+ sep = " "
168
+
169
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
170
+ """
171
+ Loads the MLX-compatible Whisper model.
172
+
173
+ Args:
174
+ modelsize (str, optional): The size or name of the Whisper model to load.
175
+ If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
176
+ Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
177
+ cache_dir (str, optional): Path to the directory for caching models.
178
+ **Note**: This is not supported by MLX Whisper and will be ignored.
179
+ model_dir (str, optional): Direct path to a custom model directory.
180
+ If specified, it overrides the `modelsize` parameter.
181
+ """
182
+ from mlx_whisper.transcribe import ModelHolder, transcribe
183
+ import mlx.core as mx
184
+
185
+ if model_dir is not None:
186
+ logger.debug(
187
+ f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
188
+ )
189
+ model_size_or_path = model_dir
190
+ elif modelsize is not None:
191
+ model_size_or_path = self.translate_model_name(modelsize)
192
+ logger.debug(
193
+ f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
194
+ )
195
+
196
+ self.model_size_or_path = model_size_or_path
197
+
198
+ # In mlx_whisper.transcribe, dtype is defined as:
199
+ # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
200
+ # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
201
+ dtype = mx.float16
202
+ ModelHolder.get_model(model_size_or_path, dtype)
203
+ return transcribe
204
+
205
+ def translate_model_name(self, model_name):
206
+ """
207
+ Translates a given model name to its corresponding MLX-compatible model path.
208
+
209
+ Args:
210
+ model_name (str): The name of the model to translate.
211
+
212
+ Returns:
213
+ str: The MLX-compatible model path.
214
+ """
215
+ # Dictionary mapping model names to MLX-compatible paths
216
+ model_mapping = {
217
+ "tiny.en": "mlx-community/whisper-tiny.en-mlx",
218
+ "tiny": "mlx-community/whisper-tiny-mlx",
219
+ "base.en": "mlx-community/whisper-base.en-mlx",
220
+ "base": "mlx-community/whisper-base-mlx",
221
+ "small.en": "mlx-community/whisper-small.en-mlx",
222
+ "small": "mlx-community/whisper-small-mlx",
223
+ "medium.en": "mlx-community/whisper-medium.en-mlx",
224
+ "medium": "mlx-community/whisper-medium-mlx",
225
+ "large-v1": "mlx-community/whisper-large-v1-mlx",
226
+ "large-v2": "mlx-community/whisper-large-v2-mlx",
227
+ "large-v3": "mlx-community/whisper-large-v3-mlx",
228
+ "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
229
+ "large": "mlx-community/whisper-large-mlx",
230
+ }
231
+
232
+ # Retrieve the corresponding MLX model path
233
+ mlx_model_path = model_mapping.get(model_name)
234
+
235
+ if mlx_model_path:
236
+ return mlx_model_path
237
+ else:
238
+ raise ValueError(
239
+ f"Model name '{model_name}' is not recognized or not supported."
240
+ )
241
+
242
+ def transcribe(self, audio, init_prompt=""):
243
+ if self.transcribe_kargs:
244
+ logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
245
+ segments = self.model(
246
+ audio,
247
+ language=self.original_language,
248
+ initial_prompt=init_prompt,
249
+ word_timestamps=True,
250
+ condition_on_previous_text=True,
251
+ path_or_hf_repo=self.model_size_or_path,
252
+ )
253
+ return segments.get("segments", [])
254
+
255
+ def ts_words(self, segments):
256
+ """
257
+ Extract timestamped words from transcription segments and skips words with high no-speech probability.
258
+ """
259
+ return [
260
+ (word["start"], word["end"], word["word"])
261
+ for segment in segments
262
+ for word in segment.get("words", [])
263
+ if segment.get("no_speech_prob", 0) <= 0.9
264
+ ]
265
+
266
+ def segments_end_ts(self, res):
267
+ return [s["end"] for s in res]
268
+
269
+ def use_vad(self):
270
+ self.transcribe_kargs["vad_filter"] = True
271
+
272
+ def set_translate_task(self):
273
+ self.transcribe_kargs["task"] = "translate"
274
+
275
+
276
+ class OpenaiApiASR(ASRBase):
277
+ """Uses OpenAI's Whisper API for audio transcription."""
278
+
279
+ def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
280
+ self.logfile = logfile
281
+
282
+ self.modelname = "whisper-1"
283
+ self.original_language = (
284
+ None if lan == "auto" else lan
285
+ ) # ISO-639-1 language code
286
+ self.response_format = "verbose_json"
287
+ self.temperature = temperature
288
+
289
+ self.load_model()
290
+
291
+ self.use_vad_opt = False
292
+
293
+ # reset the task in set_translate_task
294
+ self.task = "transcribe"
295
+
296
+ def load_model(self, *args, **kwargs):
297
+ from openai import OpenAI
298
+
299
+ self.client = OpenAI()
300
+
301
+ self.transcribed_seconds = (
302
+ 0 # for logging how many seconds were processed by API, to know the cost
303
+ )
304
+
305
+ def ts_words(self, segments):
306
+ no_speech_segments = []
307
+ if self.use_vad_opt:
308
+ for segment in segments.segments:
309
+ # TODO: threshold can be set from outside
310
+ if segment["no_speech_prob"] > 0.8:
311
+ no_speech_segments.append(
312
+ (segment.get("start"), segment.get("end"))
313
+ )
314
+
315
+ o = []
316
+ for word in segments.words:
317
+ start = word.start
318
+ end = word.end
319
+ if any(s[0] <= start <= s[1] for s in no_speech_segments):
320
+ # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
321
+ continue
322
+ o.append((start, end, word.word))
323
+ return o
324
+
325
+ def segments_end_ts(self, res):
326
+ return [s.end for s in res.words]
327
+
328
+ def transcribe(self, audio_data, prompt=None, *args, **kwargs):
329
+ # Write the audio data to a buffer
330
+ buffer = io.BytesIO()
331
+ buffer.name = "temp.wav"
332
+ sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
333
+ buffer.seek(0) # Reset buffer's position to the beginning
334
+
335
+ self.transcribed_seconds += math.ceil(
336
+ len(audio_data) / 16000
337
+ ) # it rounds up to the whole seconds
338
+
339
+ params = {
340
+ "model": self.modelname,
341
+ "file": buffer,
342
+ "response_format": self.response_format,
343
+ "temperature": self.temperature,
344
+ "timestamp_granularities": ["word", "segment"],
345
+ }
346
+ if self.task != "translate" and self.original_language:
347
+ params["language"] = self.original_language
348
+ if prompt:
349
+ params["prompt"] = prompt
350
+
351
+ if self.task == "translate":
352
+ proc = self.client.audio.translations
353
+ else:
354
+ proc = self.client.audio.transcriptions
355
+
356
+ # Process transcription/translation
357
+ transcript = proc.create(**params)
358
+ logger.debug(
359
+ f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
360
+ )
361
+
362
+ return transcript
363
+
364
+ def use_vad(self):
365
+ self.use_vad_opt = True
366
+
367
+ def set_translate_task(self):
368
+ self.task = "translate"
src/whisper_streaming/online_asr.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class HypothesisBuffer:
8
+
9
+ def __init__(self, logfile=sys.stderr):
10
+ self.commited_in_buffer = []
11
+ self.buffer = []
12
+ self.new = []
13
+
14
+ self.last_commited_time = 0
15
+ self.last_commited_word = None
16
+
17
+ self.logfile = logfile
18
+
19
+ def insert(self, new, offset):
20
+ # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
21
+ # the new tail is added to self.new
22
+
23
+ new = [(a + offset, b + offset, t) for a, b, t in new]
24
+ self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
25
+
26
+ if len(self.new) >= 1:
27
+ a, b, t = self.new[0]
28
+ if abs(a - self.last_commited_time) < 1:
29
+ if self.commited_in_buffer:
30
+ # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
31
+ cn = len(self.commited_in_buffer)
32
+ nn = len(self.new)
33
+ for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
34
+ c = " ".join(
35
+ [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
36
+ ::-1
37
+ ]
38
+ )
39
+ tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
40
+ if c == tail:
41
+ words = []
42
+ for j in range(i):
43
+ words.append(repr(self.new.pop(0)))
44
+ words_msg = " ".join(words)
45
+ logger.debug(f"removing last {i} words: {words_msg}")
46
+ break
47
+
48
+ def flush(self):
49
+ # returns commited chunk = the longest common prefix of 2 last inserts.
50
+
51
+ commit = []
52
+ while self.new:
53
+ na, nb, nt = self.new[0]
54
+
55
+ if len(self.buffer) == 0:
56
+ break
57
+
58
+ if nt == self.buffer[0][2]:
59
+ commit.append((na, nb, nt))
60
+ self.last_commited_word = nt
61
+ self.last_commited_time = nb
62
+ self.buffer.pop(0)
63
+ self.new.pop(0)
64
+ else:
65
+ break
66
+ self.buffer = self.new
67
+ self.new = []
68
+ self.commited_in_buffer.extend(commit)
69
+ return commit
70
+
71
+ def pop_commited(self, time):
72
+ while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
73
+ self.commited_in_buffer.pop(0)
74
+
75
+ def complete(self):
76
+ return self.buffer
77
+
78
+
79
+ class OnlineASRProcessor:
80
+
81
+ SAMPLING_RATE = 16000
82
+
83
+ def __init__(
84
+ self,
85
+ asr,
86
+ tokenize_method=None,
87
+ buffer_trimming=("segment", 15),
88
+ logfile=sys.stderr,
89
+ ):
90
+ """asr: WhisperASR object
91
+ tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
92
+ ("segment", 15)
93
+ buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
94
+ logfile: where to store the log.
95
+ """
96
+ self.asr = asr
97
+ self.tokenize = tokenize_method
98
+ self.logfile = logfile
99
+
100
+ self.init()
101
+
102
+ self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
103
+
104
+ def init(self, offset=None):
105
+ """run this when starting or restarting processing"""
106
+ self.audio_buffer = np.array([], dtype=np.float32)
107
+ self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
108
+ self.buffer_time_offset = 0
109
+ if offset is not None:
110
+ self.buffer_time_offset = offset
111
+ self.transcript_buffer.last_commited_time = self.buffer_time_offset
112
+ self.commited = []
113
+
114
+ def insert_audio_chunk(self, audio):
115
+ self.audio_buffer = np.append(self.audio_buffer, audio)
116
+
117
+ def prompt(self):
118
+ """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
119
+ "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
120
+ """
121
+ k = max(0, len(self.commited) - 1)
122
+ while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
123
+ k -= 1
124
+
125
+ p = self.commited[:k]
126
+ p = [t for _, _, t in p]
127
+ prompt = []
128
+ l = 0
129
+ while p and l < 200: # 200 characters prompt size
130
+ x = p.pop(-1)
131
+ l += len(x) + 1
132
+ prompt.append(x)
133
+ non_prompt = self.commited[k:]
134
+ return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
135
+ t for _, _, t in non_prompt
136
+ )
137
+
138
+ def process_iter(self):
139
+ """Runs on the current audio buffer.
140
+ Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
141
+ The non-emty text is confirmed (committed) partial transcript.
142
+ """
143
+
144
+ prompt, non_prompt = self.prompt()
145
+ logger.debug(f"PROMPT: {prompt}")
146
+ logger.debug(f"CONTEXT: {non_prompt}")
147
+ logger.debug(
148
+ f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
149
+ )
150
+ res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
151
+
152
+ # transform to [(beg,end,"word1"), ...]
153
+ tsw = self.asr.ts_words(res)
154
+
155
+ self.transcript_buffer.insert(tsw, self.buffer_time_offset)
156
+ o = self.transcript_buffer.flush()
157
+ self.commited.extend(o)
158
+ completed = self.to_flush(o)
159
+ logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
160
+ the_rest = self.to_flush(self.transcript_buffer.complete())
161
+ logger.debug(f"INCOMPLETE: {the_rest[2]}")
162
+
163
+ # there is a newly confirmed text
164
+
165
+ if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
166
+ if (
167
+ len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
168
+ ): # longer than this
169
+ self.chunk_completed_sentence()
170
+
171
+ if self.buffer_trimming_way == "segment":
172
+ s = self.buffer_trimming_sec # trim the completed segments longer than s,
173
+ else:
174
+ s = 30 # if the audio buffer is longer than 30s, trim it
175
+
176
+ if len(self.audio_buffer) / self.SAMPLING_RATE > s:
177
+ self.chunk_completed_segment(res)
178
+
179
+ # alternative: on any word
180
+ # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
181
+ # let's find commited word that is less
182
+ # k = len(self.commited)-1
183
+ # while k>0 and self.commited[k][1] > l:
184
+ # k -= 1
185
+ # t = self.commited[k][1]
186
+ logger.debug("chunking segment")
187
+ # self.chunk_at(t)
188
+
189
+ logger.debug(
190
+ f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
191
+ )
192
+ return self.to_flush(o)
193
+
194
+ def chunk_completed_sentence(self):
195
+ if self.commited == []:
196
+ return
197
+ logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
198
+ sents = self.words_to_sentences(self.commited)
199
+ for s in sents:
200
+ logger.debug(f"\t\tSENT: {s}")
201
+ if len(sents) < 2:
202
+ return
203
+ while len(sents) > 2:
204
+ sents.pop(0)
205
+ # we will continue with audio processing at this timestamp
206
+ chunk_at = sents[-2][1]
207
+
208
+ logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
209
+ self.chunk_at(chunk_at)
210
+
211
+ def chunk_completed_segment(self, res):
212
+ if self.commited == []:
213
+ return
214
+
215
+ ends = self.asr.segments_end_ts(res)
216
+
217
+ t = self.commited[-1][1]
218
+
219
+ if len(ends) > 1:
220
+
221
+ e = ends[-2] + self.buffer_time_offset
222
+ while len(ends) > 2 and e > t:
223
+ ends.pop(-1)
224
+ e = ends[-2] + self.buffer_time_offset
225
+ if e <= t:
226
+ logger.debug(f"--- segment chunked at {e:2.2f}")
227
+ self.chunk_at(e)
228
+ else:
229
+ logger.debug(f"--- last segment not within commited area")
230
+ else:
231
+ logger.debug(f"--- not enough segments to chunk")
232
+
233
+ def chunk_at(self, time):
234
+ """trims the hypothesis and audio buffer at "time" """
235
+ self.transcript_buffer.pop_commited(time)
236
+ cut_seconds = time - self.buffer_time_offset
237
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
238
+ self.buffer_time_offset = time
239
+
240
+ def words_to_sentences(self, words):
241
+ """Uses self.tokenize for sentence segmentation of words.
242
+ Returns: [(beg,end,"sentence 1"),...]
243
+ """
244
+
245
+ cwords = [w for w in words]
246
+ t = " ".join(o[2] for o in cwords)
247
+ s = self.tokenize(t)
248
+ out = []
249
+ while s:
250
+ beg = None
251
+ end = None
252
+ sent = s.pop(0).strip()
253
+ fsent = sent
254
+ while cwords:
255
+ b, e, w = cwords.pop(0)
256
+ w = w.strip()
257
+ if beg is None and sent.startswith(w):
258
+ beg = b
259
+ elif end is None and sent == w:
260
+ end = e
261
+ out.append((beg, end, fsent))
262
+ break
263
+ sent = sent[len(w) :].strip()
264
+ return out
265
+
266
+ def finish(self):
267
+ """Flush the incomplete text when the whole processing ends.
268
+ Returns: the same format as self.process_iter()
269
+ """
270
+ o = self.transcript_buffer.complete()
271
+ f = self.to_flush(o)
272
+ logger.debug(f"last, noncommited: {f}")
273
+ self.buffer_time_offset += len(self.audio_buffer) / 16000
274
+ return f
275
+
276
+ def to_flush(
277
+ self,
278
+ sents,
279
+ sep=None,
280
+ offset=0,
281
+ ):
282
+ # concatenates the timestamped words or sentences into one sequence that is flushed in one line
283
+ # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
284
+ # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
285
+ if sep is None:
286
+ sep = self.asr.sep
287
+ t = sep.join(s[2] for s in sents)
288
+ if len(sents) == 0:
289
+ b = None
290
+ e = None
291
+ else:
292
+ b = offset + sents[0][0]
293
+ e = offset + sents[-1][1]
294
+ return (b, e, t)
295
+
296
+
297
+ class VACOnlineASRProcessor(OnlineASRProcessor):
298
+ """Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
299
+
300
+ It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
301
+ it runs VAD and continuously detects whether there is speech or not.
302
+ When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
303
+ """
304
+
305
+ def __init__(self, online_chunk_size, *a, **kw):
306
+ self.online_chunk_size = online_chunk_size
307
+
308
+ self.online = OnlineASRProcessor(*a, **kw)
309
+
310
+ # VAC:
311
+ import torch
312
+
313
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
314
+ from silero_vad_iterator import FixedVADIterator
315
+
316
+ self.vac = FixedVADIterator(
317
+ model
318
+ ) # we use the default options there: 500ms silence, 100ms padding, etc.
319
+
320
+ self.logfile = self.online.logfile
321
+ self.init()
322
+
323
+ def init(self):
324
+ self.online.init()
325
+ self.vac.reset_states()
326
+ self.current_online_chunk_buffer_size = 0
327
+
328
+ self.is_currently_final = False
329
+
330
+ self.status = None # or "voice" or "nonvoice"
331
+ self.audio_buffer = np.array([], dtype=np.float32)
332
+ self.buffer_offset = 0 # in frames
333
+
334
+ def clear_buffer(self):
335
+ self.buffer_offset += len(self.audio_buffer)
336
+ self.audio_buffer = np.array([], dtype=np.float32)
337
+
338
+ def insert_audio_chunk(self, audio):
339
+ res = self.vac(audio)
340
+ self.audio_buffer = np.append(self.audio_buffer, audio)
341
+
342
+ if res is not None:
343
+ frame = list(res.values())[0] - self.buffer_offset
344
+ if "start" in res and "end" not in res:
345
+ self.status = "voice"
346
+ send_audio = self.audio_buffer[frame:]
347
+ self.online.init(
348
+ offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
349
+ )
350
+ self.online.insert_audio_chunk(send_audio)
351
+ self.current_online_chunk_buffer_size += len(send_audio)
352
+ self.clear_buffer()
353
+ elif "end" in res and "start" not in res:
354
+ self.status = "nonvoice"
355
+ send_audio = self.audio_buffer[:frame]
356
+ self.online.insert_audio_chunk(send_audio)
357
+ self.current_online_chunk_buffer_size += len(send_audio)
358
+ self.is_currently_final = True
359
+ self.clear_buffer()
360
+ else:
361
+ beg = res["start"] - self.buffer_offset
362
+ end = res["end"] - self.buffer_offset
363
+ self.status = "nonvoice"
364
+ send_audio = self.audio_buffer[beg:end]
365
+ self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
366
+ self.online.insert_audio_chunk(send_audio)
367
+ self.current_online_chunk_buffer_size += len(send_audio)
368
+ self.is_currently_final = True
369
+ self.clear_buffer()
370
+ else:
371
+ if self.status == "voice":
372
+ self.online.insert_audio_chunk(self.audio_buffer)
373
+ self.current_online_chunk_buffer_size += len(self.audio_buffer)
374
+ self.clear_buffer()
375
+ else:
376
+ # We keep 1 second because VAD may later find start of voice in it.
377
+ # But we trim it to prevent OOM.
378
+ self.buffer_offset += max(
379
+ 0, len(self.audio_buffer) - self.SAMPLING_RATE
380
+ )
381
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
382
+
383
+ def process_iter(self):
384
+ if self.is_currently_final:
385
+ return self.finish()
386
+ elif (
387
+ self.current_online_chunk_buffer_size
388
+ > self.SAMPLING_RATE * self.online_chunk_size
389
+ ):
390
+ self.current_online_chunk_buffer_size = 0
391
+ ret = self.online.process_iter()
392
+ return ret
393
+ else:
394
+ print("no online update, only VAD", self.status, file=self.logfile)
395
+ return (None, None, "")
396
+
397
+ def finish(self):
398
+ ret = self.online.finish()
399
+ self.current_online_chunk_buffer_size = 0
400
+ self.is_currently_final = False
401
+ return ret
whisper_fastapi_online_server.py CHANGED
@@ -43,7 +43,7 @@ args = parser.parse_args()
43
  asr, tokenizer = backend_factory(args)
44
 
45
  # Load demo HTML for the root endpoint
46
- with open("src/live_transcription.html", "r", encoding="utf-8") as f:
47
  html = f.read()
48
 
49
 
 
43
  asr, tokenizer = backend_factory(args)
44
 
45
  # Load demo HTML for the root endpoint
46
+ with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
47
  html = f.read()
48
 
49
 
whisper_online.py CHANGED
@@ -5,10 +5,8 @@ import librosa
5
  from functools import lru_cache
6
  import time
7
  import logging
8
-
9
- import io
10
- import soundfile as sf
11
- import math
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -25,768 +23,6 @@ def load_audio_chunk(fname, beg, end):
25
  end_s = int(end * 16000)
26
  return audio[beg_s:end_s]
27
 
28
-
29
- # Whisper backend
30
-
31
-
32
- class ASRBase:
33
-
34
- sep = " " # join transcribe words with this character (" " for whisper_timestamped,
35
- # "" for faster-whisper because it emits the spaces when neeeded)
36
-
37
- def __init__(
38
- self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
39
- ):
40
- self.logfile = logfile
41
-
42
- self.transcribe_kargs = {}
43
- if lan == "auto":
44
- self.original_language = None
45
- else:
46
- self.original_language = lan
47
-
48
- self.model = self.load_model(modelsize, cache_dir, model_dir)
49
-
50
- def load_model(self, modelsize, cache_dir):
51
- raise NotImplemented("must be implemented in the child class")
52
-
53
- def transcribe(self, audio, init_prompt=""):
54
- raise NotImplemented("must be implemented in the child class")
55
-
56
- def use_vad(self):
57
- raise NotImplemented("must be implemented in the child class")
58
-
59
-
60
- class WhisperTimestampedASR(ASRBase):
61
- """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
62
- On the other hand, the installation for GPU could be easier.
63
- """
64
-
65
- sep = " "
66
-
67
- def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
68
- import whisper
69
- import whisper_timestamped
70
- from whisper_timestamped import transcribe_timestamped
71
-
72
- self.transcribe_timestamped = transcribe_timestamped
73
- if model_dir is not None:
74
- logger.debug("ignoring model_dir, not implemented")
75
- return whisper.load_model(modelsize, download_root=cache_dir)
76
-
77
- def transcribe(self, audio, init_prompt=""):
78
- result = self.transcribe_timestamped(
79
- self.model,
80
- audio,
81
- language=self.original_language,
82
- initial_prompt=init_prompt,
83
- verbose=None,
84
- condition_on_previous_text=True,
85
- **self.transcribe_kargs,
86
- )
87
- return result
88
-
89
- def ts_words(self, r):
90
- # return: transcribe result object to [(beg,end,"word1"), ...]
91
- o = []
92
- for s in r["segments"]:
93
- for w in s["words"]:
94
- t = (w["start"], w["end"], w["text"])
95
- o.append(t)
96
- return o
97
-
98
- def segments_end_ts(self, res):
99
- return [s["end"] for s in res["segments"]]
100
-
101
- def use_vad(self):
102
- self.transcribe_kargs["vad"] = True
103
-
104
- def set_translate_task(self):
105
- self.transcribe_kargs["task"] = "translate"
106
-
107
-
108
- class FasterWhisperASR(ASRBase):
109
- """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
110
-
111
- sep = ""
112
-
113
- def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
114
- from faster_whisper import WhisperModel
115
-
116
- # logging.getLogger("faster_whisper").setLevel(logger.level)
117
- if model_dir is not None:
118
- logger.debug(
119
- f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
120
- )
121
- model_size_or_path = model_dir
122
- elif modelsize is not None:
123
- model_size_or_path = modelsize
124
- else:
125
- raise ValueError("modelsize or model_dir parameter must be set")
126
-
127
- # this worked fast and reliably on NVIDIA L40
128
- model = WhisperModel(
129
- model_size_or_path,
130
- device="cuda",
131
- compute_type="float16",
132
- download_root=cache_dir,
133
- )
134
-
135
- # or run on GPU with INT8
136
- # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
137
- # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
138
-
139
- # or run on CPU with INT8
140
- # tested: works, but slow, appx 10-times than cuda FP16
141
- # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
142
- return model
143
-
144
- def transcribe(self, audio, init_prompt=""):
145
-
146
- # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
147
- segments, info = self.model.transcribe(
148
- audio,
149
- language=self.original_language,
150
- initial_prompt=init_prompt,
151
- beam_size=5,
152
- word_timestamps=True,
153
- condition_on_previous_text=True,
154
- **self.transcribe_kargs,
155
- )
156
- # print(info) # info contains language detection result
157
-
158
- return list(segments)
159
-
160
- def ts_words(self, segments):
161
- o = []
162
- for segment in segments:
163
- for word in segment.words:
164
- if segment.no_speech_prob > 0.9:
165
- continue
166
- # not stripping the spaces -- should not be merged with them!
167
- w = word.word
168
- t = (word.start, word.end, w)
169
- o.append(t)
170
- return o
171
-
172
- def segments_end_ts(self, res):
173
- return [s.end for s in res]
174
-
175
- def use_vad(self):
176
- self.transcribe_kargs["vad_filter"] = True
177
-
178
- def set_translate_task(self):
179
- self.transcribe_kargs["task"] = "translate"
180
-
181
-
182
- class MLXWhisper(ASRBase):
183
- """
184
- Uses MPX Whisper library as the backend, optimized for Apple Silicon.
185
- Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
186
- Significantly faster than faster-whisper (without CUDA) on Apple M1.
187
- """
188
-
189
- sep = " "
190
-
191
- def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
192
- """
193
- Loads the MLX-compatible Whisper model.
194
-
195
- Args:
196
- modelsize (str, optional): The size or name of the Whisper model to load.
197
- If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
198
- Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
199
- cache_dir (str, optional): Path to the directory for caching models.
200
- **Note**: This is not supported by MLX Whisper and will be ignored.
201
- model_dir (str, optional): Direct path to a custom model directory.
202
- If specified, it overrides the `modelsize` parameter.
203
- """
204
- from mlx_whisper.transcribe import ModelHolder, transcribe
205
- import mlx.core as mx
206
-
207
- if model_dir is not None:
208
- logger.debug(
209
- f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
210
- )
211
- model_size_or_path = model_dir
212
- elif modelsize is not None:
213
- model_size_or_path = self.translate_model_name(modelsize)
214
- logger.debug(
215
- f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
216
- )
217
-
218
- self.model_size_or_path = model_size_or_path
219
-
220
- # In mlx_whisper.transcribe, dtype is defined as:
221
- # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
222
- # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
223
- dtype = mx.float16
224
- ModelHolder.get_model(model_size_or_path, dtype)
225
- return transcribe
226
-
227
- def translate_model_name(self, model_name):
228
- """
229
- Translates a given model name to its corresponding MLX-compatible model path.
230
-
231
- Args:
232
- model_name (str): The name of the model to translate.
233
-
234
- Returns:
235
- str: The MLX-compatible model path.
236
- """
237
- # Dictionary mapping model names to MLX-compatible paths
238
- model_mapping = {
239
- "tiny.en": "mlx-community/whisper-tiny.en-mlx",
240
- "tiny": "mlx-community/whisper-tiny-mlx",
241
- "base.en": "mlx-community/whisper-base.en-mlx",
242
- "base": "mlx-community/whisper-base-mlx",
243
- "small.en": "mlx-community/whisper-small.en-mlx",
244
- "small": "mlx-community/whisper-small-mlx",
245
- "medium.en": "mlx-community/whisper-medium.en-mlx",
246
- "medium": "mlx-community/whisper-medium-mlx",
247
- "large-v1": "mlx-community/whisper-large-v1-mlx",
248
- "large-v2": "mlx-community/whisper-large-v2-mlx",
249
- "large-v3": "mlx-community/whisper-large-v3-mlx",
250
- "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
251
- "large": "mlx-community/whisper-large-mlx",
252
- }
253
-
254
- # Retrieve the corresponding MLX model path
255
- mlx_model_path = model_mapping.get(model_name)
256
-
257
- if mlx_model_path:
258
- return mlx_model_path
259
- else:
260
- raise ValueError(
261
- f"Model name '{model_name}' is not recognized or not supported."
262
- )
263
-
264
- def transcribe(self, audio, init_prompt=""):
265
- if self.transcribe_kargs:
266
- logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
267
- segments = self.model(
268
- audio,
269
- language=self.original_language,
270
- initial_prompt=init_prompt,
271
- word_timestamps=True,
272
- condition_on_previous_text=True,
273
- path_or_hf_repo=self.model_size_or_path,
274
- )
275
- return segments.get("segments", [])
276
-
277
- def ts_words(self, segments):
278
- """
279
- Extract timestamped words from transcription segments and skips words with high no-speech probability.
280
- """
281
- return [
282
- (word["start"], word["end"], word["word"])
283
- for segment in segments
284
- for word in segment.get("words", [])
285
- if segment.get("no_speech_prob", 0) <= 0.9
286
- ]
287
-
288
- def segments_end_ts(self, res):
289
- return [s["end"] for s in res]
290
-
291
- def use_vad(self):
292
- self.transcribe_kargs["vad_filter"] = True
293
-
294
- def set_translate_task(self):
295
- self.transcribe_kargs["task"] = "translate"
296
-
297
-
298
- class OpenaiApiASR(ASRBase):
299
- """Uses OpenAI's Whisper API for audio transcription."""
300
-
301
- def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
302
- self.logfile = logfile
303
-
304
- self.modelname = "whisper-1"
305
- self.original_language = (
306
- None if lan == "auto" else lan
307
- ) # ISO-639-1 language code
308
- self.response_format = "verbose_json"
309
- self.temperature = temperature
310
-
311
- self.load_model()
312
-
313
- self.use_vad_opt = False
314
-
315
- # reset the task in set_translate_task
316
- self.task = "transcribe"
317
-
318
- def load_model(self, *args, **kwargs):
319
- from openai import OpenAI
320
-
321
- self.client = OpenAI()
322
-
323
- self.transcribed_seconds = (
324
- 0 # for logging how many seconds were processed by API, to know the cost
325
- )
326
-
327
- def ts_words(self, segments):
328
- no_speech_segments = []
329
- if self.use_vad_opt:
330
- for segment in segments.segments:
331
- # TODO: threshold can be set from outside
332
- if segment["no_speech_prob"] > 0.8:
333
- no_speech_segments.append(
334
- (segment.get("start"), segment.get("end"))
335
- )
336
-
337
- o = []
338
- for word in segments.words:
339
- start = word.start
340
- end = word.end
341
- if any(s[0] <= start <= s[1] for s in no_speech_segments):
342
- # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
343
- continue
344
- o.append((start, end, word.word))
345
- return o
346
-
347
- def segments_end_ts(self, res):
348
- return [s.end for s in res.words]
349
-
350
- def transcribe(self, audio_data, prompt=None, *args, **kwargs):
351
- # Write the audio data to a buffer
352
- buffer = io.BytesIO()
353
- buffer.name = "temp.wav"
354
- sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
355
- buffer.seek(0) # Reset buffer's position to the beginning
356
-
357
- self.transcribed_seconds += math.ceil(
358
- len(audio_data) / 16000
359
- ) # it rounds up to the whole seconds
360
-
361
- params = {
362
- "model": self.modelname,
363
- "file": buffer,
364
- "response_format": self.response_format,
365
- "temperature": self.temperature,
366
- "timestamp_granularities": ["word", "segment"],
367
- }
368
- if self.task != "translate" and self.original_language:
369
- params["language"] = self.original_language
370
- if prompt:
371
- params["prompt"] = prompt
372
-
373
- if self.task == "translate":
374
- proc = self.client.audio.translations
375
- else:
376
- proc = self.client.audio.transcriptions
377
-
378
- # Process transcription/translation
379
- transcript = proc.create(**params)
380
- logger.debug(
381
- f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
382
- )
383
-
384
- return transcript
385
-
386
- def use_vad(self):
387
- self.use_vad_opt = True
388
-
389
- def set_translate_task(self):
390
- self.task = "translate"
391
-
392
-
393
- class HypothesisBuffer:
394
-
395
- def __init__(self, logfile=sys.stderr):
396
- self.commited_in_buffer = []
397
- self.buffer = []
398
- self.new = []
399
-
400
- self.last_commited_time = 0
401
- self.last_commited_word = None
402
-
403
- self.logfile = logfile
404
-
405
- def insert(self, new, offset):
406
- # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
407
- # the new tail is added to self.new
408
-
409
- new = [(a + offset, b + offset, t) for a, b, t in new]
410
- self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
411
-
412
- if len(self.new) >= 1:
413
- a, b, t = self.new[0]
414
- if abs(a - self.last_commited_time) < 1:
415
- if self.commited_in_buffer:
416
- # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
417
- cn = len(self.commited_in_buffer)
418
- nn = len(self.new)
419
- for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
420
- c = " ".join(
421
- [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
422
- ::-1
423
- ]
424
- )
425
- tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
426
- if c == tail:
427
- words = []
428
- for j in range(i):
429
- words.append(repr(self.new.pop(0)))
430
- words_msg = " ".join(words)
431
- logger.debug(f"removing last {i} words: {words_msg}")
432
- break
433
-
434
- def flush(self):
435
- # returns commited chunk = the longest common prefix of 2 last inserts.
436
-
437
- commit = []
438
- while self.new:
439
- na, nb, nt = self.new[0]
440
-
441
- if len(self.buffer) == 0:
442
- break
443
-
444
- if nt == self.buffer[0][2]:
445
- commit.append((na, nb, nt))
446
- self.last_commited_word = nt
447
- self.last_commited_time = nb
448
- self.buffer.pop(0)
449
- self.new.pop(0)
450
- else:
451
- break
452
- self.buffer = self.new
453
- self.new = []
454
- self.commited_in_buffer.extend(commit)
455
- return commit
456
-
457
- def pop_commited(self, time):
458
- while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
459
- self.commited_in_buffer.pop(0)
460
-
461
- def complete(self):
462
- return self.buffer
463
-
464
-
465
- class OnlineASRProcessor:
466
-
467
- SAMPLING_RATE = 16000
468
-
469
- def __init__(
470
- self,
471
- asr,
472
- tokenize_method=None,
473
- buffer_trimming=("segment", 15),
474
- logfile=sys.stderr,
475
- ):
476
- """asr: WhisperASR object
477
- tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
478
- ("segment", 15)
479
- buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
480
- logfile: where to store the log.
481
- """
482
- self.asr = asr
483
- self.tokenize = tokenize_method
484
- self.logfile = logfile
485
-
486
- self.init()
487
-
488
- self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
489
-
490
- def init(self, offset=None):
491
- """run this when starting or restarting processing"""
492
- self.audio_buffer = np.array([], dtype=np.float32)
493
- self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
494
- self.buffer_time_offset = 0
495
- if offset is not None:
496
- self.buffer_time_offset = offset
497
- self.transcript_buffer.last_commited_time = self.buffer_time_offset
498
- self.commited = []
499
-
500
- def insert_audio_chunk(self, audio):
501
- self.audio_buffer = np.append(self.audio_buffer, audio)
502
-
503
- def prompt(self):
504
- """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
505
- "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
506
- """
507
- k = max(0, len(self.commited) - 1)
508
- while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
509
- k -= 1
510
-
511
- p = self.commited[:k]
512
- p = [t for _, _, t in p]
513
- prompt = []
514
- l = 0
515
- while p and l < 200: # 200 characters prompt size
516
- x = p.pop(-1)
517
- l += len(x) + 1
518
- prompt.append(x)
519
- non_prompt = self.commited[k:]
520
- return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
521
- t for _, _, t in non_prompt
522
- )
523
-
524
- def process_iter(self):
525
- """Runs on the current audio buffer.
526
- Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
527
- The non-emty text is confirmed (committed) partial transcript.
528
- """
529
-
530
- prompt, non_prompt = self.prompt()
531
- logger.debug(f"PROMPT: {prompt}")
532
- logger.debug(f"CONTEXT: {non_prompt}")
533
- logger.debug(
534
- f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
535
- )
536
- res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
537
-
538
- # transform to [(beg,end,"word1"), ...]
539
- tsw = self.asr.ts_words(res)
540
-
541
- self.transcript_buffer.insert(tsw, self.buffer_time_offset)
542
- o = self.transcript_buffer.flush()
543
- self.commited.extend(o)
544
- completed = self.to_flush(o)
545
- logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
546
- the_rest = self.to_flush(self.transcript_buffer.complete())
547
- logger.debug(f"INCOMPLETE: {the_rest[2]}")
548
-
549
- # there is a newly confirmed text
550
-
551
- if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
552
- if (
553
- len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
554
- ): # longer than this
555
- self.chunk_completed_sentence()
556
-
557
- if self.buffer_trimming_way == "segment":
558
- s = self.buffer_trimming_sec # trim the completed segments longer than s,
559
- else:
560
- s = 30 # if the audio buffer is longer than 30s, trim it
561
-
562
- if len(self.audio_buffer) / self.SAMPLING_RATE > s:
563
- self.chunk_completed_segment(res)
564
-
565
- # alternative: on any word
566
- # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
567
- # let's find commited word that is less
568
- # k = len(self.commited)-1
569
- # while k>0 and self.commited[k][1] > l:
570
- # k -= 1
571
- # t = self.commited[k][1]
572
- logger.debug("chunking segment")
573
- # self.chunk_at(t)
574
-
575
- logger.debug(
576
- f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
577
- )
578
- return self.to_flush(o)
579
-
580
- def chunk_completed_sentence(self):
581
- if self.commited == []:
582
- return
583
- logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
584
- sents = self.words_to_sentences(self.commited)
585
- for s in sents:
586
- logger.debug(f"\t\tSENT: {s}")
587
- if len(sents) < 2:
588
- return
589
- while len(sents) > 2:
590
- sents.pop(0)
591
- # we will continue with audio processing at this timestamp
592
- chunk_at = sents[-2][1]
593
-
594
- logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
595
- self.chunk_at(chunk_at)
596
-
597
- def chunk_completed_segment(self, res):
598
- if self.commited == []:
599
- return
600
-
601
- ends = self.asr.segments_end_ts(res)
602
-
603
- t = self.commited[-1][1]
604
-
605
- if len(ends) > 1:
606
-
607
- e = ends[-2] + self.buffer_time_offset
608
- while len(ends) > 2 and e > t:
609
- ends.pop(-1)
610
- e = ends[-2] + self.buffer_time_offset
611
- if e <= t:
612
- logger.debug(f"--- segment chunked at {e:2.2f}")
613
- self.chunk_at(e)
614
- else:
615
- logger.debug(f"--- last segment not within commited area")
616
- else:
617
- logger.debug(f"--- not enough segments to chunk")
618
-
619
- def chunk_at(self, time):
620
- """trims the hypothesis and audio buffer at "time" """
621
- self.transcript_buffer.pop_commited(time)
622
- cut_seconds = time - self.buffer_time_offset
623
- self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
624
- self.buffer_time_offset = time
625
-
626
- def words_to_sentences(self, words):
627
- """Uses self.tokenize for sentence segmentation of words.
628
- Returns: [(beg,end,"sentence 1"),...]
629
- """
630
-
631
- cwords = [w for w in words]
632
- t = " ".join(o[2] for o in cwords)
633
- s = self.tokenize(t)
634
- out = []
635
- while s:
636
- beg = None
637
- end = None
638
- sent = s.pop(0).strip()
639
- fsent = sent
640
- while cwords:
641
- b, e, w = cwords.pop(0)
642
- w = w.strip()
643
- if beg is None and sent.startswith(w):
644
- beg = b
645
- elif end is None and sent == w:
646
- end = e
647
- out.append((beg, end, fsent))
648
- break
649
- sent = sent[len(w) :].strip()
650
- return out
651
-
652
- def finish(self):
653
- """Flush the incomplete text when the whole processing ends.
654
- Returns: the same format as self.process_iter()
655
- """
656
- o = self.transcript_buffer.complete()
657
- f = self.to_flush(o)
658
- logger.debug(f"last, noncommited: {f}")
659
- self.buffer_time_offset += len(self.audio_buffer) / 16000
660
- return f
661
-
662
- def to_flush(
663
- self,
664
- sents,
665
- sep=None,
666
- offset=0,
667
- ):
668
- # concatenates the timestamped words or sentences into one sequence that is flushed in one line
669
- # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
670
- # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
671
- if sep is None:
672
- sep = self.asr.sep
673
- t = sep.join(s[2] for s in sents)
674
- if len(sents) == 0:
675
- b = None
676
- e = None
677
- else:
678
- b = offset + sents[0][0]
679
- e = offset + sents[-1][1]
680
- return (b, e, t)
681
-
682
-
683
- class VACOnlineASRProcessor(OnlineASRProcessor):
684
- """Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
685
-
686
- It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
687
- it runs VAD and continuously detects whether there is speech or not.
688
- When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
689
- """
690
-
691
- def __init__(self, online_chunk_size, *a, **kw):
692
- self.online_chunk_size = online_chunk_size
693
-
694
- self.online = OnlineASRProcessor(*a, **kw)
695
-
696
- # VAC:
697
- import torch
698
-
699
- model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
700
- from silero_vad_iterator import FixedVADIterator
701
-
702
- self.vac = FixedVADIterator(
703
- model
704
- ) # we use the default options there: 500ms silence, 100ms padding, etc.
705
-
706
- self.logfile = self.online.logfile
707
- self.init()
708
-
709
- def init(self):
710
- self.online.init()
711
- self.vac.reset_states()
712
- self.current_online_chunk_buffer_size = 0
713
-
714
- self.is_currently_final = False
715
-
716
- self.status = None # or "voice" or "nonvoice"
717
- self.audio_buffer = np.array([], dtype=np.float32)
718
- self.buffer_offset = 0 # in frames
719
-
720
- def clear_buffer(self):
721
- self.buffer_offset += len(self.audio_buffer)
722
- self.audio_buffer = np.array([], dtype=np.float32)
723
-
724
- def insert_audio_chunk(self, audio):
725
- res = self.vac(audio)
726
- self.audio_buffer = np.append(self.audio_buffer, audio)
727
-
728
- if res is not None:
729
- frame = list(res.values())[0] - self.buffer_offset
730
- if "start" in res and "end" not in res:
731
- self.status = "voice"
732
- send_audio = self.audio_buffer[frame:]
733
- self.online.init(
734
- offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
735
- )
736
- self.online.insert_audio_chunk(send_audio)
737
- self.current_online_chunk_buffer_size += len(send_audio)
738
- self.clear_buffer()
739
- elif "end" in res and "start" not in res:
740
- self.status = "nonvoice"
741
- send_audio = self.audio_buffer[:frame]
742
- self.online.insert_audio_chunk(send_audio)
743
- self.current_online_chunk_buffer_size += len(send_audio)
744
- self.is_currently_final = True
745
- self.clear_buffer()
746
- else:
747
- beg = res["start"] - self.buffer_offset
748
- end = res["end"] - self.buffer_offset
749
- self.status = "nonvoice"
750
- send_audio = self.audio_buffer[beg:end]
751
- self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
752
- self.online.insert_audio_chunk(send_audio)
753
- self.current_online_chunk_buffer_size += len(send_audio)
754
- self.is_currently_final = True
755
- self.clear_buffer()
756
- else:
757
- if self.status == "voice":
758
- self.online.insert_audio_chunk(self.audio_buffer)
759
- self.current_online_chunk_buffer_size += len(self.audio_buffer)
760
- self.clear_buffer()
761
- else:
762
- # We keep 1 second because VAD may later find start of voice in it.
763
- # But we trim it to prevent OOM.
764
- self.buffer_offset += max(
765
- 0, len(self.audio_buffer) - self.SAMPLING_RATE
766
- )
767
- self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
768
-
769
- def process_iter(self):
770
- if self.is_currently_final:
771
- return self.finish()
772
- elif (
773
- self.current_online_chunk_buffer_size
774
- > self.SAMPLING_RATE * self.online_chunk_size
775
- ):
776
- self.current_online_chunk_buffer_size = 0
777
- ret = self.online.process_iter()
778
- return ret
779
- else:
780
- print("no online update, only VAD", self.status, file=self.logfile)
781
- return (None, None, "")
782
-
783
- def finish(self):
784
- ret = self.online.finish()
785
- self.current_online_chunk_buffer_size = 0
786
- self.is_currently_final = False
787
- return ret
788
-
789
-
790
  WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
791
  ","
792
  )
@@ -852,7 +88,7 @@ def add_shared_args(parser):
852
  parser.add_argument(
853
  "--model",
854
  type=str,
855
- default="tiny",
856
  choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
857
  ","
858
  ),
@@ -887,14 +123,14 @@ def add_shared_args(parser):
887
  parser.add_argument(
888
  "--backend",
889
  type=str,
890
- default="mlx-whisper",
891
  choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
892
  help="Load only this backend for Whisper processing.",
893
  )
894
  parser.add_argument(
895
  "--vac",
896
  action="store_true",
897
- default=True,
898
  help="Use VAC = voice activity controller. Recommended. Requires torch.",
899
  )
900
  parser.add_argument(
@@ -903,7 +139,7 @@ def add_shared_args(parser):
903
  parser.add_argument(
904
  "--vad",
905
  action="store_true",
906
- default=True,
907
  help="Use VAD = voice activity detection, with the default parameters.",
908
  )
909
  parser.add_argument(
 
5
  from functools import lru_cache
6
  import time
7
  import logging
8
+ from src.whisper_streaming.backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
9
+ from src.whisper_streaming.online_asr import OnlineASRProcessor, VACOnlineASRProcessor
 
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
23
  end_s = int(end * 16000)
24
  return audio[beg_s:end_s]
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
27
  ","
28
  )
 
88
  parser.add_argument(
89
  "--model",
90
  type=str,
91
+ default="large-v3-turbo",
92
  choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
93
  ","
94
  ),
 
123
  parser.add_argument(
124
  "--backend",
125
  type=str,
126
+ default="faster-whisper",
127
  choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
128
  help="Load only this backend for Whisper processing.",
129
  )
130
  parser.add_argument(
131
  "--vac",
132
  action="store_true",
133
+ default=False,
134
  help="Use VAC = voice activity controller. Recommended. Requires torch.",
135
  )
136
  parser.add_argument(
 
139
  parser.add_argument(
140
  "--vad",
141
  action="store_true",
142
+ default=False,
143
  help="Use VAD = voice activity detection, with the default parameters.",
144
  )
145
  parser.add_argument(