sudhanm's picture
Update app.py
a950033 verified
raw
history blame
23.9 kB
import gradio as gr
import random
import difflib
import re
import jiwer
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import librosa
import soundfile as sf
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
import unicodedata
import warnings
import spaces
warnings.filterwarnings("ignore")
# ---------------- CONFIG ---------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Using device: {DEVICE}")
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
LANG_CODES = {
"English": "en",
"Tamil": "ta",
"Malayalam": "ml"
}
INDICWHISPER_MODEL = "openai/whisper-large-v2"
SPECIALIZED_MODELS = {
"English": "openai/whisper-base.en",
"Tamil": "vasista22/whisper-tamil-large-v2",
"Malayalam": "thennal/whisper-medium-ml"
}
LANG_PRIMERS = {
"English": ("Transcribe in English.",
"Write only in English. Example: This is an English sentence."),
"Tamil": ("தமிழில் எழுதுக.",
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும். உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."),
"Malayalam": ("മലയാളത്തിൽ എഴുതുക.",
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്.")
}
SCRIPT_PATTERNS = {
"Tamil": re.compile(r"[஀-௿]"),
"Malayalam": re.compile(r"[ഀ-ൿ]"),
"English": re.compile(r"[A-Za-z]")
}
SENTENCE_BANK = {
"English": [
"The sun sets over the beautiful horizon.",
"Learning new languages opens many doors.",
"I enjoy reading books in the evening.",
"Technology has changed our daily lives.",
"Music brings people together across cultures.",
"Education is the key to a bright future.",
"The flowers bloom beautifully in spring.",
"Hard work always pays off in the end."
],
"Tamil": [
"இன்று நல்ல வானிலை உள்ளது.",
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
"எனக்கு புத்தகம் படிக்க விருப்பம்.",
"தமிழ் மொழி மிகவும் அழகானது.",
"குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.",
"கல்வி நமது எதிர்காலத்தின் திறவுகோல்.",
"பறவைகள் காலையில் இனிமையாக பாடுகின்றன.",
"உழைப்பு எப்போதும் வெற்றியைத் தரும்."
],
"Malayalam": [
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
"ഇന്ന് മഴപെയ്യുന്നു.",
"ഞാൻ പുസ്തകം വായിക്കുന്നു.",
"കേരളത്തിന്റെ പ്രകൃതി സുന്ദരമാണ്.",
"വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്.",
"സംഗീതം മനസ്സിന് സന്തോഷം നൽകുന്നു.",
"കുടുംബസമയം വളരെ വിലപ്പെട്ടതാണ്.",
"കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും."
]
}
# Controls for stricter script checking and normalization
STRICT_SCRIPT_CHECK = False # set True for strict script-only validation
NORMALIZE_TEXT_FOR_METRICS = True
# ---------------- MODEL CACHE ---------------- #
indicwhisper_pipeline = None
fallback_models = {}
WHISPER_JAX_AVAILABLE = False
def normalize_text(s: str) -> str:
if not NORMALIZE_TEXT_FOR_METRICS:
return s
# Normalize unicode and collapse whitespace; do not remove language-specific punctuation
s = unicodedata.normalize("NFC", s)
s = re.sub(r"\s+", " ", s).strip()
return s
def get_random_sentence(language_choice):
return random.choice(SENTENCE_BANK[language_choice])
def is_script(text, lang_name):
pattern = SCRIPT_PATTERNS.get(lang_name)
if not pattern:
return True
if not STRICT_SCRIPT_CHECK:
# any occurrence of script chars counts as match
return bool(pattern.search(text))
# strict: allow only spaces and target script chars
for ch in text:
if ch.isspace():
continue
if not pattern.match(ch):
return False
return True
def transliterate_to_hk(text, lang_choice):
mapping = {
"Tamil": sanscript.TAMIL,
"Malayalam": sanscript.MALAYALAM,
"English": None
}
script = mapping.get(lang_choice)
if script and is_script(text, lang_choice):
try:
return transliterate(text, script, sanscript.HK)
except Exception as e:
print(f"Transliteration error: {e}")
return text
return text
def preprocess_audio(audio_path, target_sr=16000):
try:
audio, sr = librosa.load(audio_path, sr=target_sr)
if np.max(np.abs(audio)) > 0:
audio = audio / np.max(np.abs(audio))
audio, _ = librosa.effects.trim(audio, top_db=20)
if len(audio) < target_sr * 0.1:
return None, None
return audio, target_sr
except Exception as e:
print(f"Audio preprocessing error: {e}")
return None, None
@spaces.GPU
def load_indicwhisper():
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
if indicwhisper_pipeline is None:
try:
# Try JAX pipeline
try:
from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp
print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
indicwhisper_pipeline = FlaxWhisperPipeline(
INDICWHISPER_MODEL,
dtype=jnp.bfloat16,
batch_size=1
)
WHISPER_JAX_AVAILABLE = True
print("✅ JAX-optimized model loaded successfully!")
return indicwhisper_pipeline
except Exception as e:
print(f"⚠️ JAX loading failed: {e}")
WHISPER_JAX_AVAILABLE = False
# Fallback to transformers pipeline
print(f"🔄 Loading transformers pipeline: {INDICWHISPER_MODEL}")
from transformers import pipeline
indicwhisper_pipeline = pipeline(
"automatic-speech-recognition",
model=INDICWHISPER_MODEL,
device=DEVICE_INDEX,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
)
print("✅ High-performance model loaded with transformers!")
except Exception as e:
print(f"❌ Failed to load primary model: {e}")
indicwhisper_pipeline = None
raise Exception(f"Could not load high-performance model: {str(e)}")
return indicwhisper_pipeline
@spaces.GPU
def load_specialized_model(language):
if language not in fallback_models:
model_name = SPECIALIZED_MODELS[language]
print(f"🔄 Loading specialized model for {language}: {model_name}")
try:
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
low_cpu_mem_usage=True,
use_safetensors=True
).to(DEVICE)
model.eval()
fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
print(f"✅ Specialized model loaded for {language}")
except Exception as e:
print(f"❌ Failed to load specialized {model_name}: {e}")
raise Exception(f"Could not load specialized {language} model")
return fallback_models[language]
@spaces.GPU
def transcribe_with_primary_model(audio_path, language):
try:
pipe = load_indicwhisper()
if callable(pipe):
# Try to set forced decoder ids when available
if language != "English":
lang_code = LANG_CODES.get(language, "en")
try:
if hasattr(pipe, "model") and hasattr(pipe, "tokenizer"):
if hasattr(pipe.model, "config"):
forced_ids = pipe.tokenizer.get_decoder_prompt_ids(
language=lang_code, task="transcribe"
)
pipe.model.config.forced_decoder_ids = forced_ids
except Exception as e:
print(f"⚠️ Language forcing failed: {e}")
result = pipe(audio_path)
if isinstance(result, dict) and "text" in result:
return result["text"].strip()
elif isinstance(result, str):
return result.strip()
else:
return str(result).strip()
else:
return "Error: Pipeline not properly initialized"
except Exception as e:
print(f"Primary model transcription error: {e}")
raise e
@spaces.GPU
def transcribe_with_specialized_model(audio_path, language):
try:
components = load_specialized_model(language)
processor = components["processor"]
model = components["model"]
audio, sr = preprocess_audio(audio_path)
if audio is None:
return "Error: Audio too short or could not be processed"
inputs = processor(
audio,
sampling_rate=sr,
return_tensors="pt",
padding=True
)
input_features = inputs.input_features.to(DEVICE)
forced_decoder_ids = None
if language != "English":
lang_code = LANG_CODES.get(language, "en")
try:
if hasattr(processor, "get_decoder_prompt_ids"):
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=lang_code,
task="transcribe"
)
except Exception as e:
print(f"⚠️ Language forcing failed: {e}")
with torch.no_grad():
gen_kwargs = {
"max_length": 200,
"num_beams": 3,
"do_sample": False
}
if forced_decoder_ids:
gen_kwargs["forced_decoder_ids"] = forced_decoder_ids
predicted_ids = model.generate(
input_features,
**gen_kwargs
)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)[0]
return transcription.strip() or "(No transcription generated)"
except Exception as e:
print(f"Specialized model transcription error: {e}")
return f"Error: {str(e)[:150]}..."
@spaces.GPU
def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
try:
if use_specialized:
print(f"🔄 Using specialized model for {language}")
return transcribe_with_specialized_model(audio_path, language)
else:
print(f"🔄 Using high-performance primary model for {language}")
return transcribe_with_primary_model(audio_path, language)
except Exception as e:
print(f"Transcription failed, trying specialized model: {e}")
if not use_specialized:
return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
else:
return f"Error: All transcription methods failed - {str(e)[:100]}"
def highlight_differences(ref, hyp):
if not ref.strip() or not hyp.strip():
return "No text to compare"
ref_words = ref.strip().split()
hyp_words = hyp.strip().split()
sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
out_html = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'equal':
out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
elif tag == 'replace':
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>→{w}</span>" for w in hyp_words[j1:j2]])
elif tag == 'delete':
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
elif tag == 'insert':
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
return " ".join(out_html)
def char_level_highlight(ref, hyp):
if not ref.strip() or not hyp.strip():
return "No text to compare"
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
out = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'equal':
out.extend([f"<span style='color:green; background-color:#e8f5e8;'>{c}</span>" for c in ref[i1:i2]])
elif tag in ('replace', 'delete'):
out.extend([f"<span style='color:red; text-decoration:underline; background-color:#ffe8e8; font-weight:bold;'>{c}</span>" for c in ref[i1:i2]])
elif tag == 'insert':
out.extend([f"<span style='color:orange; background-color:#fff3cd; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]])
return "".join(out)
def get_pronunciation_score(wer_val, cer_val):
combined_score = (wer_val * 0.7) + (cer_val * 0.3)
if combined_score <= 0.1:
return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!"
elif combined_score <= 0.2:
return "🎉 Very Good! (80-90%)", "Great pronunciation with minor areas for improvement."
elif combined_score <= 0.4:
return "👍 Good! (60-80%)", "Good effort! Keep practicing for better accuracy."
elif combined_score <= 0.6:
return "📚 Needs Practice (40-60%)", "Focus on clearer pronunciation of highlighted words."
else:
return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
@spaces.GPU
def compare_pronunciation(audio, language_choice, intended_sentence):
print(f"🔍 Starting advanced analysis with language: {language_choice}")
print(f"📝 Audio file: {audio}")
print(f"🎯 Intended sentence: {intended_sentence}")
if audio is None:
print("❌ No audio provided")
return ("❌ Please record audio first.", "", "", "", "", "", "", "")
if not intended_sentence.strip():
print("❌ No intended sentence")
return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
try:
print(f"🔄 Starting Pass 1: High-performance model transcription...")
primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
print(f"✅ Primary model result: {primary_text}")
print("🔄 Starting Pass 2: Specialized model transcription...")
specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
print(f"✅ Specialized model result: {specialized_text}")
actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
if str(actual_text).startswith("Error:"):
print(f"❌ Transcription error: {actual_text}")
return (f"❌ {actual_text}", "", "", "", "", "", "", "")
# Normalize for metrics if enabled
ref_for_metrics = normalize_text(intended_sentence)
hyp_for_metrics = normalize_text(actual_text)
try:
print("🔄 Calculating error metrics...")
wer_val = jiwer.wer(ref_for_metrics, hyp_for_metrics)
cer_val = jiwer.cer(ref_for_metrics, hyp_for_metrics)
print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}")
except Exception as e:
print(f"❌ Error calculating metrics: {e}")
wer_val, cer_val = 1.0, 1.0
score_text, feedback = get_pronunciation_score(wer_val, cer_val)
print("🔄 Generating transliterations...")
actual_hk = transliterate_to_hk(actual_text, language_choice)
target_hk = transliterate_to_hk(intended_sentence, language_choice)
if not is_script(actual_text, language_choice) and language_choice != "English":
actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
print("🔄 Generating visual feedback...")
diff_html = highlight_differences(intended_sentence, actual_text)
char_html = char_level_highlight(intended_sentence, actual_text)
status = f"✅ Advanced Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models"
print(f"✅ Advanced analysis completed successfully")
return (
status,
primary_text or "(No primary transcription)",
specialized_text or "(No specialized transcription)",
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)",
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% character accuracy)",
diff_html,
char_html,
f"🎯 Target: {intended_sentence}"
)
except Exception as e:
error_msg = f"❌ Analysis Error: {str(e)[:200]}"
print(f"❌ FATAL ERROR: {e}")
import traceback
traceback.print_exc()
return (error_msg, str(e), "", "", "", "", "", "")
def create_interface():
with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
gr.Markdown("""
# 🎙️ Advanced Multilingual Pronunciation Trainer
Practice pronunciation in Tamil, Malayalam & English using high-performance ASR models!
### 🏆 Powered by Advanced Models:
- Dual-Model Analysis: Primary + specialized model comparison
- High Accuracy: Language-specific fine-tuned models
- Robust Performance: Automatic fallback for reliability
### 📋 How to Use:
1. Select your target language 🌍
2. Generate a practice sentence 🎲
3. Record yourself reading it aloud 🎤
4. Get detailed feedback with advanced accuracy 📊
### 🎯 Features:
- Dual-pass analysis for comprehensive assessment
- Visual highlighting of pronunciation errors
- Romanization for Indic scripts
- Advanced metrics (Word & Character accuracy)
""")
with gr.Row():
with gr.Column(scale=3):
lang_choice = gr.Dropdown(
choices=list(LANG_CODES.keys()),
value="Tamil",
label="🌍 Select Language"
)
with gr.Column(scale=1):
gen_btn = gr.Button("🎲 Generate Sentence", variant="primary")
intended_display = gr.Textbox(
label="📝 Practice Sentence (Read this aloud)",
placeholder="Click 'Generate Sentence' to get started...",
interactive=False,
lines=3
)
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="🎤 Record Your Pronunciation"
)
analyze_btn = gr.Button("🔍 Analyze with Advanced Models", variant="primary")
status_output = gr.Textbox(
label="📊 Advanced Analysis Results",
interactive=False,
lines=4
)
with gr.Row():
with gr.Column():
pass1_out = gr.Textbox(
label="🏆 Primary Model Output",
interactive=False,
lines=2
)
wer_out = gr.Textbox(
label="📈 Word Accuracy",
interactive=False
)
with gr.Column():
pass2_out = gr.Textbox(
label="🔧 Specialized Model Comparison",
interactive=False,
lines=2
)
cer_out = gr.Textbox(
label="📊 Character Accuracy",
interactive=False
)
with gr.Accordion("📝 Detailed Visual Feedback", open=True):
gr.Markdown("""
### 🎨 Color Guide:
- 🟢 Green: Correctly pronounced words/characters
- 🔴 Red: Missing or mispronounced (strikethrough)
- 🟠 Orange: Extra words or substitutions
""")
diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
char_html_box = gr.HTML(label="🔤 Character-Level Analysis", show_label=True)
target_display = gr.Textbox(
label="🎯 Reference Text",
interactive=False,
visible=False
)
gen_btn.click(
fn=get_random_sentence,
inputs=[lang_choice],
outputs=[intended_display]
)
analyze_btn.click(
fn=compare_pronunciation,
inputs=[audio_input, lang_choice, intended_display],
outputs=[
status_output,
pass1_out,
pass2_out,
wer_out,
cer_out,
diff_html_box,
char_html_box,
target_display
]
)
lang_choice.change(
fn=get_random_sentence,
inputs=[lang_choice],
outputs=[intended_display]
)
gr.Markdown("""
---
### 🏆 Advanced Technology Stack:
- Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
- Specialized Models:
- Tamil: vasista22/whisper-tamil-large-v2
- Malayalam: thennal/whisper-medium-ml
- English: OpenAI Whisper Base EN
- Dual Analysis and Automatic Fallback
### 🔧 Technical Details:
- Metrics: WER and CER
- Transliteration: Harvard-Kyoto for Indic scripts
- Languages: English, Tamil, Malayalam
""")
return demo
if __name__ == "__main__":
print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
print(f"🔧 Device: {DEVICE} (index={DEVICE_INDEX})")
print(f"🔧 PyTorch version: {torch.__version__}")
print("🏆 Using High-Performance Dual-Model Approach")
print("⚡ Automatic model selection with specialized fallbacks")
print("📊 Advanced analysis with robust error handling")
print("🎮 GPU functions decorated with @spaces.GPU for HuggingFace Spaces")
demo = create_interface()
demo.launch(
share=True,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)