Spaces:
Sleeping
Sleeping
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
import torch | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
import spaces | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Updated model configurations for each language | |
MODEL_CONFIGS = { | |
"English": "openai/whisper-large-v2", | |
"Tamil": "vasista22/whisper-tamil-large-v2", | |
"Malayalam": "thennal/whisper-medium-ml" | |
} | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml" | |
} | |
LANG_PRIMERS = { | |
"English": ("The transcript should be in English only.", | |
"Write only in English without translation. Example: This is an English sentence."), | |
"Tamil": ("நகல் தமிழ் எழுத்துக்களில் மட்டும் இருக்க வேண்டும்.", | |
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும், மொழிபெயர்ப்பு செய்யக்கூடாது. உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."), | |
"Malayalam": ("ട്രാൻസ്ഖ്രിപ്റ്റ് മലയാള ലിപിയിൽ ആയിരിക്കണം.", | |
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക, വിവർത്തനം ചെയ്യരുത്. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്. എനിക്ക് മലയാളം അറിയാം.") | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the horizon.", | |
"Learning languages is fun.", | |
"I like to drink coffee in the morning.", | |
"Technology helps us communicate better.", | |
"Reading books expands our knowledge." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"தமிழ் மொழி மிகவும் அழகானது.", | |
"நான் தினமும் பள்ளிக்கு செல்கிறேன்." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കേരളം എന്റെ സ്വന്തം നാടാണ്.", | |
"ഞാൻ മലയാളം പഠിക്കുന്നു." | |
] | |
} | |
# Global variables for models (will be loaded lazily) | |
whisper_models = {} | |
whisper_processors = {} | |
def load_model(language_choice): | |
"""Load model for specific language if not already loaded""" | |
if language_choice not in whisper_models: | |
model_id = MODEL_CONFIGS[language_choice] | |
print(f"Loading {language_choice} model: {model_id}") | |
whisper_models[language_choice] = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE) | |
whisper_processors[language_choice] = WhisperProcessor.from_pretrained(model_id) | |
print(f"{language_choice} model loaded successfully!") | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
pattern = SCRIPT_PATTERNS.get(lang_name) | |
return bool(pattern.search(text)) if pattern else True | |
def transliterate_to_hk(text, lang_choice): | |
mapping = { | |
"Tamil": sanscript.TAMIL, | |
"Malayalam": sanscript.MALAYALAM, | |
"English": None | |
} | |
return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text | |
def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text): | |
# Load model if not already loaded | |
load_model(language_choice) | |
# Get the appropriate model and processor for the language | |
model = whisper_models[language_choice] | |
processor = whisper_processors[language_choice] | |
lang_code = LANG_CODES[language_choice] | |
# Load and process audio | |
import librosa | |
audio, sr = librosa.load(audio_path, sr=16000) | |
# Process audio with the specific model's processor | |
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE) | |
# Generate forced decoder ids for the language | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe") | |
# Generate transcription | |
with torch.no_grad(): | |
predicted_ids = model.generate( | |
input_features, | |
forced_decoder_ids=forced_decoder_ids, | |
max_length=448, | |
num_beams=beam_size, | |
temperature=temperature if temperature > 0 else None, | |
do_sample=temperature > 0, | |
) | |
# Decode the transcription | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription.strip() | |
def highlight_differences(ref, hyp): | |
ref_words, hyp_words = ref.strip().split(), hyp.strip().split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
out_html = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out_html.extend([f"<span style='color:green'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'replace': | |
out_html.extend([f"<span style='color:red'>{w}</span>" for w in ref_words[i1:i2]]) | |
out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'delete': | |
out_html.extend([f"<span style='color:red;text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'insert': | |
out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]]) | |
return " ".join(out_html) | |
def char_level_highlight(ref, hyp): | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
out = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag in ('replace', 'delete'): | |
out.extend([f"<span style='color:red;text-decoration:underline'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag == 'insert': | |
out.extend([f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]) | |
return "".join(out) | |
# ---------------- MAIN ---------------- # | |
def compare_pronunciation(audio, language_choice, intended_sentence, | |
pass1_beam, pass1_temp, pass1_condition): | |
if audio is None or not intended_sentence.strip(): | |
return ("No audio or intended sentence.", "", "", "", "", "", "", "") | |
primer_weak, primer_strong = LANG_PRIMERS[language_choice] | |
# Pass 1: raw transcription with user-configured decoding parameters | |
actual_text = transcribe_once(audio, language_choice, primer_weak, | |
pass1_beam, pass1_temp, pass1_condition) | |
# Pass 2: strict transcription biased by intended sentence (fixed decoding params) | |
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}" | |
corrected_text = transcribe_once(audio, language_choice, strict_prompt, | |
beam_size=5, temperature=0.0, condition_on_previous_text=False) | |
# Compute WER and CER | |
wer_val = jiwer.wer(intended_sentence, actual_text) | |
cer_val = jiwer.cer(intended_sentence, actual_text) | |
# Transliteration of Pass 1 output | |
hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]" | |
# Highlight word-level and character-level differences | |
diff_html = highlight_differences(intended_sentence, actual_text) | |
char_html = char_level_highlight(intended_sentence, actual_text) | |
return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}", | |
diff_html, char_html, intended_sentence) | |
# ---------------- UI ---------------- # | |
with gr.Blocks(title="Pronunciation Comparator") as demo: | |
gr.Markdown("## 🎙 Pronunciation Comparator - English, Tamil & Malayalam") | |
gr.Markdown("Practice pronunciation with specialized Whisper models for each language!") | |
with gr.Row(): | |
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language") | |
gen_btn = gr.Button("🎲 Generate Sentence") | |
intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False) | |
with gr.Row(): | |
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation") | |
with gr.Column(): | |
gr.Markdown("### Transcription Parameters") | |
pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size") | |
pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature") | |
pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text") | |
submit_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary") | |
with gr.Row(): | |
pass1_out = gr.Textbox(label="Pass 1: What You Actually Said") | |
pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output") | |
with gr.Row(): | |
hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)") | |
wer_out = gr.Textbox(label="Word Error Rate") | |
cer_out = gr.Textbox(label="Character Error Rate") | |
gr.Markdown("### Visual Feedback") | |
diff_html_box = gr.HTML(label="Word Differences Highlighted") | |
char_html_box = gr.HTML(label="Character-Level Highlighting (mispronounced = red underline)") | |
# Event handlers | |
gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display]) | |
submit_btn.click( | |
fn=compare_pronunciation, | |
inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition], | |
outputs=[ | |
pass1_out, pass2_out, hk_out, wer_out, cer_out, | |
diff_html_box, char_html_box, intended_display | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |