sudhanm's picture
Update app.py
5a75be5 verified
raw
history blame
18.4 kB
import gradio as gr
import random
import difflib
import re
import jiwer
import torch
import torchaudio
import numpy as np
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoTokenizer,
AutoModel
)
from TTS.api import TTS
import librosa
import soundfile as sf
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
import warnings
warnings.filterwarnings("ignore")
# ---------------- CONFIG ---------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LANG_CODES = {
"English": "en",
"Tamil": "ta",
"Malayalam": "ml",
"Hindi": "hi",
"Sanskrit": "sa"
}
# AI4Bharat model configurations
ASR_MODELS = {
"English": "openai/whisper-base.en",
"Tamil": "ai4bharat/whisper-medium-ta",
"Malayalam": "ai4bharat/whisper-medium-ml",
"Hindi": "ai4bharat/whisper-medium-hi",
"Sanskrit": "ai4bharat/whisper-medium-hi" # Fallback to Hindi for Sanskrit
}
TTS_MODELS = {
"English": "tts_models/en/ljspeech/tacotron2-DDC",
"Tamil": "tts_models/ta/mai/tacotron2-DDC",
"Malayalam": "tts_models/ml/mai/tacotron2-DDC",
"Hindi": "tts_models/hi/mai/tacotron2-DDC",
"Sanskrit": "tts_models/hi/mai/tacotron2-DDC" # Fallback to Hindi
}
LANG_PRIMERS = {
"English": ("Transcribe in English.",
"Write only in English. Example: This is an English sentence."),
"Tamil": ("தமிழில் எழுதுக.",
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும். உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."),
"Malayalam": ("മലയാളത്തിൽ എഴുതുക.",
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്."),
"Hindi": ("हिंदी में लिखें।",
"केवल देवनागरी लिपि में लिखें। उदाहरण: यह एक हिंदी वाक्य है।"),
"Sanskrit": ("संस्कृते लिखत।",
"देवनागरी लिपि में लिखें। उदाहरण: अहं संस्कृतं जानामि।")
}
SCRIPT_PATTERNS = {
"Tamil": re.compile(r"[஀-௿]"),
"Malayalam": re.compile(r"[ഀ-ൿ]"),
"Hindi": re.compile(r"[ऀ-ॿ]"),
"Sanskrit": re.compile(r"[ऀ-ॿ]"),
"English": re.compile(r"[A-Za-z]")
}
SENTENCE_BANK = {
"English": [
"The sun sets over the beautiful horizon.",
"Learning new languages opens many doors.",
"I enjoy reading books in the evening.",
"Technology has changed our daily lives.",
"Music brings people together across cultures."
],
"Tamil": [
"இன்று நல்ல வானிலை உள்ளது.",
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
"எனக்கு புத்தகம் படிக்க விருப்பம்.",
"தமிழ் மொழி மிகவும் அழகானது.",
"குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்."
],
"Malayalam": [
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
"ഇന്ന് മഴപെയ്യുന്നു.",
"ഞാൻ പുസ്തകം വായിക്കുന്നു.",
"കേരളത്തിന്റെ പ്രകൃതി സുന്ദരമാണ്.",
"വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്."
],
"Hindi": [
"आज मौसम बहुत अच्छा है।",
"मुझे हिंदी बोलना पसंद है।",
"मैं रोज किताब पढ़ता हूँ।",
"भारत की संस्कृति विविधतापूर्ण है।",
"शिक्षा हमारे भविष्य की कुंजी है।"
],
"Sanskrit": [
"अहं ग्रन्थं पठामि।",
"अद्य सूर्यः तेजस्वी अस्ति।",
"मम नाम रामः।",
"विद्या सर्वत्र पूज्यते।",
"सत्यमेव जयते।"
]
}
# ---------------- MODEL CACHE ---------------- #
asr_models = {}
tts_models = {}
def load_asr_model(language):
"""Load ASR model for specific language"""
if language not in asr_models:
try:
model_name = ASR_MODELS[language]
print(f"Loading ASR model for {language}: {model_name}")
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(DEVICE)
asr_models[language] = {"processor": processor, "model": model}
print(f"✅ ASR model loaded for {language}")
except Exception as e:
print(f"❌ Failed to load ASR for {language}: {e}")
# Fallback to English model
if language != "English":
print(f"🔄 Falling back to English ASR for {language}")
load_asr_model("English")
asr_models[language] = asr_models["English"]
return asr_models[language]
def load_tts_model(language):
"""Load TTS model for specific language"""
if language not in tts_models:
try:
model_name = TTS_MODELS[language]
print(f"Loading TTS model for {language}: {model_name}")
tts = TTS(model_name=model_name).to(DEVICE)
tts_models[language] = tts
print(f"✅ TTS model loaded for {language}")
except Exception as e:
print(f"❌ Failed to load TTS for {language}: {e}")
# Fallback to English
if language != "English":
print(f"🔄 Falling back to English TTS for {language}")
load_tts_model("English")
tts_models[language] = tts_models["English"]
return tts_models[language]
# ---------------- HELPERS ---------------- #
def get_random_sentence(language_choice):
"""Get random sentence for practice"""
return random.choice(SENTENCE_BANK[language_choice])
def is_script(text, lang_name):
"""Check if text is in expected script"""
pattern = SCRIPT_PATTERNS.get(lang_name)
return bool(pattern.search(text)) if pattern else True
def transliterate_to_hk(text, lang_choice):
"""Transliterate Indic text to Harvard-Kyoto"""
mapping = {
"Tamil": sanscript.TAMIL,
"Malayalam": sanscript.MALAYALAM,
"Hindi": sanscript.DEVANAGARI,
"Sanskrit": sanscript.DEVANAGARI,
"English": None
}
script = mapping.get(lang_choice)
if script and is_script(text, lang_choice):
try:
return transliterate(text, script, sanscript.HK)
except:
return text
return text
def preprocess_audio(audio_path, target_sr=16000):
"""Preprocess audio for ASR"""
try:
# Load audio
audio, sr = librosa.load(audio_path, sr=target_sr)
# Normalize audio
audio = audio / np.max(np.abs(audio))
# Remove silence
audio, _ = librosa.effects.trim(audio, top_db=20)
return audio, target_sr
except Exception as e:
print(f"Audio preprocessing error: {e}")
return None, None
def transcribe_with_ai4bharat(audio_path, language, initial_prompt=""):
"""Transcribe audio using AI4Bharat models"""
try:
# Load model
asr_components = load_asr_model(language)
processor = asr_components["processor"]
model = asr_components["model"]
# Preprocess audio
audio, sr = preprocess_audio(audio_path)
if audio is None:
return "Error: Could not process audio"
# Prepare inputs
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(**inputs, max_length=200)
# Decode
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription.strip()
except Exception as e:
print(f"Transcription error for {language}: {e}")
return f"Error: Transcription failed - {str(e)}"
def synthesize_with_ai4bharat(text, language):
"""Synthesize speech using AI4Bharat TTS"""
if not text.strip():
return None
try:
# Load TTS model
tts = load_tts_model(language)
# Generate audio
audio_path = f"/tmp/tts_output_{hash(text)}.wav"
tts.tts_to_file(text=text, file_path=audio_path)
# Load generated audio
audio, sr = librosa.load(audio_path, sr=22050)
return sr, audio
except Exception as e:
print(f"TTS error for {language}: {e}")
return None
def highlight_differences(ref, hyp):
"""Highlight word-level differences"""
ref_words = ref.strip().split()
hyp_words = hyp.strip().split()
sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
out_html = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'equal':
out_html.extend([f"<span style='color:green; font-weight:bold'>{w}</span>" for w in ref_words[i1:i2]])
elif tag == 'replace':
out_html.extend([f"<span style='color:red; text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
out_html.extend([f"<span style='color:orange; font-weight:bold'> → {w}</span>" for w in hyp_words[j1:j2]])
elif tag == 'delete':
out_html.extend([f"<span style='color:red; text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
elif tag == 'insert':
out_html.extend([f"<span style='color:orange; font-weight:bold'>+{w}</span>" for w in hyp_words[j1:j2]])
return " ".join(out_html)
def char_level_highlight(ref, hyp):
"""Highlight character-level differences"""
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
out = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'equal':
out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]])
elif tag in ('replace', 'delete'):
out.extend([f"<span style='color:red; text-decoration:underline; font-weight:bold'>{c}</span>" for c in ref[i1:i2]])
elif tag == 'insert':
out.extend([f"<span style='color:orange; background-color:yellow'>{c}</span>" for c in hyp[j1:j2]])
return "".join(out)
# ---------------- MAIN FUNCTION ---------------- #
def compare_pronunciation(audio, language_choice, intended_sentence):
"""Main function to compare pronunciation"""
if audio is None or not intended_sentence.strip():
return ("❌ No audio or intended sentence provided.", "", "", "", "", "",
None, None, "", "")
try:
print(f"Processing audio for {language_choice}")
# Pass 1: Raw transcription
primer_weak, _ = LANG_PRIMERS[language_choice]
actual_text = transcribe_with_ai4bharat(audio, language_choice, primer_weak)
# Pass 2: Target-biased transcription
_, primer_strong = LANG_PRIMERS[language_choice]
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
corrected_text = transcribe_with_ai4bharat(audio, language_choice, strict_prompt)
# Error metrics
try:
wer_val = jiwer.wer(intended_sentence, actual_text)
cer_val = jiwer.cer(intended_sentence, actual_text)
except:
wer_val, cer_val = 1.0, 1.0
# Transliteration
hk_translit = transliterate_to_hk(actual_text, language_choice)
if not is_script(actual_text, language_choice):
hk_translit = f"⚠️ Script mismatch: expected {language_choice} script"
# Visual feedback
diff_html = highlight_differences(intended_sentence, actual_text)
char_html = char_level_highlight(intended_sentence, actual_text)
# TTS synthesis
tts_intended = synthesize_with_ai4bharat(intended_sentence, language_choice)
tts_actual = synthesize_with_ai4bharat(actual_text, language_choice)
# Status message
status = f"✅ Analysis complete for {language_choice}"
if wer_val < 0.1:
status += " - Excellent pronunciation! 🎉"
elif wer_val < 0.3:
status += " - Good pronunciation! 👍"
elif wer_val < 0.5:
status += " - Needs improvement 📚"
else:
status += " - Keep practicing! 💪"
return (
status,
actual_text,
corrected_text,
hk_translit,
f"{wer_val:.3f}",
f"{cer_val:.3f}",
diff_html,
tts_intended,
tts_actual,
char_html,
intended_sentence
)
except Exception as e:
error_msg = f"❌ Error during analysis: {str(e)}"
print(error_msg)
return (error_msg, "", "", "", "", "", None, None, "", "")
# ---------------- UI ---------------- #
def create_interface():
with gr.Blocks(title="🎙️ AI4Bharat Pronunciation Trainer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎙️ AI4Bharat Pronunciation Trainer
Practice pronunciation in **Tamil, Malayalam, Hindi, Sanskrit & English** using state-of-the-art AI4Bharat models!
📋 **How to use:**
1. Select your target language
2. Generate a practice sentence
3. Record yourself reading it aloud
4. Get detailed feedback with error analysis
""")
with gr.Row():
with gr.Column(scale=2):
lang_choice = gr.Dropdown(
choices=list(LANG_CODES.keys()),
value="Tamil",
label="🌍 Select Language"
)
with gr.Column(scale=1):
gen_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary")
intended_display = gr.Textbox(
label="📝 Practice Sentence (Read this aloud)",
placeholder="Click 'Generate Practice Sentence' to get started...",
interactive=False,
lines=2
)
with gr.Row():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎤 Record Your Pronunciation"
)
analyze_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary", size="lg")
status_output = gr.Textbox(label="📊 Analysis Status", interactive=False)
with gr.Row():
with gr.Column():
pass1_out = gr.Textbox(label="🎯 What You Actually Said", interactive=False)
wer_out = gr.Textbox(label="📈 Word Error Rate (lower = better)", interactive=False)
with gr.Column():
pass2_out = gr.Textbox(label="🔧 Target-Biased Output", interactive=False)
cer_out = gr.Textbox(label="📊 Character Error Rate (lower = better)", interactive=False)
hk_out = gr.Textbox(label="🔤 Romanization (Harvard-Kyoto)", interactive=False)
with gr.Accordion("📝 Detailed Feedback", open=True):
diff_html_box = gr.HTML(label="🔍 Word-Level Differences")
char_html_box = gr.HTML(label="🔤 Character-Level Analysis")
with gr.Row():
intended_tts_audio = gr.Audio(label="🔊 Reference Pronunciation", type="numpy")
actual_tts_audio = gr.Audio(label="🔊 Your Pronunciation (TTS)", type="numpy")
gr.Markdown("""
### 🎨 Color Guide:
- 🟢 **Green**: Correctly pronounced
- 🔴 **Red**: Missing or incorrect words
- 🟠 **Orange**: Extra or substituted words
- 🟡 **Yellow background**: Inserted characters
""")
# Event handlers
gen_btn.click(
fn=get_random_sentence,
inputs=[lang_choice],
outputs=[intended_display]
)
analyze_btn.click(
fn=compare_pronunciation,
inputs=[audio_input, lang_choice, intended_display],
outputs=[
status_output, pass1_out, pass2_out, hk_out,
wer_out, cer_out, diff_html_box,
intended_tts_audio, actual_tts_audio,
char_html_box, intended_display
]
)
# Auto-generate sentence on language change
lang_choice.change(
fn=get_random_sentence,
inputs=[lang_choice],
outputs=[intended_display]
)
return demo
# ---------------- LAUNCH ---------------- #
if __name__ == "__main__":
print("🚀 Starting AI4Bharat Pronunciation Trainer...")
# Pre-load English models for faster startup
print("📦 Pre-loading English models...")
try:
load_asr_model("English")
load_tts_model("English")
print("✅ English models loaded successfully")
except Exception as e:
print(f"⚠️ Warning: Could not pre-load English models: {e}")
demo = create_interface()
demo.launch(
share=True,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)