Huong commited on
Commit
265ea18
·
1 Parent(s): 704a4fe

Add application file

Browse files
Files changed (1) hide show
  1. app.py +428 -0
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_rich_textbox import RichTextbox
3
+ import torchaudio
4
+ import re
5
+ import librosa
6
+ import torch
7
+ import numpy as np
8
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
9
+ from whisper.normalizers import EnglishTextNormalizer
10
+ from whisper import audio, DecodingOptions
11
+ from whisper.tokenizer import get_tokenizer
12
+ from whisper.decoding import detect_language
13
+ from olmoasr import load_model
14
+ from bs4 import BeautifulSoup
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
+ hf_model_path = "checkpoints/medium_hf_demo"
19
+ olmoasr_ckpt = (
20
+ "checkpoints/eval_latesttrain_00524288_medium_fsdp-train_grad-acc_bfloat16_inf.pt"
21
+ )
22
+
23
+ hf_model = AutoModelForSpeechSeq2Seq.from_pretrained(
24
+ hf_model_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
25
+ )
26
+ hf_model.to(device).eval()
27
+ processor = AutoProcessor.from_pretrained(hf_model_path)
28
+
29
+ olmoasr_model = load_model(
30
+ name=olmoasr_ckpt, device=device, inference=True, in_memory=True
31
+ )
32
+ olmoasr_model.to(device).eval()
33
+
34
+ normalizer = EnglishTextNormalizer()
35
+
36
+
37
+ def stereo_to_mono(waveform):
38
+ # Check if the waveform is stereo
39
+ if waveform.shape[0] == 2:
40
+ # Average the two channels to convert to mono
41
+ mono_waveform = np.mean(waveform, axis=0)
42
+ return mono_waveform
43
+ else:
44
+ # If already mono, return as is
45
+ return waveform
46
+
47
+
48
+ def hf_chunk_transcribe(audio_file, timestamp_text, transcription_text):
49
+ hf_transcriber = pipeline(
50
+ "automatic-speech-recognition",
51
+ model=hf_model,
52
+ tokenizer=processor.tokenizer,
53
+ feature_extractor=processor.feature_extractor,
54
+ torch_dtype=torch_dtype,
55
+ device=device,
56
+ chunk_length_s=30,
57
+ )
58
+
59
+ waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False)
60
+ waveform = stereo_to_mono(waveform)
61
+ print(waveform.shape)
62
+
63
+ if sample_rate != 16000:
64
+ waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
65
+
66
+ result = hf_transcriber(waveform, return_timestamps=True)
67
+ print(f"{result['text']=}\n")
68
+ print(f"{result['chunks']=}\n")
69
+
70
+ # text = result["text"].strip().replace("\n", " ")
71
+ # text = re.sub(r"(foreign|foreign you|you)\s*$", "", text)
72
+
73
+ chunks, text = hf_process_chunks(result["chunks"])
74
+ print(f"{chunks=}\n")
75
+ print(f"{text=}\n")
76
+
77
+ # Edit components
78
+ transSoup = BeautifulSoup(transcription_text, "html.parser")
79
+ transText = transSoup.find(id="transcriptionText")
80
+ if transText:
81
+ transText.clear()
82
+ transText.append(BeautifulSoup(text, "html.parser"))
83
+
84
+ timeSoup = BeautifulSoup(timestamp_text, "html.parser")
85
+ timeText = timeSoup.find(id="timestampText")
86
+ if timeText:
87
+ timeText.clear()
88
+ timeText.append(BeautifulSoup(chunks, "html.parser"))
89
+
90
+ return str(timeSoup), str(transSoup)
91
+
92
+
93
+ def olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text):
94
+ waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False)
95
+ waveform = stereo_to_mono(waveform)
96
+ print(waveform.shape)
97
+
98
+ if sample_rate != 16000:
99
+ waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
100
+
101
+ options = dict(
102
+ task="transcribe",
103
+ language="en",
104
+ without_timestamps=False,
105
+ beam_size=5,
106
+ best_of=5,
107
+ )
108
+ result = olmoasr_model.transcribe(waveform, verbose=False, **options)
109
+ print(f"{result['text']=}\n")
110
+ print(f"{result['segments']=}\n")
111
+
112
+ # text = result["text"].strip().replace("\n", " ")
113
+ # text = re.sub(r"(foreign|foreign you|Thank you for watching!|. you)\s*$", "", text)
114
+
115
+ chunks, text = olmoasr_process_chunks(result["segments"])
116
+ print(f"{chunks=}\n")
117
+ print(f"{text=}\n")
118
+
119
+ # Edit components
120
+ transSoup = BeautifulSoup(transcription_text, "html.parser")
121
+ transText = transSoup.find(id="transcriptionText")
122
+ if transText:
123
+ transText.clear()
124
+ transText.append(BeautifulSoup(text, "html.parser"))
125
+
126
+ timeSoup = BeautifulSoup(timestamp_text, "html.parser")
127
+ timeText = timeSoup.find(id="timestampText")
128
+ if timeText:
129
+ timeText.clear()
130
+ timeText.append(BeautifulSoup(chunks, "html.parser"))
131
+
132
+ return str(timeSoup), str(transSoup)
133
+
134
+
135
+ def hf_seq_transcribe(audio_file, timestamp_text, transcription_text):
136
+ hf_transcriber = pipeline(
137
+ "automatic-speech-recognition",
138
+ model=hf_model,
139
+ tokenizer=processor.tokenizer,
140
+ feature_extractor=processor.feature_extractor,
141
+ torch_dtype=torch_dtype,
142
+ device=device,
143
+ )
144
+
145
+ waveform, sample_rate = librosa.load(audio_file, sr=None, mono=False)
146
+ waveform = stereo_to_mono(waveform)
147
+ print(waveform.shape)
148
+
149
+ if sample_rate != 16000:
150
+ waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
151
+
152
+ result = hf_transcriber(
153
+ waveform,
154
+ return_timestamps=True,
155
+ )
156
+ print(f"{result['text']=}\n")
157
+ print(f"{result['chunks']=}\n")
158
+
159
+ # text = result["text"].strip().replace("\n", " ")
160
+ # text = re.sub(r"(foreign|foreign you|you)\s*$", "", text)
161
+
162
+ chunks, text = hf_seq_process_chunks(result["chunks"])
163
+ print(f"{text=}\n")
164
+ print(f"{chunks=}\n")
165
+
166
+ # Edit components
167
+ transSoup = BeautifulSoup(transcription_text, "html.parser")
168
+ transText = transSoup.find(id="transcriptionText")
169
+ if transText:
170
+ transText.clear()
171
+ transText.append(BeautifulSoup(text, "html.parser"))
172
+
173
+ timeSoup = BeautifulSoup(timestamp_text, "html.parser")
174
+ timeText = timeSoup.find(id="timestampText")
175
+ if timeText:
176
+ timeText.clear()
177
+ timeText.append(BeautifulSoup(chunks, "html.parser"))
178
+
179
+ return str(timeSoup), str(transSoup)
180
+
181
+
182
+ def main_transcribe(inference_strategy, audio_file, timestamp_text, transcription_text):
183
+ if inference_strategy == "HuggingFace Chunking":
184
+ return hf_chunk_transcribe(audio_file, timestamp_text, transcription_text)
185
+ elif inference_strategy == "OLMoASR Sequential":
186
+ return olmoasr_seq_transcribe(audio_file, timestamp_text, transcription_text)
187
+ elif inference_strategy == "HuggingFace Sequential":
188
+ return hf_seq_transcribe(audio_file, timestamp_text, transcription_text)
189
+
190
+
191
+ def olmoasr_process_chunks(chunks):
192
+ processed_chunks = []
193
+ processed_chunks_text = []
194
+ for chunk in chunks:
195
+ text = chunk["text"].strip()
196
+ if not re.match(
197
+ r"\s*(foreign you|foreign|Thank you for watching!|you there|you)\s*$", text
198
+ ):
199
+ if text.strip() == "":
200
+ continue
201
+ start = chunk["start"]
202
+ end = chunk["end"]
203
+ pattern = r"\n(?!\d+\.\d+\s*-->)"
204
+ text = re.sub(pattern, "", text)
205
+ processed_chunks_text.append(text.strip())
206
+ processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text} <br>")
207
+ else:
208
+ break
209
+ print(f"{processed_chunks=}\n")
210
+ print(f"{processed_chunks_text=}\n")
211
+ print(
212
+ re.search(r"\s*foreign\s*$", processed_chunks_text[-1])
213
+ if processed_chunks_text
214
+ else None
215
+ )
216
+
217
+ if processed_chunks_text and re.search(
218
+ r"\s*foreign\s*$", processed_chunks_text[-1]
219
+ ):
220
+ processed_chunks_text[-1] = re.sub(
221
+ r"\s*foreign\s*$", "", processed_chunks_text[-1]
222
+ )
223
+ processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1])
224
+ return "\n".join(processed_chunks), " ".join(processed_chunks_text)
225
+
226
+
227
+ def hf_process_chunks(chunks):
228
+ processed_chunks = []
229
+ processed_chunks_text = []
230
+ for chunk in chunks:
231
+ text = chunk["text"].strip()
232
+ if not re.match(r"(foreign you|foreign|you there|you)\s*$", text):
233
+ if text.strip() == "":
234
+ continue
235
+ start = chunk["timestamp"][0]
236
+ end = chunk["timestamp"][1]
237
+ pattern = r"\n(?!\d+\.\d+\s*-->)"
238
+ text = re.sub(pattern, "", text)
239
+ processed_chunks_text.append(text.strip())
240
+ processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()} <br>")
241
+ else:
242
+ break
243
+ print(f"{processed_chunks=}\n")
244
+ print(f"{processed_chunks_text=}\n")
245
+ print(
246
+ re.search(r"\s*foreign\s*$", processed_chunks_text[-1])
247
+ if processed_chunks_text
248
+ else None
249
+ )
250
+
251
+ if processed_chunks_text and re.search(
252
+ r"\s*foreign\s*$", processed_chunks_text[-1]
253
+ ):
254
+ processed_chunks_text[-1] = re.sub(
255
+ r"\s*foreign\s*$", "", processed_chunks_text[-1]
256
+ )
257
+ processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1])
258
+ return "\n".join(processed_chunks), " ".join(processed_chunks_text)
259
+
260
+
261
+ def hf_seq_process_chunks(chunks):
262
+ processed_chunks = []
263
+ processed_chunks_text = []
264
+ delta_time = 0.0
265
+ global_start = chunks[0]["timestamp"][0]
266
+ prev_end = -1.0
267
+ prev_dur = 0.0
268
+ accumulate_ts = False
269
+ for chunk in chunks:
270
+ text = chunk["text"].strip()
271
+ if not re.match(r"\s*(foreign you|foreign|you there|you)\s*$", text):
272
+ if text.strip() == "":
273
+ continue
274
+ start = chunk["timestamp"][0]
275
+ if start < prev_end:
276
+ accumulate_ts = True
277
+ end = chunk["timestamp"][1]
278
+ if start < prev_end:
279
+ prev_dur += delta_time
280
+ # print(f"{prev_dur=}")
281
+
282
+ delta_time = end - global_start
283
+ # print(f"{delta_time=}")
284
+
285
+ prev_end = end
286
+ # print(f"{prev_end=}")
287
+ if accumulate_ts:
288
+ start += prev_dur
289
+ if accumulate_ts:
290
+ end += prev_dur
291
+ # print(f"{start=}, {end=}, {prev_dur=}")
292
+
293
+ pattern = r"\n(?!\d+\.\d+\s*-->)"
294
+ text = re.sub(pattern, "", text)
295
+ processed_chunks_text.append(text.strip())
296
+ processed_chunks.append(f"{start:.2f} --> {end:.2f}: {text.strip()} <br>")
297
+ else:
298
+ break
299
+ print(f"{processed_chunks=}\n")
300
+ print(f"{processed_chunks_text=}\n")
301
+ print(
302
+ re.search(r"\s*foreign\s*$", processed_chunks_text[-1])
303
+ if processed_chunks_text
304
+ else None
305
+ )
306
+
307
+ if processed_chunks_text and re.search(
308
+ r"\s*foreign\s*$", processed_chunks_text[-1]
309
+ ):
310
+ processed_chunks_text[-1] = re.sub(
311
+ r"\s*foreign\s*$", "", processed_chunks_text[-1]
312
+ )
313
+ processed_chunks[-1] = re.sub(r"foreign\s*<br>", "<br>", processed_chunks[-1])
314
+ return "\n".join(processed_chunks), " ".join(processed_chunks_text)
315
+
316
+
317
+ original_timestamp_html = """
318
+ <div style="background: white; border: 1px solid #d1d5db; border-radius: 8px; padding: 16px; width: 100%; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); flex: 1; margin-right: 10px;">
319
+ <div style="color: #374151; font-size: 14px; font-weight: 500; margin-bottom: 8px;">Timestamp Text</div>
320
+ <div id="timestampText"; style="color: #6b7280; font-size: 14px; line-height: 1.5; min-height: 100px; font-family: system-ui, sans-serif;"></div>
321
+ </div>
322
+ """
323
+
324
+ original_transcription_html = """
325
+ <div style="background: white; border: 1px solid #d1d5db; border-radius: 8px; padding: 16px; width: 100%; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); flex: 1; margin-right: 10px;">
326
+ <div style="color: #374151; font-size: 14px; font-weight: 500; margin-bottom: 8px;">Transcription Text</div>
327
+ <div id="transcriptionText"; style="color: #6b7280; font-size: 14px; line-height: 1.5; min-height: 100px; font-family: system-ui, sans-serif;"></div>
328
+ </div>
329
+ """
330
+
331
+
332
+ def reset():
333
+ return original_timestamp_html, original_transcription_html
334
+
335
+
336
+ event_process_js = """
337
+ <script>
338
+ function getTime() {
339
+ lastIndex = -1;
340
+ setInterval(() => {
341
+ time = document.getElementById('time');
342
+ timestampText = document.getElementById('timestampText');
343
+ if(timestampText && timestampText.innerText != '') {
344
+ if(time == null) {
345
+ timestampText.innerText = '';
346
+ transcriptionText = document.getElementById('transcriptionText');
347
+ if(transcriptionText) {
348
+ transcriptionText.innerText = '';
349
+ }
350
+ lastIndex = -1;
351
+ return;
352
+ }
353
+ timeContent = time.textContent;
354
+ const parts = timeContent.split(":").map(Number);
355
+ currTime = parseFloat(parts[0]) * 60 + parseFloat(parts[1]);
356
+ currText = timestampText.innerText;
357
+ const matches = [...currText.matchAll(/([\d.]+)\s*-->/g)];
358
+ const startTimestamps = matches.map(m => parseFloat(m[1]));
359
+
360
+ if(startTimestamps.length != 0) {
361
+ correctIndex = 0;
362
+ for (let i = 0; i < startTimestamps.length; i++) {
363
+ if (startTimestamps[i] <= currTime) {
364
+ correctIndex = i;
365
+ }
366
+ else {
367
+ break;
368
+ }
369
+ }
370
+ if (lastIndex != correctIndex) {
371
+ lastIndex = correctIndex;
372
+ lines = currText.split('\\n');
373
+ lines[correctIndex] = '<span style="background-color: #ff69b4; padding: 3px 8px; font-weight: 500; border-radius: 4px; color: white; box-shadow: 0 0 10px rgba(255, 105, 180, 0.5);">' + lines[correctIndex] + '</span>';
374
+ try {
375
+ timestampText.innerHTML = lines.join('<br>');
376
+ }
377
+ catch (e) {
378
+ console.log('Not Updating!');
379
+ }
380
+ }
381
+
382
+ }
383
+ }
384
+ else {
385
+ lastIndex = -1;
386
+ }
387
+ }, 50);
388
+ }
389
+ setTimeout(getTime, 1000);
390
+ </script>
391
+ """
392
+ demo = gr.Blocks(
393
+ head=event_process_js,
394
+ theme=gr.themes.Default(primary_hue="emerald", secondary_hue="green"),
395
+ )
396
+ with demo:
397
+ audio = gr.Audio(sources=["upload", "microphone"], type="filepath")
398
+ inf_strategy = gr.Dropdown(
399
+ label="Inference Strategy",
400
+ choices=[
401
+ "HuggingFace Chunking",
402
+ "HuggingFace Sequential",
403
+ "OLMoASR Sequential",
404
+ ],
405
+ value="HuggingFace Chunking",
406
+ multiselect=False,
407
+ info="Select the inference strategy for transcription.",
408
+ elem_id="inf_strategy",
409
+ )
410
+ main_transcribe_button = gr.Button(
411
+ "Transcribe",
412
+ variant="primary",
413
+ )
414
+ with gr.Row():
415
+ timestampText = gr.HTML(original_timestamp_html)
416
+
417
+ transcriptionText = gr.HTML(original_transcription_html)
418
+ inf_strategy.change(
419
+ fn=reset,
420
+ inputs=[],
421
+ outputs=[timestampText, transcriptionText],
422
+ )
423
+ main_transcribe_button.click(
424
+ fn=main_transcribe,
425
+ inputs=[inf_strategy, audio, timestampText, transcriptionText],
426
+ outputs=[timestampText, transcriptionText],
427
+ )
428
+ demo.launch(share=True)