Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
from faster_whisper import WhisperModel | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
from transformers import pipeline # only pipeline is needed for TTS | |
# ---------------- CONFIG ---------------- # | |
MODEL_NAME = "large-v2" | |
DEVICE = "cpu" | |
LANG_CODES = { | |
"English": "en", "Tamil": "ta", "Malayalam": "ml", | |
"Hindi": "hi", "Sanskrit": "sa" | |
} | |
LANG_PRIMERS = { | |
"English": ("The transcript should be in English only.", | |
"Write only in English without translation. Example: This is an English sentence."), | |
"Tamil": ("நகல் தமிழ் எழுத்துக்களில் மட்டும் இருக்க வேண்டும்.", | |
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும், மொழிபெயர்ப்பு செய்யக்கூடாது. உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."), | |
"Malayalam": ("ട്രാൻസ്ക്രിപ്റ്റ് മലയാള ലിപിയിൽ ആയിരിക്കണം.", | |
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക, വിവർത്തനം ചെയ്യരുത്. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്. എനിക്ക് മലയാളം അറിയാം."), | |
"Hindi": ("प्रतिलिपि केवल देवनागरी लिपि में होनी चाहिए।", | |
"केवल देवनागरी लिपि में लिखें, अनुवाद न करें। उदाहरण: यह एक हिंदी वाक्य है।"), | |
"Sanskrit": ("प्रतिलिपि केवल देवनागरी लिपि में होनी चाहिए।", | |
"केवल देवनागरी लिपि में लिखें, अनुवाद न करें। उदाहरण: अहं संस्कृतं जानामि।") | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[\u0B80-\u0BFF]"), | |
"Malayalam": re.compile(r"[\u0D00-\u0D7F]"), | |
"Hindi": re.compile(r"[\u0900-\u097F]"), | |
"Sanskrit": re.compile(r"[\u0900-\u097F]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
SENTENCE_BANK = { | |
"English": ["The sun sets over the horizon.", | |
"Learning languages is fun.", | |
"I like to drink coffee in the morning."], | |
"Tamil": ["இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்."], | |
"Malayalam": ["എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു."], | |
"Hindi": ["आज मौसम अच्छा है।", | |
"मुझे हिंदी बोलना पसंद है।", | |
"मैं किताब पढ़ रहा हूँ।"], | |
"Sanskrit": ["अहं ग्रन्थं पठामि।", | |
"अद्य सूर्यः तेजस्वी अस्ति।", | |
"मम नाम रामः।"] | |
} | |
VOICE_STYLE = { | |
"English": "An English female voice with a neutral Indian accent.", | |
"Tamil": "A female speaker with a clear Tamil accent.", | |
"Malayalam": "A female speaker with a clear Malayali accent.", | |
"Hindi": "A female speaker with a neutral Hindi accent.", | |
"Sanskrit": "A female speaker reading in classical Sanskrit style." | |
} | |
# ---------------- LOAD MODELS ---------------- # | |
print("Loading Whisper model...") | |
whisper_model = WhisperModel(MODEL_NAME, device=DEVICE) | |
print("Loading IndicParler-TTS via pipeline...") | |
TTS_MODEL_ID = "ai4bharat/indic-parler-tts" | |
tts_pipe = pipeline("text-to-speech", model=TTS_MODEL_ID) | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
pat = SCRIPT_PATTERNS.get(lang_name) | |
return bool(pat.search(text)) if pat else True | |
def transliterate_to_hk(text, lang_choice): | |
mapping = { | |
"Tamil": sanscript.TAMIL, "Malayalam": sanscript.MALAYALAM, | |
"Hindi": sanscript.DEVANAGARI, "Sanskrit": sanscript.DEVANAGARI, | |
"English": None | |
} | |
return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text | |
def transcribe_once(audio_path, lang_code, initial_prompt, beam_size, temperature, condition_on_previous_text): | |
segments, _ = whisper_model.transcribe( | |
audio_path, language=lang_code, task="transcribe", | |
initial_prompt=initial_prompt, beam_size=beam_size, | |
temperature=temperature, condition_on_previous_text=condition_on_previous_text, | |
word_timestamps=False | |
) | |
return "".join(s.text for s in segments).strip() | |
def highlight_differences(ref, hyp): | |
ref_words, hyp_words = ref.strip().split(), hyp.strip().split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
out_html = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out_html.extend([f"<span style='color:green'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'replace': | |
out_html.extend([f"<span style='color:red'>{w}</span>" for w in ref_words[i1:i2]]) | |
out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'delete': | |
out_html.extend([f"<span style='color:red;text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'insert': | |
out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]]) | |
return " ".join(out_html) | |
def char_level_highlight(ref, hyp): | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
out = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag in ('replace', 'delete'): | |
out.extend([f"<span style='color:red;text-decoration:underline'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag == 'insert': | |
out.extend([f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]) | |
return "".join(out) | |
def synthesize_tts(text, lang_choice): | |
if not text.strip(): | |
return None | |
prompt_style = VOICE_STYLE.get(lang_choice, "") | |
audio_out = tts_pipe(text, forward_params={"description": prompt_style}) | |
return (audio_out["sampling_rate"], audio_out["audio"]) | |
# ---------------- MAIN ---------------- # | |
def compare_pronunciation(audio, language_choice, intended_sentence, | |
pass1_beam, pass1_temp, pass1_condition): | |
if audio is None or not intended_sentence.strip(): | |
return "No audio or intended sentence.", "", "", "", "", "", None, None, "", "" | |
lang_code = LANG_CODES[language_choice] | |
primer_weak, primer_strong = LANG_PRIMERS[language_choice] | |
# Pass 1 | |
actual_text = transcribe_once(audio, lang_code, primer_weak, | |
pass1_beam, pass1_temp, pass1_condition) | |
# Pass 2 (fixed) | |
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}" | |
corrected_text = transcribe_once(audio, lang_code, strict_prompt, | |
beam_size=5, temperature=0.0, condition_on_previous_text=False) | |
# Scores | |
wer_val = jiwer.wer(intended_sentence, actual_text) | |
cer_val = jiwer.cer(intended_sentence, actual_text) | |
# Transliteration | |
hk_translit = transliterate_to_hk(actual_text, language_choice) \ | |
if is_script(actual_text, language_choice) \ | |
else f"[Script mismatch: expected {language_choice}]" | |
# Highlights | |
diff_html = highlight_differences(intended_sentence, actual_text) | |
char_html = char_level_highlight(intended_sentence, actual_text) | |
# TTS for intended & pass1 | |
tts_intended = synthesize_tts(intended_sentence, language_choice) | |
tts_pass1 = synthesize_tts(actual_text, language_choice) | |
return actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}", diff_html, tts_intended, tts_pass1, char_html, intended_sentence | |
# ---------------- UI ---------------- # | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🎙 Pronunciation Comparator + IndicParler‑TTS + Error Highlighting") | |
with gr.Row(): | |
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language") | |
gen_btn = gr.Button("🎲 Generate Sentence") | |
intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False) | |
with gr.Row(): | |
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath") | |
pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size") | |
pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature") | |
pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text") | |
with gr.Row(): | |
pass1_out = gr.Textbox(label="Pass 1: What You Actually Said") | |
pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output") | |
hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)") | |
with gr.Row(): | |
wer_out = gr.Textbox(label="Word Error Rate") | |
cer_out = gr.Textbox(label="Character Error Rate") | |
diff_html_box = gr.HTML(label="Word Differences Highlighted") | |
char_html_box = gr.HTML(label="Character-Level Highlighting (mispronounced = red underline)") | |
with gr.Row(): | |
intended_tts_audio = gr.Audio(label="TTS - Intended Sentence", type="numpy") | |
pass1_tts_audio = gr.Audio(label="TTS - Pass1 Output", type="numpy") | |
gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display]) | |
submit_btn = gr.Button("Analyze Pronunciation") | |
submit_btn.click( | |
fn=compare_pronunciation, | |
inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition], | |
outputs=[ | |
pass1_out, pass2_out, hk_out, wer_out, cer_out, | |
diff_html_box, intended_tts_audio, pass1_tts_audio, | |
char_html_box, intended_display | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |