KDM999's picture
Update app.py
8cc7c73 verified
raw
history blame
3.76 kB
import gradio as gr
import random
import json
import os
from difflib import SequenceMatcher
from jiwer import wer
import torchaudio
from transformers import pipeline
# Load metadata
with open("common_voice_en_validated_249_hf_ready.json") as f:
data = json.load(f)
# Available filter values
ages = sorted(set(entry["age"] for entry in data))
genders = sorted(set(entry["gender"] for entry in data))
accents = sorted(set(entry["accent"] for entry in data))
# Load pipelines
device = 0 # 0 for CUDA/GPU, -1 for CPU
pipe_whisper = pipeline("automatic-speech-recognition", model="openai/whisper-medium", device=device)
pipe_wav2vec2 = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=device)
pipe_hubert = pipeline("automatic-speech-recognition", model="facebook/hubert-base-ls960", device=device)
def load_audio(file_path):
waveform, sr = torchaudio.load(file_path)
return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()
def transcribe(pipe, file_path):
result = pipe(file_path)
return result["text"].strip().lower()
def highlight_differences(ref, hyp):
sm = SequenceMatcher(None, ref.split(), hyp.split())
result = []
for opcode, i1, i2, j1, j2 in sm.get_opcodes():
if opcode == 'equal':
result.extend(hyp.split()[j1:j2])
elif opcode in ('replace', 'insert', 'delete'):
wrong = hyp.split()[j1:j2]
result.extend([f"<span style='color:red'>{w}</span>" for w in wrong])
return " ".join(result)
def run_demo(age, gender, accent):
filtered = [
entry for entry in data
if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent
]
if not filtered:
return "No matching sample.", None, "", "", "", "", "", ""
sample = random.choice(filtered)
file_path = os.path.join("common_voice_en_validated_249", sample["path"])
gold = sample["sentence"].strip().lower()
whisper_text = transcribe(pipe_whisper, file_path)
wav2vec_text = transcribe(pipe_wav2vec2, file_path)
hubert_text = transcribe(pipe_hubert, file_path)
table = f"""
<table border="1" style="width:100%">
<tr><th>Model</th><th>Transcription</th><th>WER</th></tr>
<tr><td><b>Gold</b></td><td>{gold}</td><td>0.00</td></tr>
<tr><td>Whisper</td><td>{highlight_differences(gold, whisper_text)}</td><td>{wer(gold, whisper_text):.2f}</td></tr>
<tr><td>Wav2Vec2</td><td>{highlight_differences(gold, wav2vec_text)}</td><td>{wer(gold, wav2vec_text):.2f}</td></tr>
<tr><td>HuBERT</td><td>{highlight_differences(gold, hubert_text)}</td><td>{wer(gold, hubert_text):.2f}</td></tr>
</table>
"""
return sample["sentence"], file_path, gold, whisper_text, wav2vec_text, hubert_text, table, f"Audio path: {file_path}"
with gr.Blocks() as demo:
gr.Markdown("# ASR Model Comparison on ESL Audio")
gr.Markdown("Filter by age, gender, and accent. Then generate a random ESL learner's audio to compare how Whisper, Wav2Vec2, and HuBERT transcribe it.")
with gr.Row():
age = gr.Dropdown(choices=ages, label="Age")
gender = gr.Dropdown(choices=genders, label="Gender")
accent = gr.Dropdown(choices=accents, label="Accent")
btn = gr.Button("Generate and Transcribe")
audio = gr.Audio(label="Audio", type="filepath")
wer_output = gr.HTML()
btn.click(fn=run_demo, inputs=[age, gender, accent], outputs=[
gr.Textbox(label="Gold (Correct)"),
audio,
gr.Textbox(label="Whisper Output"),
gr.Textbox(label="Wav2Vec2 Output"),
gr.Textbox(label="HuBERT Output"),
wer_output,
gr.Textbox(label="Path")
])
demo.launch()