Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
import torch | |
import warnings | |
import contextlib | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline | |
import librosa | |
import numpy as np | |
# Optional transliteration | |
try: | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
INDIC_OK = True | |
except: | |
INDIC_OK = False | |
print("⚠️ indic_transliteration not available. Transliteration features disabled.") | |
# Optional HF Spaces GPU decorator | |
try: | |
import spaces | |
GPU_DECORATOR = spaces.GPU | |
except: | |
class _NoOp: | |
def __call__(self, f): return f | |
GPU_DECORATOR = _NoOp() | |
warnings.filterwarnings("ignore") | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1 | |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext | |
print(f"🔧 Using device: {DEVICE}") | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml", | |
"Hindi": "hi" | |
} | |
# Primary: IndicWhisper | |
INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil" | |
# Specialized models for better accuracy | |
SPECIALIZED_MODELS = { | |
"English": "openai/whisper-base.en", | |
"Tamil": "vasista22/whisper-tamil-large-v2", | |
"Malayalam": "thennal/whisper-medium-ml", | |
"Hindi": "openai/whisper-large-v2" # Using general model for Hindi | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"Hindi": re.compile(r"[ऀ-ॿ]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
# Transliteration mappings | |
TRANSLITERATION_SCRIPTS = { | |
"Tamil": sanscript.TAMIL, | |
"Malayalam": sanscript.MALAYALAM, | |
"Hindi": sanscript.DEVANAGARI, | |
"English": None | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the horizon.", | |
"Learning languages is fun and rewarding.", | |
"I like to drink coffee in the morning.", | |
"Technology helps us connect with others.", | |
"Reading books expands our knowledge." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"காலையில் காபி குடிக்க பிடிக்கும்.", | |
"நண்பர்களுடன் பேசுவது மகிழ்ச்சி." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കാലയിൽ ചായ കുടിക്കാൻ ഇഷ്ടമാണ്.", | |
"സുഹൃത്തുക്കളോടു സംസാരിക്കുന്നത് സന്തോഷമാണ്." | |
], | |
"Hindi": [ | |
"आज मौसम अच्छा है।", | |
"मुझे हिंदी बोलना पसंद है।", | |
"मैं किताब पढ़ रहा हूँ।", | |
"सुबह चाय पीना अच्छा लगता है।", | |
"दोस्तों के साथ बात करना खुशी देता है।" | |
] | |
} | |
# Model cache | |
primary_pipeline = None | |
specialized_models = {} | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_correct_script(text, lang_name): | |
"""Check if text contains the expected script for the language""" | |
if not text.strip(): | |
return False | |
pattern = SCRIPT_PATTERNS.get(lang_name) | |
if not pattern: | |
return True | |
return bool(pattern.search(text)) | |
def transliterate_text(text, lang_choice, to_romanized=True): | |
"""Transliterate text to/from romanized form""" | |
if not INDIC_OK or not text.strip(): | |
return text | |
source_script = TRANSLITERATION_SCRIPTS.get(lang_choice) | |
if not source_script: | |
return text | |
try: | |
if to_romanized: | |
# Convert to Harvard-Kyoto (romanized) | |
return transliterate(text, source_script, sanscript.HK) | |
else: | |
# Convert from romanized to native script (if needed) | |
return transliterate(text, sanscript.HK, source_script) | |
except Exception as e: | |
print(f"⚠️ Transliteration failed: {e}") | |
return text | |
def preprocess_audio(audio_path, target_sr=16000): | |
"""Enhanced audio preprocessing""" | |
try: | |
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) | |
if audio is None or len(audio) == 0: | |
return None, None | |
# Normalize audio | |
audio = audio.astype(np.float32) | |
max_val = np.max(np.abs(audio)) | |
if max_val > 0: | |
audio = audio / max_val | |
# Trim silence | |
audio, _ = librosa.effects.trim(audio, top_db=20) | |
# Check minimum length (0.1 seconds) | |
if len(audio) < int(target_sr * 0.1): | |
return None, None | |
return audio, target_sr | |
except Exception as e: | |
print(f"⚠️ Audio preprocessing failed: {e}") | |
return None, None | |
# ---------------- MODEL LOADERS ---------------- # | |
def load_primary_model(): | |
"""Load the primary IndicWhisper model""" | |
global primary_pipeline | |
if primary_pipeline is not None: | |
return primary_pipeline | |
try: | |
print(f"🔄 Loading primary model: {INDICWHISPER_MODEL}") | |
# Try direct loading first | |
primary_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model=INDICWHISPER_MODEL, | |
device=DEVICE_INDEX, | |
torch_dtype=DTYPE, | |
trust_remote_code=True | |
) | |
print("✅ Primary model loaded successfully!") | |
return primary_pipeline | |
except Exception as e: | |
print(f"⚠️ Primary model failed, using fallback: {e}") | |
# Fallback to base Whisper | |
primary_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v2", | |
device=DEVICE_INDEX, | |
torch_dtype=DTYPE | |
) | |
print("✅ Fallback model loaded!") | |
return primary_pipeline | |
def load_specialized_model(language): | |
"""Load specialized model for specific language""" | |
if language in specialized_models: | |
return specialized_models[language] | |
model_name = SPECIALIZED_MODELS[language] | |
print(f"🔄 Loading specialized {language} model: {model_name}") | |
try: | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_name, | |
torch_dtype=DTYPE, | |
device_map="auto" if DEVICE == "cuda" else None | |
).to(DEVICE) | |
specialized_models[language] = { | |
"processor": processor, | |
"model": model | |
} | |
print(f"✅ Specialized {language} model loaded!") | |
return specialized_models[language] | |
except Exception as e: | |
print(f"❌ Failed to load specialized {language} model: {e}") | |
return None | |
# ---------------- TRANSCRIPTION ---------------- # | |
def transcribe_with_primary(audio_path, language): | |
"""Transcribe using primary IndicWhisper model""" | |
try: | |
pipeline_model = load_primary_model() | |
lang_code = LANG_CODES[language] | |
# Set language forcing if possible | |
try: | |
if hasattr(pipeline_model, "model") and hasattr(pipeline_model, "tokenizer"): | |
forced_ids = pipeline_model.tokenizer.get_decoder_prompt_ids( | |
language=lang_code, | |
task="transcribe" | |
) | |
if forced_ids: | |
pipeline_model.model.config.forced_decoder_ids = forced_ids | |
except Exception as e: | |
print(f"⚠️ Language forcing failed: {e}") | |
with amp_ctx(): | |
result = pipeline_model(audio_path) | |
if isinstance(result, dict): | |
return result.get("text", "").strip() | |
return str(result).strip() | |
except Exception as e: | |
return f"Primary transcription error: {str(e)}" | |
def transcribe_with_specialized(audio_path, language): | |
"""Transcribe using specialized model""" | |
try: | |
model_components = load_specialized_model(language) | |
if not model_components: | |
return "Specialized model not available" | |
# Preprocess audio | |
audio, sr = preprocess_audio(audio_path) | |
if audio is None: | |
return "Audio preprocessing failed" | |
# Process with specialized model | |
inputs = model_components["processor"]( | |
audio, | |
sampling_rate=sr, | |
return_tensors="pt" | |
) | |
input_features = inputs.input_features.to(DEVICE) | |
# Generation parameters | |
gen_kwargs = { | |
"inputs": input_features, | |
"max_length": 200, | |
"num_beams": 3, | |
"do_sample": False | |
} | |
# Language forcing for non-English | |
if language != "English": | |
try: | |
forced_ids = model_components["processor"].tokenizer.get_decoder_prompt_ids( | |
language=LANG_CODES[language], | |
task="transcribe" | |
) | |
if forced_ids: | |
gen_kwargs["forced_decoder_ids"] = forced_ids | |
except Exception as e: | |
print(f"⚠️ Specialized language forcing failed: {e}") | |
# Generate transcription | |
with torch.no_grad(), amp_ctx(): | |
generated_ids = model_components["model"].generate(**gen_kwargs) | |
# Decode result | |
transcription = model_components["processor"].batch_decode( | |
generated_ids, | |
skip_special_tokens=True | |
)[0] | |
return transcription.strip() | |
except Exception as e: | |
return f"Specialized transcription error: {str(e)}" | |
# ---------------- ANALYSIS ---------------- # | |
def compute_metrics(reference, hypothesis): | |
"""Compute WER and CER with error handling""" | |
try: | |
# Clean up texts | |
ref_clean = reference.strip() | |
hyp_clean = hypothesis.strip() | |
if not ref_clean or not hyp_clean: | |
return 1.0, 1.0 | |
# Compute WER and CER | |
wer = jiwer.wer(ref_clean, hyp_clean) | |
cer = jiwer.cer(ref_clean, hyp_clean) | |
return wer, cer | |
except Exception as e: | |
print(f"⚠️ Metric computation failed: {e}") | |
return 1.0, 1.0 | |
def get_pronunciation_score(wer, cer): | |
"""Convert error rates to intuitive scores and feedback""" | |
# Weighted combination (WER is more important) | |
combined_error = (wer * 0.7) + (cer * 0.3) | |
accuracy = 1 - combined_error | |
if accuracy >= 0.95: | |
return "🏆 Perfect!", "Outstanding pronunciation! Native-like accuracy.", "#d4edda" | |
elif accuracy >= 0.85: | |
return "🎉 Excellent!", "Very good pronunciation with minor variations.", "#d1ecf1" | |
elif accuracy >= 0.70: | |
return "👍 Good!", "Good pronunciation, practice specific sounds.", "#fff3cd" | |
elif accuracy >= 0.50: | |
return "📚 Needs Practice", "Focus on clearer pronunciation and rhythm.", "#f8d7da" | |
else: | |
return "💪 Keep Trying!", "Break down into smaller parts and practice slowly.", "#f5c6cb" | |
def create_detailed_comparison(intended, actual, lang_choice): | |
"""Create detailed side-by-side comparison with transliteration""" | |
# Original scripts | |
intended_orig = intended.strip() | |
actual_orig = actual.strip() | |
# Transliterations | |
intended_translit = transliterate_text(intended_orig, lang_choice, to_romanized=True) | |
actual_translit = transliterate_text(actual_orig, lang_choice, to_romanized=True) | |
# Word-level highlighting | |
word_diff_orig = highlight_word_differences(intended_orig, actual_orig) | |
word_diff_translit = highlight_word_differences(intended_translit, actual_translit) | |
# Character-level highlighting | |
char_diff_orig = highlight_char_differences(intended_orig, actual_orig) | |
char_diff_translit = highlight_char_differences(intended_translit, actual_translit) | |
return { | |
"intended_orig": intended_orig, | |
"actual_orig": actual_orig, | |
"intended_translit": intended_translit, | |
"actual_translit": actual_translit, | |
"word_diff_orig": word_diff_orig, | |
"word_diff_translit": word_diff_translit, | |
"char_diff_orig": char_diff_orig, | |
"char_diff_translit": char_diff_translit | |
} | |
def highlight_word_differences(reference, hypothesis): | |
"""Highlight word-level differences with colors""" | |
ref_words = reference.split() | |
hyp_words = hypothesis.split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
html_output = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
# Correct words - green background | |
html_output.extend([ | |
f"<span style='background-color:#d4edda; color:#155724; padding:2px 4px; margin:1px; border-radius:3px'>{word}</span>" | |
for word in ref_words[i1:i2] | |
]) | |
elif tag == 'replace': | |
# Wrong words - red background for reference, orange for hypothesis | |
html_output.extend([ | |
f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>" | |
for word in ref_words[i1:i2] | |
]) | |
html_output.extend([ | |
f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>→{word}</span>" | |
for word in hyp_words[j1:j2] | |
]) | |
elif tag == 'delete': | |
# Missing words - red background | |
html_output.extend([ | |
f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>" | |
for word in ref_words[i1:i2] | |
]) | |
elif tag == 'insert': | |
# Extra words - orange background | |
html_output.extend([ | |
f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>+{word}</span>" | |
for word in hyp_words[j1:j2] | |
]) | |
return " ".join(html_output) | |
def highlight_char_differences(reference, hypothesis): | |
"""Highlight character-level differences""" | |
sm = difflib.SequenceMatcher(None, list(reference), list(hypothesis)) | |
html_output = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
# Correct characters - green | |
html_output.extend([ | |
f"<span style='color:#28a745'>{char}</span>" | |
for char in reference[i1:i2] | |
]) | |
elif tag in ('replace', 'delete'): | |
# Wrong/missing characters - red with underline | |
html_output.extend([ | |
f"<span style='color:#dc3545; text-decoration:underline; font-weight:bold'>{char}</span>" | |
for char in reference[i1:i2] | |
]) | |
elif tag == 'insert': | |
# Extra characters - orange | |
html_output.extend([ | |
f"<span style='color:#fd7e14; font-weight:bold'>{char}</span>" | |
for char in hypothesis[j1:j2] | |
]) | |
return "".join(html_output) | |
def analyze_pronunciation_errors(intended, actual, lang_choice): | |
"""Provide specific feedback about pronunciation errors""" | |
comparison = create_detailed_comparison(intended, actual, lang_choice) | |
# Analyze error patterns | |
intended_words = intended.split() | |
actual_words = actual.split() | |
error_analysis = [] | |
# Length difference analysis | |
if len(actual_words) < len(intended_words): | |
missing_count = len(intended_words) - len(actual_words) | |
error_analysis.append(f"🔍 You missed {missing_count} word(s). Try speaking more slowly.") | |
elif len(actual_words) > len(intended_words): | |
extra_count = len(actual_words) - len(intended_words) | |
error_analysis.append(f"🔍 You added {extra_count} extra word(s). Focus on the exact sentence.") | |
# Script verification | |
if not is_correct_script(actual, lang_choice): | |
error_analysis.append(f"⚠️ The transcription doesn't contain {lang_choice} script. Check your pronunciation.") | |
# WER/CER based feedback | |
wer, cer = compute_metrics(intended, actual) | |
if wer > 0.5: | |
error_analysis.append("🎯 Focus on pronouncing each word clearly and separately.") | |
elif wer > 0.3: | |
error_analysis.append("🎯 Good overall, but some words need clearer pronunciation.") | |
if cer > 0.3: | |
error_analysis.append("🔤 Pay attention to individual sounds and syllables.") | |
return error_analysis, comparison | |
# ---------------- MAIN FUNCTION ---------------- # | |
def compare_pronunciation(audio, language_choice, intended_sentence): | |
"""Main function to analyze pronunciation""" | |
if audio is None: | |
return ("❌ Please record audio first", "", "", "", "", "", "", "", "", "", "") | |
if not intended_sentence.strip(): | |
return ("❌ Please generate a sentence first", "", "", "", "", "", "", "", "", "", "") | |
print(f"🔍 Analyzing pronunciation for {language_choice}...") | |
# Get transcriptions from both models | |
primary_result = transcribe_with_primary(audio, language_choice) | |
specialized_result = transcribe_with_specialized(audio, language_choice) | |
# Choose best result (prefer specialized if successful) | |
if not specialized_result.startswith("Specialized") and specialized_result.strip(): | |
best_transcription = specialized_result | |
best_source = "Specialized Model" | |
elif not primary_result.startswith("Primary") and primary_result.strip(): | |
best_transcription = primary_result | |
best_source = "Primary Model" | |
else: | |
return ( | |
f"❌ Both models failed:\nPrimary: {primary_result}\nSpecialized: {specialized_result}", | |
"", "", "", "", "", "", "", "", "", "" | |
) | |
# Analyze pronunciation | |
error_analysis, comparison = analyze_pronunciation_errors( | |
intended_sentence, best_transcription, language_choice | |
) | |
# Compute metrics | |
wer, cer = compute_metrics(intended_sentence, best_transcription) | |
score, feedback, color = get_pronunciation_score(wer, cer) | |
# Create status message | |
status_msg = f"""✅ Analysis Complete! | |
{score} | |
{feedback} | |
🤖 Best result from: {best_source} | |
📊 Word Accuracy: {(1-wer)*100:.1f}% | |
📈 Character Accuracy: {(1-cer)*100:.1f}% | |
🔍 Analysis: | |
""" + "\n".join(error_analysis) | |
return ( | |
status_msg, | |
primary_result, | |
specialized_result, | |
f"{wer:.3f} ({(1-wer)*100:.1f}%)", | |
f"{cer:.3f} ({(1-cer)*100:.1f}%)", | |
comparison["intended_orig"], | |
comparison["actual_orig"], | |
comparison["intended_translit"], | |
comparison["actual_translit"], | |
comparison["word_diff_orig"], | |
comparison["char_diff_orig"] | |
) | |
# ---------------- UI ---------------- # | |
def create_interface(): | |
with gr.Blocks(title="Enhanced Pronunciation Comparator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎙️ Enhanced Pronunciation Comparator | |
**Perfect your pronunciation in English, Tamil, Malayalam, and Hindi!** | |
This tool uses specialized AI models to give you detailed feedback on your pronunciation, | |
including transliteration to help you understand exactly where you need improvement. | |
### How to use: | |
1. 🌐 Select your target language | |
2. 🎲 Generate a practice sentence | |
3. 🎤 Record yourself saying the sentence clearly | |
4. 🔍 Get detailed pronunciation analysis with transliteration | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
language_dropdown = gr.Dropdown( | |
choices=list(LANG_CODES.keys()), | |
value="Tamil", | |
label="🌐 Select Language" | |
) | |
with gr.Column(scale=1): | |
generate_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary") | |
intended_textbox = gr.Textbox( | |
label="📝 Practice Sentence", | |
interactive=False, | |
lines=2, | |
placeholder="Click 'Generate Practice Sentence' to get started..." | |
) | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="🎤 Record Your Pronunciation" | |
) | |
analyze_btn = gr.Button("🔍 Analyze Pronunciation", variant="secondary", size="lg") | |
with gr.Row(): | |
status_output = gr.Textbox( | |
label="📊 Analysis Results", | |
interactive=False, | |
lines=8 | |
) | |
with gr.Accordion("🤖 Model Outputs", open=False): | |
with gr.Row(): | |
primary_output = gr.Textbox(label="Primary Model (IndicWhisper)", interactive=False) | |
specialized_output = gr.Textbox(label="Specialized Model", interactive=False) | |
with gr.Accordion("📈 Detailed Metrics", open=False): | |
with gr.Row(): | |
wer_output = gr.Textbox(label="Word Error Rate", interactive=False) | |
cer_output = gr.Textbox(label="Character Error Rate", interactive=False) | |
gr.Markdown("### 🔍 Detailed Comparison") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("#### 📝 Original Script") | |
intended_orig = gr.Textbox(label="🎯 Target Text", interactive=False) | |
actual_orig = gr.Textbox(label="🗣️ What You Said", interactive=False) | |
with gr.Column(): | |
gr.Markdown("#### 🔤 Romanized (Transliterated)") | |
intended_translit = gr.Textbox(label="🎯 Target (Romanized)", interactive=False) | |
actual_translit = gr.Textbox(label="🗣️ What You Said (Romanized)", interactive=False) | |
gr.Markdown("### 🎨 Visual Comparison") | |
gr.Markdown("**Green** = Correct, **Red** = Wrong/Missing, **Orange** = Added/Substituted") | |
word_diff_html = gr.HTML(label="🔤 Word-by-Word Comparison") | |
char_diff_html = gr.HTML(label="🔍 Character-by-Character Analysis") | |
# Event handlers | |
generate_btn.click( | |
fn=get_random_sentence, | |
inputs=[language_dropdown], | |
outputs=[intended_textbox] | |
) | |
analyze_btn.click( | |
fn=compare_pronunciation, | |
inputs=[audio_input, language_dropdown, intended_textbox], | |
outputs=[ | |
status_output, primary_output, specialized_output, | |
wer_output, cer_output, intended_orig, actual_orig, | |
intended_translit, actual_translit, word_diff_html, char_diff_html | |
] | |
) | |
language_dropdown.change( | |
fn=get_random_sentence, | |
inputs=[language_dropdown], | |
outputs=[intended_textbox] | |
) | |
gr.Markdown(""" | |
### 📚 Pro Tips for Better Pronunciation: | |
- **Speak slowly and clearly** - Don't rush through the sentence | |
- **Pronounce each syllable** - Break down complex words | |
- **Check the romanized version** - Use it to understand correct pronunciation | |
- **Practice repeatedly** - Use the same sentence multiple times to track improvement | |
- **Focus on problem areas** - Pay attention to red-highlighted parts | |
- **Record in a quiet environment** - Minimize background noise | |
### 🎯 Understanding the Feedback: | |
- **Green highlights** = Perfect pronunciation ✅ | |
- **Red highlights** = Missing or mispronounced ❌ | |
- **Orange highlights** = Added or substituted 🔄 | |
- **Transliteration** = Helps you see pronunciation patterns | |
- **Error rates** = Lower is better (0% = perfect) | |
""") | |
return demo | |
# ---------------- LAUNCH ---------------- # | |
if __name__ == "__main__": | |
print("🚀 Starting Enhanced Pronunciation Comparator...") | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) |