Dominik Macháček commited on
Commit
37fc0f3
·
1 Parent(s): 2b5f14d

whisper online

Browse files
Files changed (1) hide show
  1. whisper_online.py +402 -0
whisper_online.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import numpy as np
4
+ import whisper
5
+ import whisper_timestamped
6
+ import librosa
7
+ from functools import lru_cache
8
+ import torch
9
+ import time
10
+ from mosestokenizer import MosesTokenizer
11
+ import json
12
+
13
+
14
+ @lru_cache
15
+ def load_audio(fname):
16
+ a, _ = librosa.load(fname, sr=16000)
17
+ return a
18
+
19
+ def load_audio_chunk(fname, beg, end):
20
+ audio = load_audio(fname)
21
+ beg_s = int(beg*16000)
22
+ end_s = int(end*16000)
23
+ return audio[beg_s:end_s]
24
+
25
+ class WhisperASR:
26
+ def __init__(self, modelsize="small", lan="en", cache_dir="disk-cache-dir"):
27
+ self.original_language = lan
28
+ self.model = whisper.load_model(modelsize, download_root=cache_dir)
29
+
30
+ def transcribe(self, audio, init_prompt=""):
31
+ result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
32
+ return result
33
+
34
+ def ts_words(self,r):
35
+ # return: transcribe result object to [(beg,end,"word1"), ...]
36
+ o = []
37
+ for s in r["segments"]:
38
+ for w in s["words"]:
39
+ t = (w["start"],w["end"],w["text"])
40
+ o.append(t)
41
+ return o
42
+
43
+ def to_flush(sents, offset=0):
44
+ # concatenates the timestamped words or sentences into one sequence that is flushed in one line
45
+ # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
46
+ # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
47
+ t = " ".join(s[2] for s in sents)
48
+ if len(sents) == 0:
49
+ b = None
50
+ e = None
51
+ else:
52
+ b = offset + sents[0][0]
53
+ e = offset + sents[-1][1]
54
+ return (b,e,t)
55
+
56
+ class HypothesisBuffer:
57
+
58
+ def __init__(self):
59
+ self.commited_in_buffer = []
60
+ self.buffer = []
61
+ self.new = []
62
+
63
+ self.last_commited_time = 0
64
+ self.last_commited_word = None
65
+
66
+ def insert(self, new, offset):
67
+ # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
68
+ # the new tail is added to self.new
69
+
70
+ new = [(a+offset,b+offset,t) for a,b,t in new]
71
+ self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1]
72
+
73
+ if len(self.new) >= 1:
74
+ a,b,t = self.new[0]
75
+ if abs(a - self.last_commited_time) < 1:
76
+ if self.commited_in_buffer:
77
+ # it's going to search for 1, 2 or 3 consecutive words that are identical in commited and new. If they are, they're dropped.
78
+ cn = len(self.commited_in_buffer)
79
+ nn = len(self.new)
80
+ for i in range(1,min(min(cn,nn),5)+1):
81
+ c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
82
+ tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
83
+ if c == tail:
84
+ print("removing last",i,"words:",file=sys.stderr)
85
+ for j in range(i):
86
+ print("\t",self.new.pop(0),file=sys.stderr)
87
+ break
88
+
89
+ def flush(self):
90
+ # returns commited chunk = the longest common prefix of 2 last inserts.
91
+
92
+ commit = []
93
+ while self.new:
94
+ na, nb, nt = self.new[0]
95
+
96
+ if len(self.buffer) == 0:
97
+ break
98
+
99
+ if nt == self.buffer[0][2]:
100
+ commit.append((na,nb,nt))
101
+ self.last_commited_word = nt
102
+ self.last_commited_time = nb
103
+ self.buffer.pop(0)
104
+ self.new.pop(0)
105
+ else:
106
+ break
107
+ self.buffer = self.new
108
+ self.new = []
109
+ self.commited_in_buffer.extend(commit)
110
+ return commit
111
+
112
+ def pop_commited(self, time):
113
+ while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
114
+ self.commited_in_buffer.pop(0)
115
+
116
+ def complete(self):
117
+ return self.buffer
118
+
119
+ class OnlineASRProcessor:
120
+
121
+ SAMPLING_RATE = 16000
122
+
123
+ def __init__(self, language, asr, chunk):
124
+ """language: lang. code
125
+ asr: WhisperASR object
126
+ chunk: number of seconds for intended size of audio interval that is inserted and looped
127
+ """
128
+ self.language = language
129
+ self.asr = asr
130
+ self.tokenizer = MosesTokenizer("en")
131
+
132
+ self.init()
133
+
134
+ self.chunk = chunk
135
+
136
+
137
+ def init(self):
138
+ """run this when starting or restarting processing"""
139
+ self.audio_buffer = np.array([],dtype=np.float32)
140
+ self.buffer_time_offset = 0
141
+
142
+ self.transcript_buffer = HypothesisBuffer()
143
+ self.commited = []
144
+ self.last_chunked_at = 0
145
+
146
+ self.silence_iters = 0
147
+
148
+ def insert_audio_chunk(self, audio):
149
+ self.audio_buffer = np.append(self.audio_buffer, audio)
150
+
151
+ def prompt(self):
152
+ """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
153
+ "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
154
+ """
155
+ k = max(0,len(self.commited)-1)
156
+ while k > 0 and self.commited[k-1][1] > self.last_chunked_at:
157
+ k -= 1
158
+
159
+ p = self.commited[:k]
160
+ p = [t for _,_,t in p]
161
+ prompt = []
162
+ l = 0
163
+ while p and l < 200: # 200 characters prompt size
164
+ x = p.pop(-1)
165
+ l += len(x)+1
166
+ prompt.append(x)
167
+ non_prompt = self.commited[k:]
168
+ return " ".join(prompt[::-1]), " ".join(t for _,_,t in non_prompt)
169
+
170
+ def process_iter(self):
171
+ """Runs on the current audio buffer.
172
+ Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
173
+ The non-emty text is confirmed (commited) partial transcript.
174
+ """
175
+
176
+ prompt, non_prompt = self.prompt()
177
+ print("PROMPT:", prompt, file=sys.stderr)
178
+ print("CONTEXT:", non_prompt, file=sys.stderr)
179
+ print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=sys.stderr)
180
+ res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
181
+
182
+ # transform to [(beg,end,"word1"), ...]
183
+ tsw = self.asr.ts_words(res)
184
+
185
+ self.transcript_buffer.insert(tsw, self.buffer_time_offset)
186
+ o = self.transcript_buffer.flush()
187
+ self.commited.extend(o)
188
+ print(">>>>COMPLETE NOW:",to_flush(o),file=sys.stderr,flush=True)
189
+ print("INCOMPLETE:",to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True)
190
+
191
+ # there is a newly confirmed text
192
+ if o:
193
+ # we trim all the completed sentences from the audio buffer
194
+ self.chunk_completed_sentence()
195
+
196
+ # ...segments could be considered
197
+ #self.chunk_completed_segment(res)
198
+
199
+ #
200
+ # self.silence_iters = 0
201
+
202
+ # this was an attempt to trim silence/non-linguistic noise detected by the fact that Whisper doesn't transcribe anything for 3-times in a row.
203
+ # It seemed not working better, or needs to be debugged.
204
+
205
+ # elif self.transcript_buffer.complete():
206
+ # self.silence_iters = 0
207
+ # elif not self.transcript_buffer.complete():
208
+ # # print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=sys.stderr,flush=True)
209
+ # self.silence_iters += 1
210
+ # if self.silence_iters >= 3:
211
+ # n = self.last_chunked_at
212
+ ## self.chunk_completed_sentence()
213
+ ## if n == self.last_chunked_at:
214
+ # self.chunk_at(self.last_chunked_at+self.chunk)
215
+ # print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=sys.stderr)
216
+ ## self.silence_iters = 0
217
+
218
+
219
+ # if the audio buffer is longer than 30s, trim it...
220
+ if len(self.audio_buffer)/self.SAMPLING_RATE > 30:
221
+ # ...on the last completed segment (labeled by Whisper)
222
+ self.chunk_completed_segment(res)
223
+
224
+ # alternative: on any word
225
+ #l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
226
+ # let's find commited word that is less
227
+ #k = len(self.commited)-1
228
+ #while k>0 and self.commited[k][1] > l:
229
+ # k -= 1
230
+ #t = self.commited[k][1]
231
+ print(f"chunking because of len",file=sys.stderr)
232
+ #self.chunk_at(t)
233
+
234
+ print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=sys.stderr)
235
+ return to_flush(o)
236
+
237
+ def chunk_completed_sentence(self):
238
+ if self.commited == []: return
239
+ print(self.commited,file=sys.stderr)
240
+ sents = self.words_to_sentences(self.commited)
241
+ for s in sents:
242
+ print("\t\tSENT:",s,file=sys.stderr)
243
+ if len(sents) < 2:
244
+ return
245
+ while len(sents) > 2:
246
+ sents.pop(0)
247
+ # we will continue with audio processing at this timestamp
248
+ chunk_at = sents[-2][1]
249
+
250
+ print(f"--- sentence chunked at {chunk_at:2.2f}",file=sys.stderr)
251
+ self.chunk_at(chunk_at)
252
+
253
+ def chunk_completed_segment(self, res):
254
+ if self.commited == []: return
255
+
256
+ ends = [s["end"] for s in res["segments"]]
257
+
258
+ t = self.commited[-1][1]
259
+
260
+ if len(ends) > 1:
261
+
262
+ e = ends[-2]+self.buffer_time_offset
263
+ while len(ends) > 2 and e > t:
264
+ ends.pop(-1)
265
+ e = ends[-2]+self.buffer_time_offset
266
+ if e <= t:
267
+ print(f"--- segment chunked at {e:2.2f}",file=sys.stderr)
268
+ self.chunk_at(e)
269
+ else:
270
+ print(f"--- last segment not within commited area",file=sys.stderr)
271
+ else:
272
+ print(f"--- not enough segments to chunk",file=sys.stderr)
273
+
274
+
275
+
276
+
277
+
278
+ def chunk_at(self, time):
279
+ """trims the hypothesis and audio buffer at "time"
280
+ """
281
+ self.transcript_buffer.pop_commited(time)
282
+ cut_seconds = time - self.buffer_time_offset
283
+ self.audio_buffer = self.audio_buffer[int(cut_seconds)*self.SAMPLING_RATE:]
284
+ self.buffer_time_offset = time
285
+ self.last_chunked_at = time
286
+
287
+ def words_to_sentences(self, words):
288
+ """Uses mosestokenizer for sentence segmentation of words.
289
+ Returns: [(beg,end,"sentence 1"),...]
290
+ """
291
+
292
+ cwords = [w for w in words]
293
+ t = " ".join(o[2] for o in cwords)
294
+ s = self.tokenizer.split(t)
295
+ out = []
296
+ while s:
297
+ beg = None
298
+ end = None
299
+ sent = s.pop(0).strip()
300
+ fsent = sent
301
+ while cwords:
302
+ b,e,w = cwords.pop(0)
303
+ if beg is None and sent.startswith(w):
304
+ beg = b
305
+ elif end is None and sent == w:
306
+ end = e
307
+ out.append((beg,end,fsent))
308
+ break
309
+ sent = sent[len(w):].strip()
310
+ return out
311
+
312
+ def finish(self):
313
+ """Flush the incomplete text when the whole processing ends.
314
+ Returns: the same format as self.process_iter()
315
+ """
316
+ o = self.transcript_buffer.complete()
317
+ f = to_flush(o)
318
+ print("last, noncommited:",f,file=sys.stderr)
319
+ return f
320
+
321
+
322
+
323
+ ## main:
324
+
325
+ import argparse
326
+ parser = argparse.ArgumentParser()
327
+ parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
328
+ parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
329
+ parser.add_argument('--model', type=str, default='large-v2', help="name of the Whisper model to use (default: large-v2, options: {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large}")
330
+ parser.add_argument('--model_dir', type=str, default='disk-cache-dir', help="the path where Whisper models are saved (or downloaded to). Default: ./disk-cache-dir")
331
+ parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
332
+ parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
333
+ args = parser.parse_args()
334
+
335
+ audio_path = args.audio_path
336
+
337
+ SAMPLING_RATE = 16000
338
+ duration = len(load_audio(audio_path))/SAMPLING_RATE
339
+ print("Audio duration is: %2.2f seconds" % duration, file=sys.stderr)
340
+
341
+ size = args.model
342
+ language = args.lan
343
+
344
+ t = time.time()
345
+ print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
346
+ asr = WhisperASR(lan=language, modelsize=size)
347
+ e = time.time()
348
+ print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
349
+
350
+
351
+ min_chunk = args.min_chunk_size
352
+ online = OnlineASRProcessor(language,asr,min_chunk)
353
+
354
+
355
+ # load the audio into the LRU cache before we start the timer
356
+ a = load_audio_chunk(audio_path,0,1)
357
+
358
+ # warm up the ASR, because the very first transcribe takes much more time than the other
359
+ asr.transcribe(a)
360
+
361
+ def output_transcript(o):
362
+ # output format in stdout is like:
363
+ # 4186.3606 0 1720 Takhle to je
364
+ # - the first three words are:
365
+ # - emission time from beginning of processing, in milliseconds
366
+ # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
367
+ # - the next words: segment transcript
368
+ now = time.time()-start
369
+ if o[0] is not None:
370
+ print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
371
+ else:
372
+ print(o,file=sys.stderr,flush=True)
373
+
374
+ beg = args.start_at
375
+ end = 0
376
+ start = time.time()-beg
377
+ while True:
378
+ now = time.time() - start
379
+ if now < end+min_chunk:
380
+ time.sleep(min_chunk+end-now)
381
+ end = time.time() - start
382
+ a = load_audio_chunk(audio_path,beg,end)
383
+ beg = end
384
+ online.insert_audio_chunk(a)
385
+
386
+ try:
387
+ o = online.process_iter()
388
+ except AssertionError:
389
+ print("assertion error",file=sys.stderr)
390
+ pass
391
+ else:
392
+ output_transcript(o)
393
+ now = time.time() - start
394
+ print(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}",file=sys.stderr)
395
+
396
+ print(file=sys.stderr,flush=True)
397
+
398
+ if end >= duration:
399
+ break
400
+
401
+ o = online.finish()
402
+ output_transcript(o)