Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
import torch | |
import torchaudio | |
import numpy as np | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForSpeechSeq2Seq, | |
WhisperProcessor, | |
WhisperForConditionalGeneration | |
) | |
import librosa | |
import soundfile as sf | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
import warnings | |
warnings.filterwarnings("ignore") | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"🔧 Using device: {DEVICE}") | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml" | |
} | |
# Updated model configurations with LARGE models for maximum accuracy | |
ASR_MODELS = { | |
"English": "openai/whisper-base.en", | |
"Tamil": "ai4bharat/whisper-large-ta", # LARGE AI4Bharat Tamil model (~1.5GB) | |
"Malayalam": "ai4bharat/whisper-large-ml" # LARGE AI4Bharat Malayalam model (~1.5GB) | |
} | |
LANG_PRIMERS = { | |
"English": ("Transcribe in English.", | |
"Write only in English. Example: This is an English sentence."), | |
"Tamil": ("தமிழில் எழுதுக.", | |
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும். உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."), | |
"Malayalam": ("മലയാളത്തിൽ എഴുതുക.", | |
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്.") | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the beautiful horizon.", | |
"Learning new languages opens many doors.", | |
"I enjoy reading books in the evening.", | |
"Technology has changed our daily lives.", | |
"Music brings people together across cultures.", | |
"Education is the key to a bright future.", | |
"The flowers bloom beautifully in spring.", | |
"Hard work always pays off in the end." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"தமிழ் மொழி மிகவும் அழகானது.", | |
"குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.", | |
"கல்வி நமது எதிர்காலத்தின் திறவுகோல்.", | |
"பறவைகள் காலையில் இனிமையாக பாடுகின்றன.", | |
"உழைப்பு எப்போதும் வெற்றியைத் தரும்." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കേരളത്തിന്റെ പ്രകൃതി സുന്ദരമാണ്.", | |
"വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്.", | |
"സംഗീതം മനസ്സിന് സന്തോഷം നൽകുന്നു.", | |
"കുടുംബസമയം വളരെ വിലപ്പെട്ടതാണ്.", | |
"കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും." | |
] | |
} | |
# ---------------- MODEL CACHE ---------------- # | |
asr_models = {} | |
def load_asr_model(language): | |
"""Load ASR model for specific language - PRIMARY MODELS ONLY""" | |
if language not in asr_models: | |
model_name = ASR_MODELS[language] | |
print(f"🔄 Loading LARGE model for {language}: {model_name}") | |
try: | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
).to(DEVICE) | |
asr_models[language] = {"processor": processor, "model": model, "model_name": model_name} | |
print(f"✅ LARGE model loaded successfully for {language}") | |
except Exception as e: | |
print(f"❌ Failed to load {model_name}: {e}") | |
raise Exception(f"Could not load {language} model. Please check model availability.") | |
return asr_models[language] | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
"""Get random sentence for practice""" | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
"""Check if text is in expected script""" | |
pattern = SCRIPT_PATTERNS.get(lang_name) | |
if not pattern: | |
return True | |
return bool(pattern.search(text)) | |
def transliterate_to_hk(text, lang_choice): | |
"""Transliterate Indic text to Harvard-Kyoto""" | |
mapping = { | |
"Tamil": sanscript.TAMIL, | |
"Malayalam": sanscript.MALAYALAM, | |
"English": None | |
} | |
script = mapping.get(lang_choice) | |
if script and is_script(text, lang_choice): | |
try: | |
return transliterate(text, script, sanscript.HK) | |
except Exception as e: | |
print(f"Transliteration error: {e}") | |
return text | |
return text | |
def preprocess_audio(audio_path, target_sr=16000): | |
"""Preprocess audio for ASR""" | |
try: | |
# Load audio | |
audio, sr = librosa.load(audio_path, sr=target_sr) | |
# Normalize audio | |
if np.max(np.abs(audio)) > 0: | |
audio = audio / np.max(np.abs(audio)) | |
# Remove silence from beginning and end | |
audio, _ = librosa.effects.trim(audio, top_db=20) | |
# Ensure minimum length | |
if len(audio) < target_sr * 0.1: # Less than 0.1 seconds | |
return None, None | |
return audio, target_sr | |
except Exception as e: | |
print(f"Audio preprocessing error: {e}") | |
return None, None | |
def transcribe_audio(audio_path, language, initial_prompt="", force_language=True): | |
"""Transcribe audio using loaded models""" | |
try: | |
# Load model components | |
asr_components = load_asr_model(language) | |
processor = asr_components["processor"] | |
model = asr_components["model"] | |
model_name = asr_components["model_name"] | |
# Preprocess audio | |
audio, sr = preprocess_audio(audio_path) | |
if audio is None: | |
return "Error: Audio too short or could not be processed" | |
# Prepare inputs | |
inputs = processor( | |
audio, | |
sampling_rate=sr, | |
return_tensors="pt", | |
padding=True | |
) | |
# Move to device | |
input_features = inputs.input_features.to(DEVICE) | |
# Generate transcription | |
with torch.no_grad(): | |
# Basic generation parameters | |
generate_kwargs = { | |
"input_features": input_features, | |
"max_length": 200, | |
"num_beams": 3, # Reduced for better compatibility | |
"do_sample": False | |
} | |
# Try different approaches for language forcing | |
if force_language and language != "English": | |
lang_code = LANG_CODES.get(language, "en") | |
# Method 1: Try forced_decoder_ids (OpenAI Whisper style) | |
try: | |
if hasattr(processor, 'get_decoder_prompt_ids'): | |
forced_decoder_ids = processor.get_decoder_prompt_ids( | |
language=lang_code, | |
task="transcribe" | |
) | |
# Test if model accepts this parameter | |
test_kwargs = generate_kwargs.copy() | |
test_kwargs["max_length"] = 10 | |
test_kwargs["forced_decoder_ids"] = forced_decoder_ids | |
_ = model.generate(**test_kwargs) # Test run | |
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids | |
print(f"✅ Using forced_decoder_ids for {language}") | |
except Exception as e: | |
print(f"⚠️ forced_decoder_ids not supported: {e}") | |
# Method 2: Try language parameter | |
try: | |
test_kwargs = generate_kwargs.copy() | |
test_kwargs["max_length"] = 10 | |
test_kwargs["language"] = lang_code | |
_ = model.generate(**test_kwargs) # Test run | |
generate_kwargs["language"] = lang_code | |
print(f"✅ Using language parameter for {language}") | |
except Exception as e: | |
print(f"⚠️ language parameter not supported: {e}") | |
# Generate with whatever parameters work | |
predicted_ids = model.generate(**generate_kwargs) | |
# Decode | |
transcription = processor.batch_decode( | |
predicted_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
)[0] | |
# Post-process transcription | |
transcription = transcription.strip() | |
# If we get empty transcription, try again with simpler parameters | |
if not transcription and generate_kwargs.get("num_beams", 1) > 1: | |
print("🔄 Retrying with greedy decoding...") | |
simple_kwargs = { | |
"input_features": input_features, | |
"max_length": 200, | |
"do_sample": False | |
} | |
predicted_ids = model.generate(**simple_kwargs) | |
transcription = processor.batch_decode( | |
predicted_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
)[0].strip() | |
return transcription or "(No transcription generated)" | |
except Exception as e: | |
print(f"Transcription error for {language}: {e}") | |
return f"Error: {str(e)[:150]}..." | |
def highlight_differences(ref, hyp): | |
"""Highlight word-level differences with better styling""" | |
if not ref.strip() or not hyp.strip(): | |
return "No text to compare" | |
ref_words = ref.strip().split() | |
hyp_words = hyp.strip().split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
out_html = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'replace': | |
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>→{w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'delete': | |
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'insert': | |
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]]) | |
return " ".join(out_html) | |
def char_level_highlight(ref, hyp): | |
"""Highlight character-level differences""" | |
if not ref.strip() or not hyp.strip(): | |
return "No text to compare" | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
out = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out.extend([f"<span style='color:green; background-color:#e8f5e8;'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag in ('replace', 'delete'): | |
out.extend([f"<span style='color:red; text-decoration:underline; background-color:#ffe8e8; font-weight:bold;'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag == 'insert': | |
out.extend([f"<span style='color:orange; background-color:#fff3cd; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]]) | |
return "".join(out) | |
def get_pronunciation_score(wer_val, cer_val): | |
"""Calculate pronunciation score and feedback""" | |
# Weight WER more heavily than CER | |
combined_score = (wer_val * 0.7) + (cer_val * 0.3) | |
if combined_score <= 0.1: | |
return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!" | |
elif combined_score <= 0.2: | |
return "🎉 Very Good! (80-90%)", "Great pronunciation with minor areas for improvement." | |
elif combined_score <= 0.4: | |
return "👍 Good! (60-80%)", "Good effort! Keep practicing for better accuracy." | |
elif combined_score <= 0.6: | |
return "📚 Needs Practice (40-60%)", "Focus on clearer pronunciation of highlighted words." | |
else: | |
return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect." | |
# ---------------- MAIN FUNCTION ---------------- # | |
def compare_pronunciation(audio, language_choice, intended_sentence): | |
"""Main function to compare pronunciation""" | |
print(f"🔍 Starting analysis with language: {language_choice}") | |
print(f"📝 Audio file: {audio}") | |
print(f"🎯 Intended sentence: {intended_sentence}") | |
if audio is None: | |
print("❌ No audio provided") | |
return ("❌ Please record audio first.", "", "", "", "", "", "", "", "", "", "", "", "") | |
if not intended_sentence.strip(): | |
print("❌ No intended sentence") | |
return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "", "", "", "", "", "") | |
try: | |
print(f"🔍 Analyzing pronunciation for {language_choice}...") | |
# Pass 1: Raw transcription | |
print("🔄 Starting Pass 1 transcription...") | |
primer_weak, _ = LANG_PRIMERS[language_choice] | |
actual_text = transcribe_audio(audio, language_choice, primer_weak, force_language=True) | |
print(f"✅ Pass 1 result: {actual_text}") | |
# Pass 2: Target-biased transcription with stronger prompt | |
print("🔄 Starting Pass 2 transcription...") | |
_, primer_strong = LANG_PRIMERS[language_choice] | |
strict_prompt = f"{primer_strong}\nExpected: {intended_sentence}" | |
corrected_text = transcribe_audio(audio, language_choice, strict_prompt, force_language=True) | |
print(f"✅ Pass 2 result: {corrected_text}") | |
# Handle transcription errors | |
if actual_text.startswith("Error:"): | |
print(f"❌ Transcription error: {actual_text}") | |
return (f"❌ {actual_text}", "", "", "", "", "", "", "", "", "", "", "", "") | |
# Calculate error metrics | |
try: | |
print("🔄 Calculating error metrics...") | |
wer_val = jiwer.wer(intended_sentence, actual_text) | |
cer_val = jiwer.cer(intended_sentence, actual_text) | |
print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}") | |
except Exception as e: | |
print(f"❌ Error calculating metrics: {e}") | |
wer_val, cer_val = 1.0, 1.0 | |
# Get pronunciation score and feedback | |
score_text, feedback = get_pronunciation_score(wer_val, cer_val) | |
print(f"✅ Score: {score_text}") | |
# Transliterations for both actual and intended | |
print("🔄 Generating transliterations...") | |
actual_hk = transliterate_to_hk(actual_text, language_choice) | |
target_hk = transliterate_to_hk(intended_sentence, language_choice) | |
# Handle script mismatches | |
if not is_script(actual_text, language_choice) and language_choice != "English": | |
actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script" | |
# Visual feedback | |
print("🔄 Generating visual feedback...") | |
diff_html = highlight_differences(intended_sentence, actual_text) | |
char_html = char_level_highlight(intended_sentence, actual_text) | |
# Status message with detailed feedback | |
status = f"✅ Analysis Complete - {score_text}\n💬 {feedback}" | |
print(f"✅ Analysis completed successfully") | |
return ( | |
status, | |
actual_text or "(No transcription)", | |
corrected_text or "(No corrected transcription)", | |
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)", | |
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% character accuracy)", | |
# New visual feedback outputs | |
actual_text or "(No transcription)", # actual_text_display | |
actual_hk, # actual_transliteration | |
intended_sentence, # target_text_display | |
target_hk, # target_transliteration | |
diff_html, # diff_html_box | |
char_html, # char_html_box | |
intended_sentence, # intended_display (unchanged) | |
f"🎯 Target: {intended_sentence}" # target_display | |
) | |
except Exception as e: | |
error_msg = f"❌ Analysis Error: {str(e)[:200]}" | |
print(f"❌ FATAL ERROR: {e}") | |
import traceback | |
traceback.print_exc() | |
return (error_msg, str(e), "", "", "", "", "", "", "", "", "", "", "") | |
# ---------------- UI ---------------- # | |
def create_interface(): | |
with gr.Blocks(title="🎙️ Multilingual Pronunciation Trainer") as demo: | |
gr.Markdown(""" | |
# 🎙️ Multilingual Pronunciation Trainer | |
**Practice pronunciation in Tamil, Malayalam & English** using advanced speech recognition! | |
### 📋 How to Use: | |
1. **Select** your target language 🌍 | |
2. **Generate** a practice sentence 🎲 | |
3. **Record** yourself reading it aloud 🎤 | |
4. **Get** detailed feedback with accuracy metrics 📊 | |
### 🎯 Features: | |
- **Dual-pass analysis** for accurate assessment | |
- **Visual highlighting** of pronunciation errors | |
- **Romanization** for Indic scripts | |
- **Detailed metrics** (Word & Character accuracy) | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
lang_choice = gr.Dropdown( | |
choices=list(LANG_CODES.keys()), | |
value="Tamil", | |
label="🌍 Select Language" | |
) | |
with gr.Column(scale=1): | |
gen_btn = gr.Button("🎲 Generate Sentence", variant="primary") | |
intended_display = gr.Textbox( | |
label="📝 Practice Sentence (Read this aloud)", | |
placeholder="Click 'Generate Sentence' to get started...", | |
interactive=False, | |
lines=3 | |
) | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="🎤 Record Your Pronunciation" | |
) | |
analyze_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary") | |
status_output = gr.Textbox( | |
label="📊 Analysis Results", | |
interactive=False, | |
lines=3 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
pass1_out = gr.Textbox( | |
label="🎯 What You Actually Said (Raw Output)", | |
interactive=False, | |
lines=2 | |
) | |
wer_out = gr.Textbox( | |
label="📈 Word Accuracy", | |
interactive=False | |
) | |
with gr.Column(): | |
pass2_out = gr.Textbox( | |
label="🔧 Target-Biased Analysis", | |
interactive=False, | |
lines=2 | |
) | |
cer_out = gr.Textbox( | |
label="📊 Character Accuracy", | |
interactive=False | |
) | |
with gr.Accordion("📝 Detailed Visual Feedback", open=True): | |
gr.Markdown(""" | |
### 🎨 Color Guide: | |
- 🟢 **Green**: Correctly pronounced words/characters | |
- 🔴 **Red**: Missing or mispronounced (strikethrough) | |
- 🟠 **Orange**: Extra words or substitutions | |
""") | |
diff_html_box = gr.HTML( | |
label="🔍 Word-Level Analysis", | |
show_label=True | |
) | |
char_html_box = gr.HTML( | |
label="🔤 Character-Level Analysis", | |
show_label=True | |
) | |
target_display = gr.Textbox( | |
label="🎯 Reference Text", | |
interactive=False, | |
visible=False | |
) | |
# Auto-generate sentence on language change | |
lang_choice.change( | |
fn=get_random_sentence, | |
inputs=[lang_choice], | |
outputs=[intended_display] | |
) | |
# Footer | |
gr.Markdown(""" | |
--- | |
### 🔧 Technical Details: | |
- **ASR Models**: | |
- **Tamil**: AI4Bharat Whisper-LARGE-TA (~1.5GB, maximum accuracy) | |
- **Malayalam**: AI4Bharat Whisper-LARGE-ML (~1.5GB, maximum accuracy) | |
- **English**: OpenAI Whisper-Base-EN (optimized for English) | |
- **Performance**: Using largest available models for best pronunciation assessment | |
- **Metrics**: WER (Word Error Rate) and CER (Character Error Rate) | |
- **Transliteration**: Harvard-Kyoto system for Indic scripts | |
- **Analysis**: Dual-pass approach for comprehensive feedback | |
**Note**: Large models provide maximum accuracy but require longer initial loading time. | |
**Languages**: English, Tamil, and Malayalam with specialized large models. | |
""") | |
return demo | |
# ---------------- LAUNCH ---------------- # | |
if __name__ == "__main__": | |
print("🚀 Starting Multilingual Pronunciation Trainer with LARGE models...") | |
print(f"🔧 Device: {DEVICE}") | |
print(f"🔧 PyTorch version: {torch.__version__}") | |
print("📦 Models will be loaded on-demand for best performance...") | |
print("⚡ Using AI4Bharat LARGE models for maximum accuracy!") | |
demo = create_interface() | |
demo.launch( | |
share=True, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |