Spaces:
Sleeping
Sleeping
import gradio as gr | |
import random | |
import difflib | |
import re | |
import jiwer | |
import torch | |
import numpy as np | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
import librosa | |
import soundfile as sf | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
import unicodedata | |
import warnings | |
import spaces | |
warnings.filterwarnings("ignore") | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"🔧 Using device: {DEVICE}") | |
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1 | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml" | |
} | |
INDICWHISPER_MODEL = "openai/whisper-large-v2" | |
SPECIALIZED_MODELS = { | |
"English": "openai/whisper-base.en", | |
"Tamil": "vasista22/whisper-tamil-large-v2", | |
"Malayalam": "thennal/whisper-medium-ml" | |
} | |
LANG_PRIMERS = { | |
"English": ("Transcribe in English.", | |
"Write only in English. Example: This is an English sentence."), | |
"Tamil": ("தமிழில் எழுதுக.", | |
"தமிழ் எழுத்துக்களில் மட்டும் எழுதவும். உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."), | |
"Malayalam": ("മലയാളത്തിൽ എഴുതുക.", | |
"മലയാള ലിപിയിൽ മാത്രം എഴുതുക. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്.") | |
} | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"English": re.compile(r"[A-Za-z]") | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the beautiful horizon.", | |
"Learning new languages opens many doors.", | |
"I enjoy reading books in the evening.", | |
"Technology has changed our daily lives.", | |
"Music brings people together across cultures.", | |
"Education is the key to a bright future.", | |
"The flowers bloom beautifully in spring.", | |
"Hard work always pays off in the end." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"தமிழ் மொழி மிகவும் அழகானது.", | |
"குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.", | |
"கல்வி நமது எதிர்காலத்தின் திறவுகோல்.", | |
"பறவைகள் காலையில் இனிமையாக பாடுகின்றன.", | |
"உழைப்பு எப்போதும் வெற்றியைத் தரும்." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കേരളത്തിന്റെ പ്രകൃതി സുന്ദരമാണ്.", | |
"വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്.", | |
"സംഗീതം മനസ്സിന് സന്തോഷം നൽകുന്നു.", | |
"കുടുംബസമയം വളരെ വിലപ്പെട്ടതാണ്.", | |
"കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും." | |
] | |
} | |
# Controls for stricter script checking and normalization | |
STRICT_SCRIPT_CHECK = False # set True for strict script-only validation | |
NORMALIZE_TEXT_FOR_METRICS = True | |
# ---------------- MODEL CACHE ---------------- # | |
indicwhisper_pipeline = None | |
fallback_models = {} | |
WHISPER_JAX_AVAILABLE = False | |
def normalize_text(s: str) -> str: | |
if not NORMALIZE_TEXT_FOR_METRICS: | |
return s | |
# Normalize unicode and collapse whitespace; do not remove language-specific punctuation | |
s = unicodedata.normalize("NFC", s) | |
s = re.sub(r"\s+", " ", s).strip() | |
return s | |
def get_random_sentence(language_choice): | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
pattern = SCRIPT_PATTERNS.get(lang_name) | |
if not pattern: | |
return True | |
if not STRICT_SCRIPT_CHECK: | |
# any occurrence of script chars counts as match | |
return bool(pattern.search(text)) | |
# strict: allow only spaces and target script chars | |
for ch in text: | |
if ch.isspace(): | |
continue | |
if not pattern.match(ch): | |
return False | |
return True | |
def transliterate_to_hk(text, lang_choice): | |
mapping = { | |
"Tamil": sanscript.TAMIL, | |
"Malayalam": sanscript.MALAYALAM, | |
"English": None | |
} | |
script = mapping.get(lang_choice) | |
if script and is_script(text, lang_choice): | |
try: | |
return transliterate(text, script, sanscript.HK) | |
except Exception as e: | |
print(f"Transliteration error: {e}") | |
return text | |
return text | |
def preprocess_audio(audio_path, target_sr=16000): | |
try: | |
audio, sr = librosa.load(audio_path, sr=target_sr) | |
if np.max(np.abs(audio)) > 0: | |
audio = audio / np.max(np.abs(audio)) | |
audio, _ = librosa.effects.trim(audio, top_db=20) | |
if len(audio) < target_sr * 0.1: | |
return None, None | |
return audio, target_sr | |
except Exception as e: | |
print(f"Audio preprocessing error: {e}") | |
return None, None | |
def load_indicwhisper(): | |
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE | |
if indicwhisper_pipeline is None: | |
try: | |
# Try JAX pipeline | |
try: | |
from whisper_jax import FlaxWhisperPipeline | |
import jax.numpy as jnp | |
print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}") | |
indicwhisper_pipeline = FlaxWhisperPipeline( | |
INDICWHISPER_MODEL, | |
dtype=jnp.bfloat16, | |
batch_size=1 | |
) | |
WHISPER_JAX_AVAILABLE = True | |
print("✅ JAX-optimized model loaded successfully!") | |
return indicwhisper_pipeline | |
except Exception as e: | |
print(f"⚠️ JAX loading failed: {e}") | |
WHISPER_JAX_AVAILABLE = False | |
# Fallback to transformers pipeline | |
print(f"🔄 Loading transformers pipeline: {INDICWHISPER_MODEL}") | |
from transformers import pipeline | |
indicwhisper_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model=INDICWHISPER_MODEL, | |
device=DEVICE_INDEX, | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 | |
) | |
print("✅ High-performance model loaded with transformers!") | |
except Exception as e: | |
print(f"❌ Failed to load primary model: {e}") | |
indicwhisper_pipeline = None | |
raise Exception(f"Could not load high-performance model: {str(e)}") | |
return indicwhisper_pipeline | |
def load_specialized_model(language): | |
if language not in fallback_models: | |
model_name = SPECIALIZED_MODELS[language] | |
print(f"🔄 Loading specialized model for {language}: {model_name}") | |
try: | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
).to(DEVICE) | |
model.eval() | |
fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name} | |
print(f"✅ Specialized model loaded for {language}") | |
except Exception as e: | |
print(f"❌ Failed to load specialized {model_name}: {e}") | |
raise Exception(f"Could not load specialized {language} model") | |
return fallback_models[language] | |
def transcribe_with_primary_model(audio_path, language): | |
try: | |
pipe = load_indicwhisper() | |
if callable(pipe): | |
# Try to set forced decoder ids when available | |
if language != "English": | |
lang_code = LANG_CODES.get(language, "en") | |
try: | |
if hasattr(pipe, "model") and hasattr(pipe, "tokenizer"): | |
if hasattr(pipe.model, "config"): | |
forced_ids = pipe.tokenizer.get_decoder_prompt_ids( | |
language=lang_code, task="transcribe" | |
) | |
pipe.model.config.forced_decoder_ids = forced_ids | |
except Exception as e: | |
print(f"⚠️ Language forcing failed: {e}") | |
result = pipe(audio_path) | |
if isinstance(result, dict) and "text" in result: | |
return result["text"].strip() | |
elif isinstance(result, str): | |
return result.strip() | |
else: | |
return str(result).strip() | |
else: | |
return "Error: Pipeline not properly initialized" | |
except Exception as e: | |
print(f"Primary model transcription error: {e}") | |
raise e | |
def transcribe_with_specialized_model(audio_path, language): | |
try: | |
components = load_specialized_model(language) | |
processor = components["processor"] | |
model = components["model"] | |
audio, sr = preprocess_audio(audio_path) | |
if audio is None: | |
return "Error: Audio too short or could not be processed" | |
inputs = processor( | |
audio, | |
sampling_rate=sr, | |
return_tensors="pt", | |
padding=True | |
) | |
input_features = inputs.input_features.to(DEVICE) | |
forced_decoder_ids = None | |
if language != "English": | |
lang_code = LANG_CODES.get(language, "en") | |
try: | |
if hasattr(processor, "get_decoder_prompt_ids"): | |
forced_decoder_ids = processor.get_decoder_prompt_ids( | |
language=lang_code, | |
task="transcribe" | |
) | |
except Exception as e: | |
print(f"⚠️ Language forcing failed: {e}") | |
with torch.no_grad(): | |
gen_kwargs = { | |
"max_length": 200, | |
"num_beams": 3, | |
"do_sample": False | |
} | |
if forced_decoder_ids: | |
gen_kwargs["forced_decoder_ids"] = forced_decoder_ids | |
predicted_ids = model.generate( | |
input_features, | |
**gen_kwargs | |
) | |
transcription = processor.batch_decode( | |
predicted_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
)[0] | |
return transcription.strip() or "(No transcription generated)" | |
except Exception as e: | |
print(f"Specialized model transcription error: {e}") | |
return f"Error: {str(e)[:150]}..." | |
def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False): | |
try: | |
if use_specialized: | |
print(f"🔄 Using specialized model for {language}") | |
return transcribe_with_specialized_model(audio_path, language) | |
else: | |
print(f"🔄 Using high-performance primary model for {language}") | |
return transcribe_with_primary_model(audio_path, language) | |
except Exception as e: | |
print(f"Transcription failed, trying specialized model: {e}") | |
if not use_specialized: | |
return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True) | |
else: | |
return f"Error: All transcription methods failed - {str(e)[:100]}" | |
def highlight_differences(ref, hyp): | |
if not ref.strip() or not hyp.strip(): | |
return "No text to compare" | |
ref_words = ref.strip().split() | |
hyp_words = hyp.strip().split() | |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words) | |
out_html = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'replace': | |
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>→{w}</span>" for w in hyp_words[j1:j2]]) | |
elif tag == 'delete': | |
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]]) | |
elif tag == 'insert': | |
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]]) | |
return " ".join(out_html) | |
def char_level_highlight(ref, hyp): | |
if not ref.strip() or not hyp.strip(): | |
return "No text to compare" | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
out = [] | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'equal': | |
out.extend([f"<span style='color:green; background-color:#e8f5e8;'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag in ('replace', 'delete'): | |
out.extend([f"<span style='color:red; text-decoration:underline; background-color:#ffe8e8; font-weight:bold;'>{c}</span>" for c in ref[i1:i2]]) | |
elif tag == 'insert': | |
out.extend([f"<span style='color:orange; background-color:#fff3cd; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]]) | |
return "".join(out) | |
def get_pronunciation_score(wer_val, cer_val): | |
combined_score = (wer_val * 0.7) + (cer_val * 0.3) | |
if combined_score <= 0.1: | |
return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!" | |
elif combined_score <= 0.2: | |
return "🎉 Very Good! (80-90%)", "Great pronunciation with minor areas for improvement." | |
elif combined_score <= 0.4: | |
return "👍 Good! (60-80%)", "Good effort! Keep practicing for better accuracy." | |
elif combined_score <= 0.6: | |
return "📚 Needs Practice (40-60%)", "Focus on clearer pronunciation of highlighted words." | |
else: | |
return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect." | |
def compare_pronunciation(audio, language_choice, intended_sentence): | |
print(f"🔍 Starting advanced analysis with language: {language_choice}") | |
print(f"📝 Audio file: {audio}") | |
print(f"🎯 Intended sentence: {intended_sentence}") | |
if audio is None: | |
print("❌ No audio provided") | |
return ("❌ Please record audio first.", "", "", "", "", "", "", "") | |
if not intended_sentence.strip(): | |
print("❌ No intended sentence") | |
return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "") | |
try: | |
print(f"🔄 Starting Pass 1: High-performance model transcription...") | |
primary_text = transcribe_audio(audio, language_choice, use_specialized=False) | |
print(f"✅ Primary model result: {primary_text}") | |
print("🔄 Starting Pass 2: Specialized model transcription...") | |
specialized_text = transcribe_audio(audio, language_choice, use_specialized=True) | |
print(f"✅ Specialized model result: {specialized_text}") | |
actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text | |
if str(actual_text).startswith("Error:"): | |
print(f"❌ Transcription error: {actual_text}") | |
return (f"❌ {actual_text}", "", "", "", "", "", "", "") | |
# Normalize for metrics if enabled | |
ref_for_metrics = normalize_text(intended_sentence) | |
hyp_for_metrics = normalize_text(actual_text) | |
try: | |
print("🔄 Calculating error metrics...") | |
wer_val = jiwer.wer(ref_for_metrics, hyp_for_metrics) | |
cer_val = jiwer.cer(ref_for_metrics, hyp_for_metrics) | |
print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}") | |
except Exception as e: | |
print(f"❌ Error calculating metrics: {e}") | |
wer_val, cer_val = 1.0, 1.0 | |
score_text, feedback = get_pronunciation_score(wer_val, cer_val) | |
print("🔄 Generating transliterations...") | |
actual_hk = transliterate_to_hk(actual_text, language_choice) | |
target_hk = transliterate_to_hk(intended_sentence, language_choice) | |
if not is_script(actual_text, language_choice) and language_choice != "English": | |
actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script" | |
print("🔄 Generating visual feedback...") | |
diff_html = highlight_differences(intended_sentence, actual_text) | |
char_html = char_level_highlight(intended_sentence, actual_text) | |
status = f"✅ Advanced Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models" | |
print(f"✅ Advanced analysis completed successfully") | |
return ( | |
status, | |
primary_text or "(No primary transcription)", | |
specialized_text or "(No specialized transcription)", | |
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)", | |
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% character accuracy)", | |
diff_html, | |
char_html, | |
f"🎯 Target: {intended_sentence}" | |
) | |
except Exception as e: | |
error_msg = f"❌ Analysis Error: {str(e)[:200]}" | |
print(f"❌ FATAL ERROR: {e}") | |
import traceback | |
traceback.print_exc() | |
return (error_msg, str(e), "", "", "", "", "", "") | |
def create_interface(): | |
with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo: | |
gr.Markdown(""" | |
# 🎙️ Advanced Multilingual Pronunciation Trainer | |
Practice pronunciation in Tamil, Malayalam & English using high-performance ASR models! | |
### 🏆 Powered by Advanced Models: | |
- Dual-Model Analysis: Primary + specialized model comparison | |
- High Accuracy: Language-specific fine-tuned models | |
- Robust Performance: Automatic fallback for reliability | |
### 📋 How to Use: | |
1. Select your target language 🌍 | |
2. Generate a practice sentence 🎲 | |
3. Record yourself reading it aloud 🎤 | |
4. Get detailed feedback with advanced accuracy 📊 | |
### 🎯 Features: | |
- Dual-pass analysis for comprehensive assessment | |
- Visual highlighting of pronunciation errors | |
- Romanization for Indic scripts | |
- Advanced metrics (Word & Character accuracy) | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
lang_choice = gr.Dropdown( | |
choices=list(LANG_CODES.keys()), | |
value="Tamil", | |
label="🌍 Select Language" | |
) | |
with gr.Column(scale=1): | |
gen_btn = gr.Button("🎲 Generate Sentence", variant="primary") | |
intended_display = gr.Textbox( | |
label="📝 Practice Sentence (Read this aloud)", | |
placeholder="Click 'Generate Sentence' to get started...", | |
interactive=False, | |
lines=3 | |
) | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="🎤 Record Your Pronunciation" | |
) | |
analyze_btn = gr.Button("🔍 Analyze with Advanced Models", variant="primary") | |
status_output = gr.Textbox( | |
label="📊 Advanced Analysis Results", | |
interactive=False, | |
lines=4 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
pass1_out = gr.Textbox( | |
label="🏆 Primary Model Output", | |
interactive=False, | |
lines=2 | |
) | |
wer_out = gr.Textbox( | |
label="📈 Word Accuracy", | |
interactive=False | |
) | |
with gr.Column(): | |
pass2_out = gr.Textbox( | |
label="🔧 Specialized Model Comparison", | |
interactive=False, | |
lines=2 | |
) | |
cer_out = gr.Textbox( | |
label="📊 Character Accuracy", | |
interactive=False | |
) | |
with gr.Accordion("📝 Detailed Visual Feedback", open=True): | |
gr.Markdown(""" | |
### 🎨 Color Guide: | |
- 🟢 Green: Correctly pronounced words/characters | |
- 🔴 Red: Missing or mispronounced (strikethrough) | |
- 🟠 Orange: Extra words or substitutions | |
""") | |
diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True) | |
char_html_box = gr.HTML(label="🔤 Character-Level Analysis", show_label=True) | |
target_display = gr.Textbox( | |
label="🎯 Reference Text", | |
interactive=False, | |
visible=False | |
) | |
gen_btn.click( | |
fn=get_random_sentence, | |
inputs=[lang_choice], | |
outputs=[intended_display] | |
) | |
analyze_btn.click( | |
fn=compare_pronunciation, | |
inputs=[audio_input, lang_choice, intended_display], | |
outputs=[ | |
status_output, | |
pass1_out, | |
pass2_out, | |
wer_out, | |
cer_out, | |
diff_html_box, | |
char_html_box, | |
target_display | |
] | |
) | |
lang_choice.change( | |
fn=get_random_sentence, | |
inputs=[lang_choice], | |
outputs=[intended_display] | |
) | |
gr.Markdown(""" | |
--- | |
### 🏆 Advanced Technology Stack: | |
- Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model) | |
- Specialized Models: | |
- Tamil: vasista22/whisper-tamil-large-v2 | |
- Malayalam: thennal/whisper-medium-ml | |
- English: OpenAI Whisper Base EN | |
- Dual Analysis and Automatic Fallback | |
### 🔧 Technical Details: | |
- Metrics: WER and CER | |
- Transliteration: Harvard-Kyoto for Indic scripts | |
- Languages: English, Tamil, Malayalam | |
""") | |
return demo | |
if __name__ == "__main__": | |
print("🚀 Starting Advanced Multilingual Pronunciation Trainer...") | |
print(f"🔧 Device: {DEVICE} (index={DEVICE_INDEX})") | |
print(f"🔧 PyTorch version: {torch.__version__}") | |
print("🏆 Using High-Performance Dual-Model Approach") | |
print("⚡ Automatic model selection with specialized fallbacks") | |
print("📊 Advanced analysis with robust error handling") | |
print("🎮 GPU functions decorated with @spaces.GPU for HuggingFace Spaces") | |
demo = create_interface() | |
demo.launch( | |
share=True, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |