sudhanm's picture
Update app.py
bc807f8 verified
raw
history blame
10.9 kB
import gradio as gr
import random
import difflib
import re
import jiwer
from faster_whisper import WhisperModel
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
from transformers import pipeline # only pipeline is needed for TTS
# ---------------- CONFIG ---------------- #
MODEL_NAME = "large-v2"
DEVICE = "cpu"
LANG_CODES = {
"English": "en", "Tamil": "ta", "Malayalam": "ml",
"Hindi": "hi", "Sanskrit": "sa"
}
LANG_PRIMERS = {
"English": ("The transcript should be in English only.",
"Write only in English without translation. Example: This is an English sentence."),
"Tamil": ("நகல் தமிழ் எழுத்துக்களில் மட்டும் இருக்க வேண்டும்.",
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும், மொழிபெயர்ப்பு செய்யக்கூடாது. உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."),
"Malayalam": ("ട്രാൻസ്ക്രിപ്റ്റ് മലയാള ലിപിയിൽ ആയിരിക്കണം.",
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക, വിവർത്തനം ചെയ്യരുത്. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്. എനിക്ക് മലയാളം അറിയാം."),
"Hindi": ("प्रतिलिपि केवल देवनागरी लिपि में होनी चाहिए।",
"केवल देवनागरी लिपि में लिखें, अनुवाद न करें। उदाहरण: यह एक हिंदी वाक्य है।"),
"Sanskrit": ("प्रतिलिपि केवल देवनागरी लिपि में होनी चाहिए।",
"केवल देवनागरी लिपि में लिखें, अनुवाद न करें। उदाहरण: अहं संस्कृतं जानामि।")
}
SCRIPT_PATTERNS = {
"Tamil": re.compile(r"[\u0B80-\u0BFF]"),
"Malayalam": re.compile(r"[\u0D00-\u0D7F]"),
"Hindi": re.compile(r"[\u0900-\u097F]"),
"Sanskrit": re.compile(r"[\u0900-\u097F]"),
"English": re.compile(r"[A-Za-z]")
}
SENTENCE_BANK = {
"English": ["The sun sets over the horizon.",
"Learning languages is fun.",
"I like to drink coffee in the morning."],
"Tamil": ["இன்று நல்ல வானிலை உள்ளது.",
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
"எனக்கு புத்தகம் படிக்க விருப்பம்."],
"Malayalam": ["എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
"ഇന്ന് മഴപെയ്യുന്നു.",
"ഞാൻ പുസ്തകം വായിക്കുന്നു."],
"Hindi": ["आज मौसम अच्छा है।",
"मुझे हिंदी बोलना पसंद है।",
"मैं किताब पढ़ रहा हूँ।"],
"Sanskrit": ["अहं ग्रन्थं पठामि।",
"अद्य सूर्यः तेजस्वी अस्ति।",
"मम नाम रामः।"]
}
VOICE_STYLE = {
"English": "An English female voice with a neutral Indian accent.",
"Tamil": "A female speaker with a clear Tamil accent.",
"Malayalam": "A female speaker with a clear Malayali accent.",
"Hindi": "A female speaker with a neutral Hindi accent.",
"Sanskrit": "A female speaker reading in classical Sanskrit style."
}
# ---------------- LOAD MODELS ---------------- #
print("Loading Whisper model...")
whisper_model = WhisperModel(MODEL_NAME, device=DEVICE)
print("Loading IndicParler-TTS via pipeline...")
TTS_MODEL_ID = "ai4bharat/indic-parler-tts"
tts_pipe = pipeline("text-to-speech", model=TTS_MODEL_ID)
# ---------------- HELPERS ---------------- #
def get_random_sentence(language_choice):
return random.choice(SENTENCE_BANK[language_choice])
def is_script(text, lang_name):
pat = SCRIPT_PATTERNS.get(lang_name)
return bool(pat.search(text)) if pat else True
def transliterate_to_hk(text, lang_choice):
mapping = {
"Tamil": sanscript.TAMIL, "Malayalam": sanscript.MALAYALAM,
"Hindi": sanscript.DEVANAGARI, "Sanskrit": sanscript.DEVANAGARI,
"English": None
}
return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text
def transcribe_once(audio_path, lang_code, initial_prompt, beam_size, temperature, condition_on_previous_text):
segments, _ = whisper_model.transcribe(
audio_path, language=lang_code, task="transcribe",
initial_prompt=initial_prompt, beam_size=beam_size,
temperature=temperature, condition_on_previous_text=condition_on_previous_text,
word_timestamps=False
)
return "".join(s.text for s in segments).strip()
def highlight_differences(ref, hyp):
ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
out_html = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'equal':
out_html.extend([f"<span style='color:green'>{w}</span>" for w in ref_words[i1:i2]])
elif tag == 'replace':
out_html.extend([f"<span style='color:red'>{w}</span>" for w in ref_words[i1:i2]])
out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
elif tag == 'delete':
out_html.extend([f"<span style='color:red;text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
elif tag == 'insert':
out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
return " ".join(out_html)
def char_level_highlight(ref, hyp):
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
out = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'equal':
out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]])
elif tag in ('replace', 'delete'):
out.extend([f"<span style='color:red;text-decoration:underline'>{c}</span>" for c in ref[i1:i2]])
elif tag == 'insert':
out.extend([f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]])
return "".join(out)
def synthesize_tts(text, lang_choice):
if not text.strip():
return None
prompt_style = VOICE_STYLE.get(lang_choice, "")
audio_out = tts_pipe(text, forward_params={"description": prompt_style})
return (audio_out["sampling_rate"], audio_out["audio"])
# ---------------- MAIN ---------------- #
def compare_pronunciation(audio, language_choice, intended_sentence,
pass1_beam, pass1_temp, pass1_condition):
if audio is None or not intended_sentence.strip():
return "No audio or intended sentence.", "", "", "", "", "", None, None, "", ""
lang_code = LANG_CODES[language_choice]
primer_weak, primer_strong = LANG_PRIMERS[language_choice]
# Pass 1
actual_text = transcribe_once(audio, lang_code, primer_weak,
pass1_beam, pass1_temp, pass1_condition)
# Pass 2 (fixed)
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
corrected_text = transcribe_once(audio, lang_code, strict_prompt,
beam_size=5, temperature=0.0, condition_on_previous_text=False)
# Scores
wer_val = jiwer.wer(intended_sentence, actual_text)
cer_val = jiwer.cer(intended_sentence, actual_text)
# Transliteration
hk_translit = transliterate_to_hk(actual_text, language_choice) \
if is_script(actual_text, language_choice) \
else f"[Script mismatch: expected {language_choice}]"
# Highlights
diff_html = highlight_differences(intended_sentence, actual_text)
char_html = char_level_highlight(intended_sentence, actual_text)
# TTS for intended & pass1
tts_intended = synthesize_tts(intended_sentence, language_choice)
tts_pass1 = synthesize_tts(actual_text, language_choice)
return actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}", diff_html, tts_intended, tts_pass1, char_html, intended_sentence
# ---------------- UI ---------------- #
with gr.Blocks() as demo:
gr.Markdown("## 🎙 Pronunciation Comparator + IndicParler‑TTS + Error Highlighting")
with gr.Row():
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
gen_btn = gr.Button("🎲 Generate Sentence")
intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath")
pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size")
pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature")
pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text")
with gr.Row():
pass1_out = gr.Textbox(label="Pass 1: What You Actually Said")
pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output")
hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)")
with gr.Row():
wer_out = gr.Textbox(label="Word Error Rate")
cer_out = gr.Textbox(label="Character Error Rate")
diff_html_box = gr.HTML(label="Word Differences Highlighted")
char_html_box = gr.HTML(label="Character-Level Highlighting (mispronounced = red underline)")
with gr.Row():
intended_tts_audio = gr.Audio(label="TTS - Intended Sentence", type="numpy")
pass1_tts_audio = gr.Audio(label="TTS - Pass1 Output", type="numpy")
gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display])
submit_btn = gr.Button("Analyze Pronunciation")
submit_btn.click(
fn=compare_pronunciation,
inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
outputs=[
pass1_out, pass2_out, hk_out, wer_out, cer_out,
diff_html_box, intended_tts_audio, pass1_tts_audio,
char_html_box, intended_display
]
)
if __name__ == "__main__":
demo.launch()