import gradio as gr import random import difflib import re import jiwer import torch import warnings import contextlib from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline import librosa import numpy as np # Optional transliteration try: from indic_transliteration import sanscript from indic_transliteration.sanscript import transliterate INDIC_OK = True except: INDIC_OK = False print("⚠️ indic_transliteration not available. Transliteration features disabled.") # Optional HF Spaces GPU decorator try: import spaces GPU_DECORATOR = spaces.GPU except: class _NoOp: def __call__(self, f): return f GPU_DECORATOR = _NoOp() warnings.filterwarnings("ignore") # ---------------- CONFIG ---------------- # DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE_INDEX = 0 if DEVICE == "cuda" else -1 DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext print(f"🔧 Using device: {DEVICE}") LANG_CODES = { "English": "en", "Tamil": "ta", "Malayalam": "ml", "Hindi": "hi" } # Primary: IndicWhisper INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil" # Specialized models for better accuracy SPECIALIZED_MODELS = { "English": "openai/whisper-base.en", "Tamil": "vasista22/whisper-tamil-large-v2", "Malayalam": "thennal/whisper-medium-ml", "Hindi": "openai/whisper-large-v2" # Using general model for Hindi } SCRIPT_PATTERNS = { "Tamil": re.compile(r"[஀-௿]"), "Malayalam": re.compile(r"[ഀ-ൿ]"), "Hindi": re.compile(r"[ऀ-ॿ]"), "English": re.compile(r"[A-Za-z]") } # Transliteration mappings TRANSLITERATION_SCRIPTS = { "Tamil": sanscript.TAMIL, "Malayalam": sanscript.MALAYALAM, "Hindi": sanscript.DEVANAGARI, "English": None } SENTENCE_BANK = { "English": [ "The sun sets over the horizon.", "Learning languages is fun and rewarding.", "I like to drink coffee in the morning.", "Technology helps us connect with others.", "Reading books expands our knowledge." ], "Tamil": [ "இன்று நல்ல வானிலை உள்ளது.", "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", "எனக்கு புத்தகம் படிக்க விருப்பம்.", "காலையில் காபி குடிக்க பிடிக்கும்.", "நண்பர்களுடன் பேசுவது மகிழ்ச்சி." ], "Malayalam": [ "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", "ഇന്ന് മഴപെയ്യുന്നു.", "ഞാൻ പുസ്തകം വായിക്കുന്നു.", "കാലയിൽ ചായ കുടിക്കാൻ ഇഷ്ടമാണ്.", "സുഹൃത്തുക്കളോടു സംസാരിക്കുന്നത് സന്തോഷമാണ്." ], "Hindi": [ "आज मौसम अच्छा है।", "मुझे हिंदी बोलना पसंद है।", "मैं किताब पढ़ रहा हूँ।", "सुबह चाय पीना अच्छा लगता है।", "दोस्तों के साथ बात करना खुशी देता है।" ] } # Model cache primary_pipeline = None specialized_models = {} # ---------------- HELPERS ---------------- # def get_random_sentence(language_choice): return random.choice(SENTENCE_BANK[language_choice]) def is_correct_script(text, lang_name): """Check if text contains the expected script for the language""" if not text.strip(): return False pattern = SCRIPT_PATTERNS.get(lang_name) if not pattern: return True return bool(pattern.search(text)) def transliterate_text(text, lang_choice, to_romanized=True): """Transliterate text to/from romanized form""" if not INDIC_OK or not text.strip(): return text source_script = TRANSLITERATION_SCRIPTS.get(lang_choice) if not source_script: return text try: if to_romanized: # Convert to Harvard-Kyoto (romanized) return transliterate(text, source_script, sanscript.HK) else: # Convert from romanized to native script (if needed) return transliterate(text, sanscript.HK, source_script) except Exception as e: print(f"⚠️ Transliteration failed: {e}") return text def preprocess_audio(audio_path, target_sr=16000): """Enhanced audio preprocessing""" try: audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) if audio is None or len(audio) == 0: return None, None # Normalize audio audio = audio.astype(np.float32) max_val = np.max(np.abs(audio)) if max_val > 0: audio = audio / max_val # Trim silence audio, _ = librosa.effects.trim(audio, top_db=20) # Check minimum length (0.1 seconds) if len(audio) < int(target_sr * 0.1): return None, None return audio, target_sr except Exception as e: print(f"⚠️ Audio preprocessing failed: {e}") return None, None # ---------------- MODEL LOADERS ---------------- # @GPU_DECORATOR def load_primary_model(): """Load the primary IndicWhisper model""" global primary_pipeline if primary_pipeline is not None: return primary_pipeline try: print(f"🔄 Loading primary model: {INDICWHISPER_MODEL}") # Try direct loading first primary_pipeline = pipeline( "automatic-speech-recognition", model=INDICWHISPER_MODEL, device=DEVICE_INDEX, torch_dtype=DTYPE, trust_remote_code=True ) print("✅ Primary model loaded successfully!") return primary_pipeline except Exception as e: print(f"⚠️ Primary model failed, using fallback: {e}") # Fallback to base Whisper primary_pipeline = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v2", device=DEVICE_INDEX, torch_dtype=DTYPE ) print("✅ Fallback model loaded!") return primary_pipeline @GPU_DECORATOR def load_specialized_model(language): """Load specialized model for specific language""" if language in specialized_models: return specialized_models[language] model_name = SPECIALIZED_MODELS[language] print(f"🔄 Loading specialized {language} model: {model_name}") try: processor = AutoProcessor.from_pretrained(model_name) model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=DTYPE, device_map="auto" if DEVICE == "cuda" else None ).to(DEVICE) specialized_models[language] = { "processor": processor, "model": model } print(f"✅ Specialized {language} model loaded!") return specialized_models[language] except Exception as e: print(f"❌ Failed to load specialized {language} model: {e}") return None # ---------------- TRANSCRIPTION ---------------- # @GPU_DECORATOR def transcribe_with_primary(audio_path, language): """Transcribe using primary IndicWhisper model""" try: pipeline_model = load_primary_model() lang_code = LANG_CODES[language] # Set language forcing if possible try: if hasattr(pipeline_model, "model") and hasattr(pipeline_model, "tokenizer"): forced_ids = pipeline_model.tokenizer.get_decoder_prompt_ids( language=lang_code, task="transcribe" ) if forced_ids: pipeline_model.model.config.forced_decoder_ids = forced_ids except Exception as e: print(f"⚠️ Language forcing failed: {e}") with amp_ctx(): result = pipeline_model(audio_path) if isinstance(result, dict): return result.get("text", "").strip() return str(result).strip() except Exception as e: return f"Primary transcription error: {str(e)}" @GPU_DECORATOR def transcribe_with_specialized(audio_path, language): """Transcribe using specialized model""" try: model_components = load_specialized_model(language) if not model_components: return "Specialized model not available" # Preprocess audio audio, sr = preprocess_audio(audio_path) if audio is None: return "Audio preprocessing failed" # Process with specialized model inputs = model_components["processor"]( audio, sampling_rate=sr, return_tensors="pt" ) input_features = inputs.input_features.to(DEVICE) # Generation parameters gen_kwargs = { "inputs": input_features, "max_length": 200, "num_beams": 3, "do_sample": False } # Language forcing for non-English if language != "English": try: forced_ids = model_components["processor"].tokenizer.get_decoder_prompt_ids( language=LANG_CODES[language], task="transcribe" ) if forced_ids: gen_kwargs["forced_decoder_ids"] = forced_ids except Exception as e: print(f"⚠️ Specialized language forcing failed: {e}") # Generate transcription with torch.no_grad(), amp_ctx(): generated_ids = model_components["model"].generate(**gen_kwargs) # Decode result transcription = model_components["processor"].batch_decode( generated_ids, skip_special_tokens=True )[0] return transcription.strip() except Exception as e: return f"Specialized transcription error: {str(e)}" # ---------------- ANALYSIS ---------------- # def compute_metrics(reference, hypothesis): """Compute WER and CER with error handling""" try: # Clean up texts ref_clean = reference.strip() hyp_clean = hypothesis.strip() if not ref_clean or not hyp_clean: return 1.0, 1.0 # Compute WER and CER wer = jiwer.wer(ref_clean, hyp_clean) cer = jiwer.cer(ref_clean, hyp_clean) return wer, cer except Exception as e: print(f"⚠️ Metric computation failed: {e}") return 1.0, 1.0 def get_pronunciation_score(wer, cer): """Convert error rates to intuitive scores and feedback""" # Weighted combination (WER is more important) combined_error = (wer * 0.7) + (cer * 0.3) accuracy = 1 - combined_error if accuracy >= 0.95: return "🏆 Perfect!", "Outstanding pronunciation! Native-like accuracy.", "#d4edda" elif accuracy >= 0.85: return "🎉 Excellent!", "Very good pronunciation with minor variations.", "#d1ecf1" elif accuracy >= 0.70: return "👍 Good!", "Good pronunciation, practice specific sounds.", "#fff3cd" elif accuracy >= 0.50: return "📚 Needs Practice", "Focus on clearer pronunciation and rhythm.", "#f8d7da" else: return "💪 Keep Trying!", "Break down into smaller parts and practice slowly.", "#f5c6cb" def create_detailed_comparison(intended, actual, lang_choice): """Create detailed side-by-side comparison with transliteration""" # Original scripts intended_orig = intended.strip() actual_orig = actual.strip() # Transliterations intended_translit = transliterate_text(intended_orig, lang_choice, to_romanized=True) actual_translit = transliterate_text(actual_orig, lang_choice, to_romanized=True) # Word-level highlighting word_diff_orig = highlight_word_differences(intended_orig, actual_orig) word_diff_translit = highlight_word_differences(intended_translit, actual_translit) # Character-level highlighting char_diff_orig = highlight_char_differences(intended_orig, actual_orig) char_diff_translit = highlight_char_differences(intended_translit, actual_translit) return { "intended_orig": intended_orig, "actual_orig": actual_orig, "intended_translit": intended_translit, "actual_translit": actual_translit, "word_diff_orig": word_diff_orig, "word_diff_translit": word_diff_translit, "char_diff_orig": char_diff_orig, "char_diff_translit": char_diff_translit } def highlight_word_differences(reference, hypothesis): """Highlight word-level differences with colors""" ref_words = reference.split() hyp_words = hypothesis.split() sm = difflib.SequenceMatcher(None, ref_words, hyp_words) html_output = [] for tag, i1, i2, j1, j2 in sm.get_opcodes(): if tag == 'equal': # Correct words - green background html_output.extend([ f"{word}" for word in ref_words[i1:i2] ]) elif tag == 'replace': # Wrong words - red background for reference, orange for hypothesis html_output.extend([ f"{word}" for word in ref_words[i1:i2] ]) html_output.extend([ f"→{word}" for word in hyp_words[j1:j2] ]) elif tag == 'delete': # Missing words - red background html_output.extend([ f"{word}" for word in ref_words[i1:i2] ]) elif tag == 'insert': # Extra words - orange background html_output.extend([ f"+{word}" for word in hyp_words[j1:j2] ]) return " ".join(html_output) def highlight_char_differences(reference, hypothesis): """Highlight character-level differences""" sm = difflib.SequenceMatcher(None, list(reference), list(hypothesis)) html_output = [] for tag, i1, i2, j1, j2 in sm.get_opcodes(): if tag == 'equal': # Correct characters - green html_output.extend([ f"{char}" for char in reference[i1:i2] ]) elif tag in ('replace', 'delete'): # Wrong/missing characters - red with underline html_output.extend([ f"{char}" for char in reference[i1:i2] ]) elif tag == 'insert': # Extra characters - orange html_output.extend([ f"{char}" for char in hypothesis[j1:j2] ]) return "".join(html_output) def analyze_pronunciation_errors(intended, actual, lang_choice): """Provide specific feedback about pronunciation errors""" comparison = create_detailed_comparison(intended, actual, lang_choice) # Analyze error patterns intended_words = intended.split() actual_words = actual.split() error_analysis = [] # Length difference analysis if len(actual_words) < len(intended_words): missing_count = len(intended_words) - len(actual_words) error_analysis.append(f"🔍 You missed {missing_count} word(s). Try speaking more slowly.") elif len(actual_words) > len(intended_words): extra_count = len(actual_words) - len(intended_words) error_analysis.append(f"🔍 You added {extra_count} extra word(s). Focus on the exact sentence.") # Script verification if not is_correct_script(actual, lang_choice): error_analysis.append(f"⚠️ The transcription doesn't contain {lang_choice} script. Check your pronunciation.") # WER/CER based feedback wer, cer = compute_metrics(intended, actual) if wer > 0.5: error_analysis.append("🎯 Focus on pronouncing each word clearly and separately.") elif wer > 0.3: error_analysis.append("🎯 Good overall, but some words need clearer pronunciation.") if cer > 0.3: error_analysis.append("🔤 Pay attention to individual sounds and syllables.") return error_analysis, comparison # ---------------- MAIN FUNCTION ---------------- # @GPU_DECORATOR def compare_pronunciation(audio, language_choice, intended_sentence): """Main function to analyze pronunciation""" if audio is None: return ("❌ Please record audio first", "", "", "", "", "", "", "", "", "", "") if not intended_sentence.strip(): return ("❌ Please generate a sentence first", "", "", "", "", "", "", "", "", "", "") print(f"🔍 Analyzing pronunciation for {language_choice}...") # Get transcriptions from both models primary_result = transcribe_with_primary(audio, language_choice) specialized_result = transcribe_with_specialized(audio, language_choice) # Choose best result (prefer specialized if successful) if not specialized_result.startswith("Specialized") and specialized_result.strip(): best_transcription = specialized_result best_source = "Specialized Model" elif not primary_result.startswith("Primary") and primary_result.strip(): best_transcription = primary_result best_source = "Primary Model" else: return ( f"❌ Both models failed:\nPrimary: {primary_result}\nSpecialized: {specialized_result}", "", "", "", "", "", "", "", "", "", "" ) # Analyze pronunciation error_analysis, comparison = analyze_pronunciation_errors( intended_sentence, best_transcription, language_choice ) # Compute metrics wer, cer = compute_metrics(intended_sentence, best_transcription) score, feedback, color = get_pronunciation_score(wer, cer) # Create status message status_msg = f"""✅ Analysis Complete! {score} {feedback} 🤖 Best result from: {best_source} 📊 Word Accuracy: {(1-wer)*100:.1f}% 📈 Character Accuracy: {(1-cer)*100:.1f}% 🔍 Analysis: """ + "\n".join(error_analysis) return ( status_msg, primary_result, specialized_result, f"{wer:.3f} ({(1-wer)*100:.1f}%)", f"{cer:.3f} ({(1-cer)*100:.1f}%)", comparison["intended_orig"], comparison["actual_orig"], comparison["intended_translit"], comparison["actual_translit"], comparison["word_diff_orig"], comparison["char_diff_orig"] ) # ---------------- UI ---------------- # def create_interface(): with gr.Blocks(title="Enhanced Pronunciation Comparator", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎙️ Enhanced Pronunciation Comparator **Perfect your pronunciation in English, Tamil, Malayalam, and Hindi!** This tool uses specialized AI models to give you detailed feedback on your pronunciation, including transliteration to help you understand exactly where you need improvement. ### How to use: 1. 🌐 Select your target language 2. 🎲 Generate a practice sentence 3. 🎤 Record yourself saying the sentence clearly 4. 🔍 Get detailed pronunciation analysis with transliteration """) with gr.Row(): with gr.Column(scale=2): language_dropdown = gr.Dropdown( choices=list(LANG_CODES.keys()), value="Tamil", label="🌐 Select Language" ) with gr.Column(scale=1): generate_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary") intended_textbox = gr.Textbox( label="📝 Practice Sentence", interactive=False, lines=2, placeholder="Click 'Generate Practice Sentence' to get started..." ) audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="🎤 Record Your Pronunciation" ) analyze_btn = gr.Button("🔍 Analyze Pronunciation", variant="secondary", size="lg") with gr.Row(): status_output = gr.Textbox( label="📊 Analysis Results", interactive=False, lines=8 ) with gr.Accordion("🤖 Model Outputs", open=False): with gr.Row(): primary_output = gr.Textbox(label="Primary Model (IndicWhisper)", interactive=False) specialized_output = gr.Textbox(label="Specialized Model", interactive=False) with gr.Accordion("📈 Detailed Metrics", open=False): with gr.Row(): wer_output = gr.Textbox(label="Word Error Rate", interactive=False) cer_output = gr.Textbox(label="Character Error Rate", interactive=False) gr.Markdown("### 🔍 Detailed Comparison") with gr.Row(): with gr.Column(): gr.Markdown("#### 📝 Original Script") intended_orig = gr.Textbox(label="🎯 Target Text", interactive=False) actual_orig = gr.Textbox(label="🗣️ What You Said", interactive=False) with gr.Column(): gr.Markdown("#### 🔤 Romanized (Transliterated)") intended_translit = gr.Textbox(label="🎯 Target (Romanized)", interactive=False) actual_translit = gr.Textbox(label="🗣️ What You Said (Romanized)", interactive=False) gr.Markdown("### 🎨 Visual Comparison") gr.Markdown("**Green** = Correct, **Red** = Wrong/Missing, **Orange** = Added/Substituted") word_diff_html = gr.HTML(label="🔤 Word-by-Word Comparison") char_diff_html = gr.HTML(label="🔍 Character-by-Character Analysis") # Event handlers generate_btn.click( fn=get_random_sentence, inputs=[language_dropdown], outputs=[intended_textbox] ) analyze_btn.click( fn=compare_pronunciation, inputs=[audio_input, language_dropdown, intended_textbox], outputs=[ status_output, primary_output, specialized_output, wer_output, cer_output, intended_orig, actual_orig, intended_translit, actual_translit, word_diff_html, char_diff_html ] ) language_dropdown.change( fn=get_random_sentence, inputs=[language_dropdown], outputs=[intended_textbox] ) gr.Markdown(""" ### 📚 Pro Tips for Better Pronunciation: - **Speak slowly and clearly** - Don't rush through the sentence - **Pronounce each syllable** - Break down complex words - **Check the romanized version** - Use it to understand correct pronunciation - **Practice repeatedly** - Use the same sentence multiple times to track improvement - **Focus on problem areas** - Pay attention to red-highlighted parts - **Record in a quiet environment** - Minimize background noise ### 🎯 Understanding the Feedback: - **Green highlights** = Perfect pronunciation ✅ - **Red highlights** = Missing or mispronounced ❌ - **Orange highlights** = Added or substituted 🔄 - **Transliteration** = Helps you see pronunciation patterns - **Error rates** = Lower is better (0% = perfect) """) return demo # ---------------- LAUNCH ---------------- # if __name__ == "__main__": print("🚀 Starting Enhanced Pronunciation Comparator...") demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )