qfuxa commited on
Commit
e33bd9c
·
1 Parent(s): 45b3df2

split backends and online asr files

Browse files
whisper_streaming_custom/backends.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import io
4
+ import soundfile as sf
5
+ import math
6
+ import torch
7
+ from typing import List
8
+ import numpy as np
9
+ from timed_objects import ASRToken
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class ASRBase:
14
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
15
+ # "" for faster-whisper because it emits the spaces when needed)
16
+
17
+ def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
18
+ self.logfile = logfile
19
+ self.transcribe_kargs = {}
20
+ if lan == "auto":
21
+ self.original_language = None
22
+ else:
23
+ self.original_language = lan
24
+ self.model = self.load_model(modelsize, cache_dir, model_dir)
25
+
26
+ def with_offset(self, offset: float) -> ASRToken:
27
+ # This method is kept for compatibility (typically you will use ASRToken.with_offset)
28
+ return ASRToken(self.start + offset, self.end + offset, self.text)
29
+
30
+ def __repr__(self):
31
+ return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
32
+
33
+ def load_model(self, modelsize, cache_dir, model_dir):
34
+ raise NotImplementedError("must be implemented in the child class")
35
+
36
+ def transcribe(self, audio, init_prompt=""):
37
+ raise NotImplementedError("must be implemented in the child class")
38
+
39
+ def use_vad(self):
40
+ raise NotImplementedError("must be implemented in the child class")
41
+
42
+
43
+ class WhisperTimestampedASR(ASRBase):
44
+ """Uses whisper_timestamped as the backend."""
45
+ sep = " "
46
+
47
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
48
+ import whisper
49
+ import whisper_timestamped
50
+ from whisper_timestamped import transcribe_timestamped
51
+
52
+ self.transcribe_timestamped = transcribe_timestamped
53
+ if model_dir is not None:
54
+ logger.debug("ignoring model_dir, not implemented")
55
+ return whisper.load_model(modelsize, download_root=cache_dir)
56
+
57
+ def transcribe(self, audio, init_prompt=""):
58
+ result = self.transcribe_timestamped(
59
+ self.model,
60
+ audio,
61
+ language=self.original_language,
62
+ initial_prompt=init_prompt,
63
+ verbose=None,
64
+ condition_on_previous_text=True,
65
+ **self.transcribe_kargs,
66
+ )
67
+ return result
68
+
69
+ def ts_words(self, r) -> List[ASRToken]:
70
+ """
71
+ Converts the whisper_timestamped result to a list of ASRToken objects.
72
+ """
73
+ tokens = []
74
+ for segment in r["segments"]:
75
+ for word in segment["words"]:
76
+ token = ASRToken(word["start"], word["end"], word["text"])
77
+ tokens.append(token)
78
+ return tokens
79
+
80
+ def segments_end_ts(self, res) -> List[float]:
81
+ return [segment["end"] for segment in res["segments"]]
82
+
83
+ def use_vad(self):
84
+ self.transcribe_kargs["vad"] = True
85
+
86
+ def set_translate_task(self):
87
+ self.transcribe_kargs["task"] = "translate"
88
+
89
+
90
+ class FasterWhisperASR(ASRBase):
91
+ """Uses faster-whisper as the backend."""
92
+ sep = ""
93
+
94
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
95
+ from faster_whisper import WhisperModel
96
+
97
+ if model_dir is not None:
98
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. "
99
+ f"modelsize and cache_dir parameters are not used.")
100
+ model_size_or_path = model_dir
101
+ elif modelsize is not None:
102
+ model_size_or_path = modelsize
103
+ else:
104
+ raise ValueError("Either modelsize or model_dir must be set")
105
+ device = "cuda" if torch.cuda.is_available() else "cpu"
106
+ compute_type = "float16" if device == "cuda" else "float32"
107
+
108
+ model = WhisperModel(
109
+ model_size_or_path,
110
+ device=device,
111
+ compute_type=compute_type,
112
+ download_root=cache_dir,
113
+ )
114
+ return model
115
+
116
+ def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
117
+ segments, info = self.model.transcribe(
118
+ audio,
119
+ language=self.original_language,
120
+ initial_prompt=init_prompt,
121
+ beam_size=5,
122
+ word_timestamps=True,
123
+ condition_on_previous_text=True,
124
+ **self.transcribe_kargs,
125
+ )
126
+ return list(segments)
127
+
128
+ def ts_words(self, segments) -> List[ASRToken]:
129
+ tokens = []
130
+ for segment in segments:
131
+ if segment.no_speech_prob > 0.9:
132
+ continue
133
+ for word in segment.words:
134
+ token = ASRToken(word.start, word.end, word.word)
135
+ tokens.append(token)
136
+ return tokens
137
+
138
+ def segments_end_ts(self, segments) -> List[float]:
139
+ return [segment.end for segment in segments]
140
+
141
+ def use_vad(self):
142
+ self.transcribe_kargs["vad_filter"] = True
143
+
144
+ def set_translate_task(self):
145
+ self.transcribe_kargs["task"] = "translate"
146
+
147
+
148
+ class MLXWhisper(ASRBase):
149
+ """
150
+ Uses MLX Whisper optimized for Apple Silicon.
151
+ """
152
+ sep = ""
153
+
154
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
155
+ from mlx_whisper.transcribe import ModelHolder, transcribe
156
+ import mlx.core as mx
157
+
158
+ if model_dir is not None:
159
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
160
+ model_size_or_path = model_dir
161
+ elif modelsize is not None:
162
+ model_size_or_path = self.translate_model_name(modelsize)
163
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
164
+ else:
165
+ raise ValueError("Either modelsize or model_dir must be set")
166
+
167
+ self.model_size_or_path = model_size_or_path
168
+ dtype = mx.float16
169
+ ModelHolder.get_model(model_size_or_path, dtype)
170
+ return transcribe
171
+
172
+ def translate_model_name(self, model_name):
173
+ model_mapping = {
174
+ "tiny.en": "mlx-community/whisper-tiny.en-mlx",
175
+ "tiny": "mlx-community/whisper-tiny-mlx",
176
+ "base.en": "mlx-community/whisper-base.en-mlx",
177
+ "base": "mlx-community/whisper-base-mlx",
178
+ "small.en": "mlx-community/whisper-small.en-mlx",
179
+ "small": "mlx-community/whisper-small-mlx",
180
+ "medium.en": "mlx-community/whisper-medium.en-mlx",
181
+ "medium": "mlx-community/whisper-medium-mlx",
182
+ "large-v1": "mlx-community/whisper-large-v1-mlx",
183
+ "large-v2": "mlx-community/whisper-large-v2-mlx",
184
+ "large-v3": "mlx-community/whisper-large-v3-mlx",
185
+ "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
186
+ "large": "mlx-community/whisper-large-mlx",
187
+ }
188
+ mlx_model_path = model_mapping.get(model_name)
189
+ if mlx_model_path:
190
+ return mlx_model_path
191
+ else:
192
+ raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
193
+
194
+ def transcribe(self, audio, init_prompt=""):
195
+ if self.transcribe_kargs:
196
+ logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
197
+ segments = self.model(
198
+ audio,
199
+ language=self.original_language,
200
+ initial_prompt=init_prompt,
201
+ word_timestamps=True,
202
+ condition_on_previous_text=True,
203
+ path_or_hf_repo=self.model_size_or_path,
204
+ )
205
+ return segments.get("segments", [])
206
+
207
+ def ts_words(self, segments) -> List[ASRToken]:
208
+ tokens = []
209
+ for segment in segments:
210
+ if segment.get("no_speech_prob", 0) > 0.9:
211
+ continue
212
+ for word in segment.get("words", []):
213
+ token = ASRToken(word["start"], word["end"], word["word"])
214
+ tokens.append(token)
215
+ return tokens
216
+
217
+ def segments_end_ts(self, res) -> List[float]:
218
+ return [s["end"] for s in res]
219
+
220
+ def use_vad(self):
221
+ self.transcribe_kargs["vad_filter"] = True
222
+
223
+ def set_translate_task(self):
224
+ self.transcribe_kargs["task"] = "translate"
225
+
226
+
227
+ class OpenaiApiASR(ASRBase):
228
+ """Uses OpenAI's Whisper API for transcription."""
229
+ def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
230
+ self.logfile = logfile
231
+ self.modelname = "whisper-1"
232
+ self.original_language = None if lan == "auto" else lan
233
+ self.response_format = "verbose_json"
234
+ self.temperature = temperature
235
+ self.load_model()
236
+ self.use_vad_opt = False
237
+ self.task = "transcribe"
238
+
239
+ def load_model(self, *args, **kwargs):
240
+ from openai import OpenAI
241
+ self.client = OpenAI()
242
+ self.transcribed_seconds = 0
243
+
244
+ def ts_words(self, segments) -> List[ASRToken]:
245
+ """
246
+ Converts OpenAI API response words into ASRToken objects while
247
+ optionally skipping words that fall into no-speech segments.
248
+ """
249
+ no_speech_segments = []
250
+ if self.use_vad_opt:
251
+ for segment in segments.segments:
252
+ if segment["no_speech_prob"] > 0.8:
253
+ no_speech_segments.append((segment.get("start"), segment.get("end")))
254
+ tokens = []
255
+ for word in segments.words:
256
+ start = word.start
257
+ end = word.end
258
+ if any(s[0] <= start <= s[1] for s in no_speech_segments):
259
+ continue
260
+ tokens.append(ASRToken(start, end, word.word))
261
+ return tokens
262
+
263
+ def segments_end_ts(self, res) -> List[float]:
264
+ return [s.end for s in res.words]
265
+
266
+ def transcribe(self, audio_data, prompt=None, *args, **kwargs):
267
+ buffer = io.BytesIO()
268
+ buffer.name = "temp.wav"
269
+ sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
270
+ buffer.seek(0)
271
+ self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
272
+ params = {
273
+ "model": self.modelname,
274
+ "file": buffer,
275
+ "response_format": self.response_format,
276
+ "temperature": self.temperature,
277
+ "timestamp_granularities": ["word", "segment"],
278
+ }
279
+ if self.task != "translate" and self.original_language:
280
+ params["language"] = self.original_language
281
+ if prompt:
282
+ params["prompt"] = prompt
283
+ proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
284
+ transcript = proc.create(**params)
285
+ logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
286
+ return transcript
287
+
288
+ def use_vad(self):
289
+ self.use_vad_opt = True
290
+
291
+ def set_translate_task(self):
292
+ self.task = "translate"
whisper_streaming_custom/online_asr.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import logging
4
+ from typing import List, Tuple, Optional
5
+ from timed_objects import ASRToken, Sentence, Transcript
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class HypothesisBuffer:
11
+ """
12
+ Buffer to store and process ASR hypothesis tokens.
13
+
14
+ It holds:
15
+ - committed_in_buffer: tokens that have been confirmed (committed)
16
+ - buffer: the last hypothesis that is not yet committed
17
+ - new: new tokens coming from the recognizer
18
+ """
19
+ def __init__(self, logfile=sys.stderr):
20
+ self.committed_in_buffer: List[ASRToken] = []
21
+ self.buffer: List[ASRToken] = []
22
+ self.new: List[ASRToken] = []
23
+ self.last_committed_time = 0.0
24
+ self.last_committed_word: Optional[str] = None
25
+ self.logfile = logfile
26
+
27
+ def insert(self, new_tokens: List[ASRToken], offset: float):
28
+ """
29
+ Insert new tokens (after applying a time offset) and compare them with the
30
+ already committed tokens. Only tokens that extend the committed hypothesis
31
+ are added.
32
+ """
33
+ # Apply the offset to each token.
34
+ new_tokens = [token.with_offset(offset) for token in new_tokens]
35
+ # Only keep tokens that are roughly “new”
36
+ self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
37
+
38
+ if self.new:
39
+ first_token = self.new[0]
40
+ if abs(first_token.start - self.last_committed_time) < 1:
41
+ if self.committed_in_buffer:
42
+ committed_len = len(self.committed_in_buffer)
43
+ new_len = len(self.new)
44
+ # Try to match 1 to 5 consecutive tokens
45
+ max_ngram = min(min(committed_len, new_len), 5)
46
+ for i in range(1, max_ngram + 1):
47
+ committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
48
+ new_ngram = " ".join(token.text for token in self.new[:i])
49
+ if committed_ngram == new_ngram:
50
+ removed = []
51
+ for _ in range(i):
52
+ removed_token = self.new.pop(0)
53
+ removed.append(repr(removed_token))
54
+ logger.debug(f"Removing last {i} words: {' '.join(removed)}")
55
+ break
56
+
57
+ def flush(self) -> List[ASRToken]:
58
+ """
59
+ Returns the committed chunk, defined as the longest common prefix
60
+ between the previous hypothesis and the new tokens.
61
+ """
62
+ committed: List[ASRToken] = []
63
+ while self.new:
64
+ current_new = self.new[0]
65
+ if not self.buffer:
66
+ break
67
+ if current_new.text == self.buffer[0].text:
68
+ committed.append(current_new)
69
+ self.last_committed_word = current_new.text
70
+ self.last_committed_time = current_new.end
71
+ self.buffer.pop(0)
72
+ self.new.pop(0)
73
+ else:
74
+ break
75
+ self.buffer = self.new
76
+ self.new = []
77
+ self.committed_in_buffer.extend(committed)
78
+ return committed
79
+
80
+ def pop_committed(self, time: float):
81
+ """
82
+ Remove tokens (from the beginning) that have ended before `time`.
83
+ """
84
+ while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
85
+ self.committed_in_buffer.pop(0)
86
+
87
+
88
+
89
+ class OnlineASRProcessor:
90
+ """
91
+ Processes incoming audio in a streaming fashion, calling the ASR system
92
+ periodically, and uses a hypothesis buffer to commit and trim recognized text.
93
+
94
+ The processor supports two types of buffer trimming:
95
+ - "sentence": trims at sentence boundaries (using a sentence tokenizer)
96
+ - "segment": trims at fixed segment durations.
97
+ """
98
+ SAMPLING_RATE = 16000
99
+
100
+ def __init__(
101
+ self,
102
+ asr,
103
+ tokenize_method: Optional[callable] = None,
104
+ buffer_trimming: Tuple[str, float] = ("segment", 15),
105
+ logfile=sys.stderr,
106
+ ):
107
+ """
108
+ asr: An ASR system object (for example, a WhisperASR instance) that
109
+ provides a `transcribe` method, a `ts_words` method (to extract tokens),
110
+ a `segments_end_ts` method, and a separator attribute `sep`.
111
+ tokenize_method: A function that receives text and returns a list of sentence strings.
112
+ buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
113
+ """
114
+ self.asr = asr
115
+ self.tokenize = tokenize_method
116
+ self.logfile = logfile
117
+
118
+ self.init()
119
+
120
+ self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
121
+
122
+ if self.buffer_trimming_way not in ["sentence", "segment"]:
123
+ raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
124
+ if self.buffer_trimming_sec <= 0:
125
+ raise ValueError("buffer_trimming_sec must be positive")
126
+ elif self.buffer_trimming_sec > 30:
127
+ logger.warning(
128
+ f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
129
+ )
130
+
131
+ def init(self, offset: Optional[float] = None):
132
+ """Initialize or reset the processing buffers."""
133
+ self.audio_buffer = np.array([], dtype=np.float32)
134
+ self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
135
+ self.buffer_time_offset = offset if offset is not None else 0.0
136
+ self.transcript_buffer.last_committed_time = self.buffer_time_offset
137
+ self.committed: List[ASRToken] = []
138
+
139
+ def insert_audio_chunk(self, audio: np.ndarray):
140
+ """Append an audio chunk (a numpy array) to the current audio buffer."""
141
+ self.audio_buffer = np.append(self.audio_buffer, audio)
142
+
143
+ def prompt(self) -> Tuple[str, str]:
144
+ """
145
+ Returns a tuple: (prompt, context), where:
146
+ - prompt is a 200-character suffix of committed text that falls
147
+ outside the current audio buffer.
148
+ - context is the committed text within the current audio buffer.
149
+ """
150
+ k = len(self.committed)
151
+ while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
152
+ k -= 1
153
+
154
+ prompt_tokens = self.committed[:k]
155
+ prompt_words = [token.text for token in prompt_tokens]
156
+ prompt_list = []
157
+ length_count = 0
158
+ # Use the last words until reaching 200 characters.
159
+ while prompt_words and length_count < 200:
160
+ word = prompt_words.pop(-1)
161
+ length_count += len(word) + 1
162
+ prompt_list.append(word)
163
+ non_prompt_tokens = self.committed[k:]
164
+ context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
165
+ return self.asr.sep.join(prompt_list[::-1]), context_text
166
+
167
+ def get_buffer(self):
168
+ """
169
+ Get the unvalidated buffer in string format.
170
+ """
171
+ return self.concatenate_tokens(self.transcript_buffer.buffer)
172
+
173
+
174
+ def process_iter(self) -> Transcript:
175
+ """
176
+ Processes the current audio buffer.
177
+
178
+ Returns a Transcript object representing the committed transcript.
179
+ """
180
+ prompt_text, _ = self.prompt()
181
+ logger.debug(
182
+ f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
183
+ )
184
+ res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
185
+ tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
186
+ self.transcript_buffer.insert(tokens, self.buffer_time_offset)
187
+ committed_tokens = self.transcript_buffer.flush()
188
+ self.committed.extend(committed_tokens)
189
+ completed = self.concatenate_tokens(committed_tokens)
190
+ logger.debug(f">>>> COMPLETE NOW: {completed.text}")
191
+ incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
192
+ logger.debug(f"INCOMPLETE: {incomp.text}")
193
+
194
+ if committed_tokens and self.buffer_trimming_way == "sentence":
195
+ if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
196
+ self.chunk_completed_sentence()
197
+
198
+ s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
199
+ if len(self.audio_buffer) / self.SAMPLING_RATE > s:
200
+ self.chunk_completed_segment(res)
201
+ logger.debug("Chunking segment")
202
+ logger.debug(
203
+ f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
204
+ )
205
+ return committed_tokens
206
+
207
+ def chunk_completed_sentence(self):
208
+ """
209
+ If the committed tokens form at least two sentences, chunk the audio
210
+ buffer at the end time of the penultimate sentence.
211
+ """
212
+ if not self.committed:
213
+ return
214
+ logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
215
+ sentences = self.words_to_sentences(self.committed)
216
+ for sentence in sentences:
217
+ logger.debug(f"\tSentence: {sentence.text}")
218
+ if len(sentences) < 2:
219
+ return
220
+ # Keep the last two sentences.
221
+ while len(sentences) > 2:
222
+ sentences.pop(0)
223
+ chunk_time = sentences[-2].end
224
+ logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
225
+ self.chunk_at(chunk_time)
226
+
227
+ def chunk_completed_segment(self, res):
228
+ """
229
+ Chunk the audio buffer based on segment-end timestamps reported by the ASR.
230
+ """
231
+ if not self.committed:
232
+ return
233
+ ends = self.asr.segments_end_ts(res)
234
+ last_committed_time = self.committed[-1].end
235
+ if len(ends) > 1:
236
+ e = ends[-2] + self.buffer_time_offset
237
+ while len(ends) > 2 and e > last_committed_time:
238
+ ends.pop(-1)
239
+ e = ends[-2] + self.buffer_time_offset
240
+ if e <= last_committed_time:
241
+ logger.debug(f"--- Segment chunked at {e:.2f}")
242
+ self.chunk_at(e)
243
+ else:
244
+ logger.debug("--- Last segment not within committed area")
245
+ else:
246
+ logger.debug("--- Not enough segments to chunk")
247
+
248
+ def chunk_at(self, time: float):
249
+ """
250
+ Trim both the hypothesis and audio buffer at the given time.
251
+ """
252
+ logger.debug(f"Chunking at {time:.2f}s")
253
+ logger.debug(
254
+ f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
255
+ )
256
+ self.transcript_buffer.pop_committed(time)
257
+ cut_seconds = time - self.buffer_time_offset
258
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
259
+ self.buffer_time_offset = time
260
+ logger.debug(
261
+ f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
262
+ )
263
+
264
+ def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
265
+ """
266
+ Converts a list of tokens to a list of Sentence objects using the provided
267
+ sentence tokenizer.
268
+ """
269
+ if not tokens:
270
+ return []
271
+
272
+ full_text = " ".join(token.text for token in tokens)
273
+
274
+ if self.tokenize:
275
+ try:
276
+ sentence_texts = self.tokenize(full_text)
277
+ except Exception as e:
278
+ # Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
279
+ try:
280
+ sentence_texts = self.tokenize([full_text])
281
+ except Exception as e2:
282
+ raise ValueError("Tokenization failed") from e2
283
+ else:
284
+ sentence_texts = [full_text]
285
+
286
+ sentences: List[Sentence] = []
287
+ token_index = 0
288
+ for sent_text in sentence_texts:
289
+ sent_text = sent_text.strip()
290
+ if not sent_text:
291
+ continue
292
+ sent_tokens = []
293
+ accumulated = ""
294
+ # Accumulate tokens until roughly matching the length of the sentence text.
295
+ while token_index < len(tokens) and len(accumulated) < len(sent_text):
296
+ token = tokens[token_index]
297
+ accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
298
+ sent_tokens.append(token)
299
+ token_index += 1
300
+ if sent_tokens:
301
+ sentence = Sentence(
302
+ start=sent_tokens[0].start,
303
+ end=sent_tokens[-1].end,
304
+ text=" ".join(t.text for t in sent_tokens),
305
+ )
306
+ sentences.append(sentence)
307
+ return sentences
308
+ def finish(self) -> Transcript:
309
+ """
310
+ Flush the remaining transcript when processing ends.
311
+ """
312
+ remaining_tokens = self.transcript_buffer.buffer
313
+ final_transcript = self.concatenate_tokens(remaining_tokens)
314
+ logger.debug(f"Final non-committed transcript: {final_transcript}")
315
+ self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
316
+ return final_transcript
317
+
318
+ def concatenate_tokens(
319
+ self,
320
+ tokens: List[ASRToken],
321
+ sep: Optional[str] = None,
322
+ offset: float = 0
323
+ ) -> Transcript:
324
+ sep = sep if sep is not None else self.asr.sep
325
+ text = sep.join(token.text for token in tokens)
326
+ if tokens:
327
+ start = offset + tokens[0].start
328
+ end = offset + tokens[-1].end
329
+ else:
330
+ start = None
331
+ end = None
332
+ return Transcript(start, end, text)
333
+
334
+
335
+ class VACOnlineASRProcessor:
336
+ """
337
+ Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
338
+
339
+ It receives small chunks of audio, applies VAD (e.g. with Silero),
340
+ and when the system detects a pause in speech (or end of an utterance)
341
+ it finalizes the utterance immediately.
342
+ """
343
+ SAMPLING_RATE = 16000
344
+
345
+ def __init__(self, online_chunk_size: float, *args, **kwargs):
346
+ self.online_chunk_size = online_chunk_size
347
+ self.online = OnlineASRProcessor(*args, **kwargs)
348
+
349
+ # Load a VAD model (e.g. Silero VAD)
350
+ import torch
351
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
352
+ from silero_vad_iterator import FixedVADIterator
353
+
354
+ self.vac = FixedVADIterator(model)
355
+ self.logfile = self.online.logfile
356
+ self.init()
357
+
358
+ def init(self):
359
+ self.online.init()
360
+ self.vac.reset_states()
361
+ self.current_online_chunk_buffer_size = 0
362
+ self.is_currently_final = False
363
+ self.status: Optional[str] = None # "voice" or "nonvoice"
364
+ self.audio_buffer = np.array([], dtype=np.float32)
365
+ self.buffer_offset = 0 # in frames
366
+
367
+ def clear_buffer(self):
368
+ self.buffer_offset += len(self.audio_buffer)
369
+ self.audio_buffer = np.array([], dtype=np.float32)
370
+
371
+ def insert_audio_chunk(self, audio: np.ndarray):
372
+ """
373
+ Process an incoming small audio chunk:
374
+ - run VAD on the chunk,
375
+ - decide whether to send the audio to the online ASR processor immediately,
376
+ - and/or to mark the current utterance as finished.
377
+ """
378
+ res = self.vac(audio)
379
+ self.audio_buffer = np.append(self.audio_buffer, audio)
380
+
381
+ if res is not None:
382
+ # VAD returned a result; adjust the frame number
383
+ frame = list(res.values())[0] - self.buffer_offset
384
+ if "start" in res and "end" not in res:
385
+ self.status = "voice"
386
+ send_audio = self.audio_buffer[frame:]
387
+ self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
388
+ self.online.insert_audio_chunk(send_audio)
389
+ self.current_online_chunk_buffer_size += len(send_audio)
390
+ self.clear_buffer()
391
+ elif "end" in res and "start" not in res:
392
+ self.status = "nonvoice"
393
+ send_audio = self.audio_buffer[:frame]
394
+ self.online.insert_audio_chunk(send_audio)
395
+ self.current_online_chunk_buffer_size += len(send_audio)
396
+ self.is_currently_final = True
397
+ self.clear_buffer()
398
+ else:
399
+ beg = res["start"] - self.buffer_offset
400
+ end = res["end"] - self.buffer_offset
401
+ self.status = "nonvoice"
402
+ send_audio = self.audio_buffer[beg:end]
403
+ self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
404
+ self.online.insert_audio_chunk(send_audio)
405
+ self.current_online_chunk_buffer_size += len(send_audio)
406
+ self.is_currently_final = True
407
+ self.clear_buffer()
408
+ else:
409
+ if self.status == "voice":
410
+ self.online.insert_audio_chunk(self.audio_buffer)
411
+ self.current_online_chunk_buffer_size += len(self.audio_buffer)
412
+ self.clear_buffer()
413
+ else:
414
+ # Keep 1 second worth of audio in case VAD later detects voice,
415
+ # but trim to avoid unbounded memory usage.
416
+ self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
417
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
418
+
419
+ def process_iter(self) -> Transcript:
420
+ """
421
+ Depending on the VAD status and the amount of accumulated audio,
422
+ process the current audio chunk.
423
+ """
424
+ if self.is_currently_final:
425
+ return self.finish()
426
+ elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
427
+ self.current_online_chunk_buffer_size = 0
428
+ return self.online.process_iter()
429
+ else:
430
+ logger.debug("No online update, only VAD")
431
+ return Transcript(None, None, "")
432
+
433
+ def finish(self) -> Transcript:
434
+ """Finish processing by flushing any remaining text."""
435
+ result = self.online.finish()
436
+ self.current_online_chunk_buffer_size = 0
437
+ self.is_currently_final = False
438
+ return result
439
+
440
+ def get_buffer(self):
441
+ """
442
+ Get the unvalidated buffer in string format.
443
+ """
444
+ return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
whisper_streaming_custom/whisper_online.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import numpy as np
4
+ import librosa
5
+ from functools import lru_cache
6
+ import time
7
+ import logging
8
+ from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
9
+ from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+
15
+ WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
16
+ ","
17
+ )
18
+
19
+
20
+ def create_tokenizer(lan):
21
+ """returns an object that has split function that works like the one of MosesTokenizer"""
22
+
23
+ assert (
24
+ lan in WHISPER_LANG_CODES
25
+ ), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
26
+
27
+ if lan == "uk":
28
+ import tokenize_uk
29
+
30
+ class UkrainianTokenizer:
31
+ def split(self, text):
32
+ return tokenize_uk.tokenize_sents(text)
33
+
34
+ return UkrainianTokenizer()
35
+
36
+ # supported by fast-mosestokenizer
37
+ if (
38
+ lan
39
+ in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
40
+ ):
41
+ from mosestokenizer import MosesSentenceSplitter
42
+
43
+ return MosesSentenceSplitter(lan)
44
+
45
+ # the following languages are in Whisper, but not in wtpsplit:
46
+ if (
47
+ lan
48
+ in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
49
+ ):
50
+ logger.debug(
51
+ f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
52
+ )
53
+ lan = None
54
+
55
+ from wtpsplit import WtP
56
+
57
+ # downloads the model from huggingface on the first use
58
+ wtp = WtP("wtp-canine-s-12l-no-adapters")
59
+
60
+ class WtPtok:
61
+ def split(self, sent):
62
+ return wtp.split(sent, lang_code=lan)
63
+
64
+ return WtPtok()
65
+
66
+
67
+ def add_shared_args(parser):
68
+ """shared args for simulation (this entry point) and server
69
+ parser: argparse.ArgumentParser object
70
+ """
71
+ parser.add_argument(
72
+ "--min-chunk-size",
73
+ type=float,
74
+ default=1.0,
75
+ help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
76
+ )
77
+ parser.add_argument(
78
+ "--model",
79
+ type=str,
80
+ default="tiny.en",
81
+ choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
82
+ ","
83
+ ),
84
+ help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
85
+ )
86
+ parser.add_argument(
87
+ "--model_cache_dir",
88
+ type=str,
89
+ default=None,
90
+ help="Overriding the default model cache dir where models downloaded from the hub are saved",
91
+ )
92
+ parser.add_argument(
93
+ "--model_dir",
94
+ type=str,
95
+ default=None,
96
+ help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
97
+ )
98
+ parser.add_argument(
99
+ "--lan",
100
+ "--language",
101
+ type=str,
102
+ default="auto",
103
+ help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
104
+ )
105
+ parser.add_argument(
106
+ "--task",
107
+ type=str,
108
+ default="transcribe",
109
+ choices=["transcribe", "translate"],
110
+ help="Transcribe or translate.",
111
+ )
112
+ parser.add_argument(
113
+ "--backend",
114
+ type=str,
115
+ default="faster-whisper",
116
+ choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
117
+ help="Load only this backend for Whisper processing.",
118
+ )
119
+ parser.add_argument(
120
+ "--vac",
121
+ action="store_true",
122
+ default=False,
123
+ help="Use VAC = voice activity controller. Recommended. Requires torch.",
124
+ )
125
+ parser.add_argument(
126
+ "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
127
+ )
128
+ parser.add_argument(
129
+ "--vad",
130
+ action="store_true",
131
+ default=False,
132
+ help="Use VAD = voice activity detection, with the default parameters.",
133
+ )
134
+ parser.add_argument(
135
+ "--buffer_trimming",
136
+ type=str,
137
+ default="segment",
138
+ choices=["sentence", "segment"],
139
+ help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
140
+ )
141
+ parser.add_argument(
142
+ "--buffer_trimming_sec",
143
+ type=float,
144
+ default=15,
145
+ help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
146
+ )
147
+ parser.add_argument(
148
+ "-l",
149
+ "--log-level",
150
+ dest="log_level",
151
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
152
+ help="Set the log level",
153
+ default="DEBUG",
154
+ )
155
+
156
+ def backend_factory(args):
157
+ backend = args.backend
158
+ if backend == "openai-api":
159
+ logger.debug("Using OpenAI API.")
160
+ asr = OpenaiApiASR(lan=args.lan)
161
+ else:
162
+ if backend == "faster-whisper":
163
+ asr_cls = FasterWhisperASR
164
+ elif backend == "mlx-whisper":
165
+ asr_cls = MLXWhisper
166
+ else:
167
+ asr_cls = WhisperTimestampedASR
168
+
169
+ # Only for FasterWhisperASR and WhisperTimestampedASR
170
+ size = args.model
171
+ t = time.time()
172
+ logger.info(f"Loading Whisper {size} model for language {args.lan}...")
173
+ asr = asr_cls(
174
+ modelsize=size,
175
+ lan=args.lan,
176
+ cache_dir=args.model_cache_dir,
177
+ model_dir=args.model_dir,
178
+ )
179
+ e = time.time()
180
+ logger.info(f"done. It took {round(e-t,2)} seconds.")
181
+
182
+ # Apply common configurations
183
+ if getattr(args, "vad", False): # Checks if VAD argument is present and True
184
+ logger.info("Setting VAD filter")
185
+ asr.use_vad()
186
+
187
+ language = args.lan
188
+ if args.task == "translate":
189
+ asr.set_translate_task()
190
+ tgt_language = "en" # Whisper translates into English
191
+ else:
192
+ tgt_language = language # Whisper transcribes in this language
193
+
194
+ # Create the tokenizer
195
+ if args.buffer_trimming == "sentence":
196
+
197
+ tokenizer = create_tokenizer(tgt_language)
198
+ else:
199
+ tokenizer = None
200
+ return asr, tokenizer
201
+
202
+ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
203
+ if args.vac:
204
+ online = VACOnlineASRProcessor(
205
+ args.min_chunk_size,
206
+ asr,
207
+ tokenizer,
208
+ logfile=logfile,
209
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
210
+ )
211
+ else:
212
+ online = OnlineASRProcessor(
213
+ asr,
214
+ tokenizer,
215
+ logfile=logfile,
216
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
217
+ )
218
+ return online
219
+
220
+ def asr_factory(args, logfile=sys.stderr):
221
+ """
222
+ Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
223
+ """
224
+ asr, tokenizer = backend_factory(args)
225
+ online = online_factory(args, asr, tokenizer, logfile=logfile)
226
+ return asr, online
227
+
228
+ def set_logging(args, logger, others=[]):
229
+ logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
230
+ logger.setLevel(args.log_level)
231
+
232
+ for other in others:
233
+ logging.getLogger(other).setLevel(args.log_level)
234
+
235
+