Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
import torch | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
import spaces | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Updated model configurations for each language | |
MODEL_CONFIGS = { | |
"English": "openai/whisper-large-v2", | |
"Tamil": "vasista22/whisper-tamil-large-v2", | |
"Malayalam": "thennal/whisper-medium-ml" | |
} | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml" | |
} | |
LANG_PRIMERS = { | |
"English": ("The transcript should be in English only.", | |
"Write only in English without translation. Example: This is an English sentence."), | |
"Tamil": ("நகல் தமிழ் எழுத்துக்களில் மட்டும் இருக்க வேண்டும்.", | |
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும், மொழிபெயர்ப்பு செய்யக்கூடாது. உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."), | |
"Malayalam": ("ട്രാൻസ്ഖ്രിപ്റ്റ് മലയാള ലിപിയിൽ ആയിരിക്കണം.", | |
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക, വിവർത്തനം ചെയ്യരുത്. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്. എനിക്ക് മലയാളം അറിയാം.") | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the horizon.", | |
"Learning languages is fun.", | |
"I like to drink coffee in the morning.", | |
"Technology helps us communicate better.", | |
"Reading books expands our knowledge." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"தமிழ் மொழி மிகவும் அழகானது.", | |
"நான் தினமும் பள்ளிக்கு செல்கிறேன்." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കേരളം എന്റെ സ്വന്തം നാടാണ്.", | |
"ഞാൻ മലയാളം പഠിക്കുന്നു." | |
] | |
} | |
# Global variables for models (will be loaded lazily) | |
current_model = None | |
current_processor = None | |
current_language = None | |
def clear_gpu_memory(): | |
"""Clear GPU memory to prevent OOM errors""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def load_model(language_choice): | |
"""Load model for specific language, unload previous if different""" | |
global current_model, current_processor, current_language | |
if current_language == language_choice and current_model is not None: | |
return current_model, current_processor | |
# Clear previous model if different language | |
if current_model is not None: | |
print(f"Unloading previous model for {current_language}") | |
del current_model | |
del current_processor | |
clear_gpu_memory() | |
# Load new model | |
model_id = MODEL_CONFIGS[language_choice] | |
print(f"Loading {language_choice} model: {model_id}") | |
try: | |
current_processor = WhisperProcessor.from_pretrained(model_id) | |
current_model = WhisperForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, # Use half precision to save memory | |
device_map="auto" | |
) | |
current_language = language_choice | |
print(f"{language_choice} model loaded successfully!") | |
return current_model, current_processor | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Fallback to CPU if GPU fails | |
current_processor = WhisperProcessor.from_pretrained(model_id) | |
current_model = WhisperForConditionalGeneration.from_pretrained(model_id) | |
current_language = language_choice | |
return current_model, current_processor | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
pattern = SCRIPT_PATTERNS.get(lang_name) | |
return bool(pattern.search(text)) if pattern else True | |
def transliterate_to_hk(text, lang_choice): | |
mapping = { | |
"Tamil": sanscript.TAMIL, | |
"Malayalam": sanscript.MALAYALAM, | |
"English": None | |
} | |
return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text | |
def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text): | |
try: | |
# Load model if not already loaded | |
model, processor = load_model(language_choice) | |
lang_code = LANG_CODES[language_choice] | |
# Load and process audio | |
import librosa | |
audio, sr = librosa.load(audio_path, sr=16000) | |
# Process audio with the specific model's processor | |
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features | |
# Move to GPU if available | |
if torch.cuda.is_available(): | |
input_features = input_features.to("cuda") | |
# Generate forced decoder ids for the language | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe") | |
# Generate transcription with memory-efficient settings | |
with torch.no_grad(): | |
predicted_ids = model.generate( | |
input_features, | |
forced_decoder_ids=forced_decoder_ids, | |
max_length=200, # Reduced max length to save memory | |
num_beams=min(beam_size, 4), # Limit beam size for memory | |
temperature=temperature if temperature > 0 else None, | |
do_sample=temperature > 0, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
# Decode the transcription | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
# Clear GPU cache after inference | |
clear_gpu_memory() | |
return transcription.strip() | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
clear_gpu_memory() | |
return f"Error during transcription: {str(e)}" | |
def highlight_differences(ref, hyp): | |
ref_words, hyp_words = ref.strip().split(), hyp.strip().split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
# Create side-by-side comparison | |
expected_html = [] | |
actual_html = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
# Correct words - green background | |
expected_html.extend([f"<span style='background-color:#d4edda; color:#155724; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
actual_html.extend([f"<span style='background-color:#d4edda; color:#155724; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'replace': | |
# Substituted words - red for expected, orange for actual | |
expected_html.extend([f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:underline;'>{w}</span>" for w in ref_words[i1:i2]]) | |
actual_html.extend([f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px; font-weight:bold;'>{w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'delete': | |
# Missing words - red with strikethrough | |
expected_html.extend([f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'insert': | |
# Extra words - orange | |
actual_html.extend([f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px; font-weight:bold;'>+{w}</span>" for w in hyp_words[j1:j2]]) | |
# Create the comparison HTML | |
comparison_html = f""" | |
<div style='font-family: monospace; line-height: 2;'> | |
<div style='margin-bottom: 15px;'> | |
<strong>📝 Expected:</strong><br> | |
<div style='padding: 10px; background-color: #f8f9fa; border-radius: 5px; margin-top: 5px;'> | |
{" ".join(expected_html)} | |
</div> | |
</div> | |
<div style='margin-bottom: 15px;'> | |
<strong>🎤 You said:</strong><br> | |
<div style='padding: 10px; background-color: #f8f9fa; border-radius: 5px; margin-top: 5px;'> | |
{" ".join(actual_html)} | |
</div> | |
</div> | |
<div style='font-size: 12px; color: #6c757d; margin-top: 10px;'> | |
<span style='background-color:#d4edda; padding:2px 4px; border-radius:3px;'>✓ Correct</span> | |
<span style='background-color:#f8d7da; padding:2px 4px; border-radius:3px; margin-left:5px;'>✗ Expected</span> | |
<span style='background-color:#fff3cd; padding:2px 4px; border-radius:3px; margin-left:5px;'>+ Extra/Wrong</span> | |
</div> | |
</div> | |
""" | |
return comparison_html | |
def char_level_highlight(ref, hyp): | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
expected_chars = [] | |
actual_chars = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
# Correct characters - green background | |
expected_chars.extend([f"<span style='background-color:#d4edda; color:#155724;'>{c}</span>" for c in ref[i1:i2]]) | |
actual_chars.extend([f"<span style='background-color:#d4edda; color:#155724;'>{c}</span>" for c in hyp[j1:j2]]) | |
elif tag == 'replace': | |
# Different characters - red for expected, orange for actual | |
expected_chars.extend([f"<span style='background-color:#f8d7da; color:#721c24; text-decoration:underline;'>{c}</span>" for c in ref[i1:i2]]) | |
actual_chars.extend([f"<span style='background-color:#fff3cd; color:#856404; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]]) | |
elif tag == 'delete': | |
# Missing characters - red with strikethrough | |
expected_chars.extend([f"<span style='background-color:#f8d7da; color:#721c24; text-decoration:line-through;'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag == 'insert': | |
# Extra characters - orange with + prefix | |
actual_chars.extend([f"<span style='background-color:#fff3cd; color:#856404; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]]) | |
# Character-level comparison | |
char_comparison_html = f""" | |
<div style='font-family: monospace; line-height: 2; font-size: 16px;'> | |
<div style='margin-bottom: 15px;'> | |
<strong>📝 Expected (character-level):</strong><br> | |
<div style='padding: 10px; background-color: #f8f9fa; border-radius: 5px; margin-top: 5px; word-break: break-all; letter-spacing: 1px;'> | |
{"".join(expected_chars)} | |
</div> | |
</div> | |
<div style='margin-bottom: 15px;'> | |
<strong>🎤 You said (character-level):</strong><br> | |
<div style='padding: 10px; background-color: #f8f9fa; border-radius: 5px; margin-top: 5px; word-break: break-all; letter-spacing: 1px;'> | |
{"".join(actual_chars)} | |
</div> | |
</div> | |
<div style='font-size: 12px; color: #6c757d; margin-top: 10px;'> | |
Character-level analysis helps identify pronunciation issues within words | |
</div> | |
</div> | |
""" | |
return char_comparison_html | |
# ---------------- MAIN ---------------- # | |
def compare_pronunciation(audio, language_choice, intended_sentence, | |
pass1_beam, pass1_temp, pass1_condition): | |
if audio is None or not intended_sentence.strip(): | |
return ("No audio or intended sentence.", "", "", "", "", "", "", "", "❌ Please provide audio and sentence") | |
try: | |
primer_weak, primer_strong = LANG_PRIMERS[language_choice] | |
# Pass 1: raw transcription with user-configured decoding parameters | |
status_msg = f"🔄 Transcribing with {language_choice} model..." | |
actual_text = transcribe_once(audio, language_choice, primer_weak, | |
pass1_beam, pass1_temp, pass1_condition) | |
if actual_text.startswith("Error"): | |
return (actual_text, "", "", "", "", "", "", "", "❌ Transcription failed") | |
# Pass 2: strict transcription biased by intended sentence (fixed decoding params) | |
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}" | |
corrected_text = transcribe_once(audio, language_choice, strict_prompt, | |
beam_size=3, temperature=0.0, condition_on_previous_text=False) | |
# Compute WER and CER | |
try: | |
wer_val = jiwer.wer(intended_sentence, actual_text) | |
cer_val = jiwer.cer(intended_sentence, actual_text) | |
except: | |
wer_val = 1.0 | |
cer_val = 1.0 | |
# Transliteration of Pass 1 output | |
hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]" | |
# Highlight word-level and character-level differences | |
diff_html = highlight_differences(intended_sentence, actual_text) | |
char_html = char_level_highlight(intended_sentence, actual_text) | |
# Success status | |
status_msg = f"✅ Analysis complete! WER: {wer_val:.2f}" | |
return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}", | |
diff_html, char_html, intended_sentence, status_msg) | |
except Exception as e: | |
error_msg = f"❌ Error: {str(e)}" | |
clear_gpu_memory() | |
return ("Error occurred", "", "", "", "", "", "", "", error_msg) | |
# ---------------- UI ---------------- # | |
with gr.Blocks(title="Pronunciation Comparator") as demo: | |
gr.Markdown("## 🎙 Pronunciation Comparator - English, Tamil & Malayalam") | |
gr.Markdown("Practice pronunciation with specialized Whisper models for each language!") | |
gr.Markdown("⚠️ **Note**: Models load on-demand to optimize memory usage. First use may take longer.") | |
with gr.Row(): | |
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language") | |
gen_btn = gr.Button("🎲 Generate Sentence") | |
intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False) | |
# Status indicator | |
status_display = gr.Textbox(label="Status", interactive=False, value="🟢 Ready") | |
with gr.Row(): | |
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation") | |
with gr.Column(): | |
gr.Markdown("### ⚙️ Transcription Parameters") | |
with gr.Row(): | |
pass1_beam = gr.Slider(1, 4, value=2, step=1, label="Beam Size (lower = faster)") | |
pass1_temp = gr.Slider(0.0, 0.8, value=0.2, step=0.1, label="Temperature") | |
pass1_condition = gr.Checkbox(value=False, label="Condition on previous text") | |
submit_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary", size="lg") | |
gr.Markdown("### 📊 Analysis Results") | |
with gr.Row(): | |
pass1_out = gr.Textbox(label="Pass 1: What You Actually Said") | |
pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output") | |
with gr.Row(): | |
hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)") | |
wer_out = gr.Textbox(label="Word Error Rate (WER)") | |
cer_out = gr.Textbox(label="Character Error Rate (CER)") | |
gr.Markdown("### 🎯 Visual Comparison") | |
gr.Markdown("Compare your pronunciation with the expected text to identify areas for improvement") | |
diff_html_box = gr.HTML(label="Word-Level Comparison") | |
char_html_box = gr.HTML(label="Character-Level Analysis") | |
# Event handlers | |
gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display]) | |
submit_btn.click( | |
fn=compare_pronunciation, | |
inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition], | |
outputs=[ | |
pass1_out, pass2_out, hk_out, wer_out, cer_out, | |
diff_html_box, char_html_box, intended_display, status_display | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |