Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
import torch | |
import torchaudio | |
import numpy as np | |
from transformers import ( | |
AutoProcessor, | |
AutoModelForSpeechSeq2Seq, | |
AutoTokenizer, | |
AutoModel | |
) | |
from TTS.api import TTS | |
import librosa | |
import soundfile as sf | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
import warnings | |
warnings.filterwarnings("ignore") | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml", | |
"Hindi": "hi", | |
"Sanskrit": "sa" | |
} | |
# AI4Bharat model configurations | |
ASR_MODELS = { | |
"English": "openai/whisper-base.en", | |
"Tamil": "ai4bharat/whisper-medium-ta", | |
"Malayalam": "ai4bharat/whisper-medium-ml", | |
"Hindi": "ai4bharat/whisper-medium-hi", | |
"Sanskrit": "ai4bharat/whisper-medium-hi" # Fallback to Hindi for Sanskrit | |
} | |
TTS_MODELS = { | |
"English": "tts_models/en/ljspeech/tacotron2-DDC", | |
"Tamil": "tts_models/ta/mai/tacotron2-DDC", | |
"Malayalam": "tts_models/ml/mai/tacotron2-DDC", | |
"Hindi": "tts_models/hi/mai/tacotron2-DDC", | |
"Sanskrit": "tts_models/hi/mai/tacotron2-DDC" # Fallback to Hindi | |
} | |
LANG_PRIMERS = { | |
"English": ("Transcribe in English.", | |
"Write only in English. Example: This is an English sentence."), | |
"Tamil": ("தமிழில் எழுதுக.", | |
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும். உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."), | |
"Malayalam": ("മലയാളത്തിൽ എഴുതുക.", | |
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്."), | |
"Hindi": ("हिंदी में लिखें।", | |
"केवल देवनागरी लिपि में लिखें। उदाहरण: यह एक हिंदी वाक्य है।"), | |
"Sanskrit": ("संस्कृते लिखत।", | |
"देवनागरी लिपि में लिखें। उदाहरण: अहं संस्कृतं जानामि।") | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"Hindi": re.compile(r"[ऀ-ॿ]"), | |
"Sanskrit": re.compile(r"[ऀ-ॿ]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the beautiful horizon.", | |
"Learning new languages opens many doors.", | |
"I enjoy reading books in the evening.", | |
"Technology has changed our daily lives.", | |
"Music brings people together across cultures." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"தமிழ் மொழி மிகவும் அழகானது.", | |
"குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കേരളത്തിന്റെ പ്രകൃതി സുന്ദരമാണ്.", | |
"വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്." | |
], | |
"Hindi": [ | |
"आज मौसम बहुत अच्छा है।", | |
"मुझे हिंदी बोलना पसंद है।", | |
"मैं रोज किताब पढ़ता हूँ।", | |
"भारत की संस्कृति विविधतापूर्ण है।", | |
"शिक्षा हमारे भविष्य की कुंजी है।" | |
], | |
"Sanskrit": [ | |
"अहं ग्रन्थं पठामि।", | |
"अद्य सूर्यः तेजस्वी अस्ति।", | |
"मम नाम रामः।", | |
"विद्या सर्वत्र पूज्यते।", | |
"सत्यमेव जयते।" | |
] | |
} | |
# ---------------- MODEL CACHE ---------------- # | |
asr_models = {} | |
tts_models = {} | |
def load_asr_model(language): | |
"""Load ASR model for specific language""" | |
if language not in asr_models: | |
try: | |
model_name = ASR_MODELS[language] | |
print(f"Loading ASR model for {language}: {model_name}") | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(DEVICE) | |
asr_models[language] = {"processor": processor, "model": model} | |
print(f"✅ ASR model loaded for {language}") | |
except Exception as e: | |
print(f"❌ Failed to load ASR for {language}: {e}") | |
# Fallback to English model | |
if language != "English": | |
print(f"🔄 Falling back to English ASR for {language}") | |
load_asr_model("English") | |
asr_models[language] = asr_models["English"] | |
return asr_models[language] | |
def load_tts_model(language): | |
"""Load TTS model for specific language""" | |
if language not in tts_models: | |
try: | |
model_name = TTS_MODELS[language] | |
print(f"Loading TTS model for {language}: {model_name}") | |
tts = TTS(model_name=model_name).to(DEVICE) | |
tts_models[language] = tts | |
print(f"✅ TTS model loaded for {language}") | |
except Exception as e: | |
print(f"❌ Failed to load TTS for {language}: {e}") | |
# Fallback to English | |
if language != "English": | |
print(f"🔄 Falling back to English TTS for {language}") | |
load_tts_model("English") | |
tts_models[language] = tts_models["English"] | |
return tts_models[language] | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
"""Get random sentence for practice""" | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
"""Check if text is in expected script""" | |
pattern = SCRIPT_PATTERNS.get(lang_name) | |
return bool(pattern.search(text)) if pattern else True | |
def transliterate_to_hk(text, lang_choice): | |
"""Transliterate Indic text to Harvard-Kyoto""" | |
mapping = { | |
"Tamil": sanscript.TAMIL, | |
"Malayalam": sanscript.MALAYALAM, | |
"Hindi": sanscript.DEVANAGARI, | |
"Sanskrit": sanscript.DEVANAGARI, | |
"English": None | |
} | |
script = mapping.get(lang_choice) | |
if script and is_script(text, lang_choice): | |
try: | |
return transliterate(text, script, sanscript.HK) | |
except: | |
return text | |
return text | |
def preprocess_audio(audio_path, target_sr=16000): | |
"""Preprocess audio for ASR""" | |
try: | |
# Load audio | |
audio, sr = librosa.load(audio_path, sr=target_sr) | |
# Normalize audio | |
audio = audio / np.max(np.abs(audio)) | |
# Remove silence | |
audio, _ = librosa.effects.trim(audio, top_db=20) | |
return audio, target_sr | |
except Exception as e: | |
print(f"Audio preprocessing error: {e}") | |
return None, None | |
def transcribe_with_ai4bharat(audio_path, language, initial_prompt=""): | |
"""Transcribe audio using AI4Bharat models""" | |
try: | |
# Load model | |
asr_components = load_asr_model(language) | |
processor = asr_components["processor"] | |
model = asr_components["model"] | |
# Preprocess audio | |
audio, sr = preprocess_audio(audio_path) | |
if audio is None: | |
return "Error: Could not process audio" | |
# Prepare inputs | |
inputs = processor(audio, sampling_rate=sr, return_tensors="pt") | |
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
# Generate transcription | |
with torch.no_grad(): | |
predicted_ids = model.generate(**inputs, max_length=200) | |
# Decode | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription.strip() | |
except Exception as e: | |
print(f"Transcription error for {language}: {e}") | |
return f"Error: Transcription failed - {str(e)}" | |
def synthesize_with_ai4bharat(text, language): | |
"""Synthesize speech using AI4Bharat TTS""" | |
if not text.strip(): | |
return None | |
try: | |
# Load TTS model | |
tts = load_tts_model(language) | |
# Generate audio | |
audio_path = f"/tmp/tts_output_{hash(text)}.wav" | |
tts.tts_to_file(text=text, file_path=audio_path) | |
# Load generated audio | |
audio, sr = librosa.load(audio_path, sr=22050) | |
return sr, audio | |
except Exception as e: | |
print(f"TTS error for {language}: {e}") | |
return None | |
def highlight_differences(ref, hyp): | |
"""Highlight word-level differences""" | |
ref_words = ref.strip().split() | |
hyp_words = hyp.strip().split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
out_html = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out_html.extend([f"<span style='color:green; font-weight:bold'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'replace': | |
out_html.extend([f"<span style='color:red; text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]]) | |
out_html.extend([f"<span style='color:orange; font-weight:bold'> → {w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'delete': | |
out_html.extend([f"<span style='color:red; text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'insert': | |
out_html.extend([f"<span style='color:orange; font-weight:bold'>+{w}</span>" for w in hyp_words[j1:j2]]) | |
return " ".join(out_html) | |
def char_level_highlight(ref, hyp): | |
"""Highlight character-level differences""" | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
out = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag in ('replace', 'delete'): | |
out.extend([f"<span style='color:red; text-decoration:underline; font-weight:bold'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag == 'insert': | |
out.extend([f"<span style='color:orange; background-color:yellow'>{c}</span>" for c in hyp[j1:j2]]) | |
return "".join(out) | |
# ---------------- MAIN FUNCTION ---------------- # | |
def compare_pronunciation(audio, language_choice, intended_sentence): | |
"""Main function to compare pronunciation""" | |
if audio is None or not intended_sentence.strip(): | |
return ("❌ No audio or intended sentence provided.", "", "", "", "", "", | |
None, None, "", "") | |
try: | |
print(f"Processing audio for {language_choice}") | |
# Pass 1: Raw transcription | |
primer_weak, _ = LANG_PRIMERS[language_choice] | |
actual_text = transcribe_with_ai4bharat(audio, language_choice, primer_weak) | |
# Pass 2: Target-biased transcription | |
_, primer_strong = LANG_PRIMERS[language_choice] | |
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}" | |
corrected_text = transcribe_with_ai4bharat(audio, language_choice, strict_prompt) | |
# Error metrics | |
try: | |
wer_val = jiwer.wer(intended_sentence, actual_text) | |
cer_val = jiwer.cer(intended_sentence, actual_text) | |
except: | |
wer_val, cer_val = 1.0, 1.0 | |
# Transliteration | |
hk_translit = transliterate_to_hk(actual_text, language_choice) | |
if not is_script(actual_text, language_choice): | |
hk_translit = f"⚠️ Script mismatch: expected {language_choice} script" | |
# Visual feedback | |
diff_html = highlight_differences(intended_sentence, actual_text) | |
char_html = char_level_highlight(intended_sentence, actual_text) | |
# TTS synthesis | |
tts_intended = synthesize_with_ai4bharat(intended_sentence, language_choice) | |
tts_actual = synthesize_with_ai4bharat(actual_text, language_choice) | |
# Status message | |
status = f"✅ Analysis complete for {language_choice}" | |
if wer_val < 0.1: | |
status += " - Excellent pronunciation! 🎉" | |
elif wer_val < 0.3: | |
status += " - Good pronunciation! 👍" | |
elif wer_val < 0.5: | |
status += " - Needs improvement 📚" | |
else: | |
status += " - Keep practicing! 💪" | |
return ( | |
status, | |
actual_text, | |
corrected_text, | |
hk_translit, | |
f"{wer_val:.3f}", | |
f"{cer_val:.3f}", | |
diff_html, | |
tts_intended, | |
tts_actual, | |
char_html, | |
intended_sentence | |
) | |
except Exception as e: | |
error_msg = f"❌ Error during analysis: {str(e)}" | |
print(error_msg) | |
return (error_msg, "", "", "", "", "", None, None, "", "") | |
# ---------------- UI ---------------- # | |
def create_interface(): | |
with gr.Blocks(title="🎙️ AI4Bharat Pronunciation Trainer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎙️ AI4Bharat Pronunciation Trainer | |
Practice pronunciation in **Tamil, Malayalam, Hindi, Sanskrit & English** using state-of-the-art AI4Bharat models! | |
📋 **How to use:** | |
1. Select your target language | |
2. Generate a practice sentence | |
3. Record yourself reading it aloud | |
4. Get detailed feedback with error analysis | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
lang_choice = gr.Dropdown( | |
choices=list(LANG_CODES.keys()), | |
value="Tamil", | |
label="🌍 Select Language" | |
) | |
with gr.Column(scale=1): | |
gen_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary") | |
intended_display = gr.Textbox( | |
label="📝 Practice Sentence (Read this aloud)", | |
placeholder="Click 'Generate Practice Sentence' to get started...", | |
interactive=False, | |
lines=2 | |
) | |
with gr.Row(): | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="🎤 Record Your Pronunciation" | |
) | |
analyze_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary", size="lg") | |
status_output = gr.Textbox(label="📊 Analysis Status", interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
pass1_out = gr.Textbox(label="🎯 What You Actually Said", interactive=False) | |
wer_out = gr.Textbox(label="📈 Word Error Rate (lower = better)", interactive=False) | |
with gr.Column(): | |
pass2_out = gr.Textbox(label="🔧 Target-Biased Output", interactive=False) | |
cer_out = gr.Textbox(label="📊 Character Error Rate (lower = better)", interactive=False) | |
hk_out = gr.Textbox(label="🔤 Romanization (Harvard-Kyoto)", interactive=False) | |
with gr.Accordion("📝 Detailed Feedback", open=True): | |
diff_html_box = gr.HTML(label="🔍 Word-Level Differences") | |
char_html_box = gr.HTML(label="🔤 Character-Level Analysis") | |
with gr.Row(): | |
intended_tts_audio = gr.Audio(label="🔊 Reference Pronunciation", type="numpy") | |
actual_tts_audio = gr.Audio(label="🔊 Your Pronunciation (TTS)", type="numpy") | |
gr.Markdown(""" | |
### 🎨 Color Guide: | |
- 🟢 **Green**: Correctly pronounced | |
- 🔴 **Red**: Missing or incorrect words | |
- 🟠 **Orange**: Extra or substituted words | |
- 🟡 **Yellow background**: Inserted characters | |
""") | |
# Event handlers | |
gen_btn.click( | |
fn=get_random_sentence, | |
inputs=[lang_choice], | |
outputs=[intended_display] | |
) | |
analyze_btn.click( | |
fn=compare_pronunciation, | |
inputs=[audio_input, lang_choice, intended_display], | |
outputs=[ | |
status_output, pass1_out, pass2_out, hk_out, | |
wer_out, cer_out, diff_html_box, | |
intended_tts_audio, actual_tts_audio, | |
char_html_box, intended_display | |
] | |
) | |
# Auto-generate sentence on language change | |
lang_choice.change( | |
fn=get_random_sentence, | |
inputs=[lang_choice], | |
outputs=[intended_display] | |
) | |
return demo | |
# ---------------- LAUNCH ---------------- # | |
if __name__ == "__main__": | |
print("🚀 Starting AI4Bharat Pronunciation Trainer...") | |
# Pre-load English models for faster startup | |
print("📦 Pre-loading English models...") | |
try: | |
load_asr_model("English") | |
load_tts_model("English") | |
print("✅ English models loaded successfully") | |
except Exception as e: | |
print(f"⚠️ Warning: Could not pre-load English models: {e}") | |
demo = create_interface() | |
demo.launch( | |
share=True, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |