import gradio as gr import random import difflib import jiwer import torch from transformers import ( WhisperForConditionalGeneration, WhisperProcessor, AutoModelForCausalLM, AutoTokenizer ) import spaces import gc # ---------------- CONFIG ---------------- # DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_CONFIGS = { "English": "openai/whisper-large-v2", "Tamil": "vasista22/whisper-tamil-large-v2", "Malayalam": "thennal/whisper-medium-ml" } LANG_CODES = { "English": "en", "Tamil": "ta", "Malayalam": "ml" } SENTENCE_BANK = { "English": [ "The sun sets over the horizon.", "Learning languages is fun.", "I like to drink coffee in the morning.", "Technology helps us communicate better.", "Reading books expands our knowledge." ], "Tamil": [ "இன்று நல்ல வானிலை உள்ளது.", "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", "எனக்கு புத்தகம் படிக்க விருப்பம்.", "தமிழ் மொழி மிகவும் அழகானது.", "அன்னை தமிழ் எங்கள் தாய்மொழி." ], "Malayalam": [ "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", "ഇന്ന് മഴപെയ്യുന്നു.", "ഞാൻ പുസ്തകം വായിക്കുന്നു.", "കേരളം എന്റെ സ്വന്തം നാടാണ്.", "സംഗീതം ജീവിതത്തിന്റെ ഭാഗമാണ്." ] } # ---------------- MODELS ---------------- # current_whisper_model = {"language": None, "model": None, "processor": None} qwen_model = {"model": None, "tokenizer": None} def load_whisper_model(language_choice): """Load Whisper model for the selected language""" global current_whisper_model if current_whisper_model["language"] == language_choice and current_whisper_model["model"] is not None: return current_whisper_model["model"], current_whisper_model["processor"] # Clear previous model if current_whisper_model["model"] is not None: del current_whisper_model["model"] del current_whisper_model["processor"] gc.collect() if DEVICE == "cuda": torch.cuda.empty_cache() # Load new model model_id = MODEL_CONFIGS[language_choice] print(f"Loading Whisper model: {model_id}") try: model = WhisperForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float32 ).to(DEVICE) processor = WhisperProcessor.from_pretrained(model_id) current_whisper_model = { "language": language_choice, "model": model, "processor": processor } print(f"✓ Whisper model loaded successfully") return model, processor except Exception as e: print(f"✗ Error loading Whisper model: {e}") # Fallback to base model model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-base", torch_dtype=torch.float32 ).to(DEVICE) processor = WhisperProcessor.from_pretrained("openai/whisper-base") current_whisper_model = { "language": language_choice, "model": model, "processor": processor } return model, processor def load_qwen_model(): """Load Qwen2.5-1.5B-Instruct for transliteration""" global qwen_model if qwen_model["model"] is not None: return qwen_model["model"], qwen_model["tokenizer"] try: model_name = "Qwen/Qwen2.5-1.5B-Instruct" print(f"Loading Qwen model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto" if DEVICE == "cuda" else None ) if DEVICE == "cpu": model = model.to(DEVICE) model.eval() qwen_model = {"model": model, "tokenizer": tokenizer} print(f"✓ Qwen model loaded successfully") return model, tokenizer except Exception as e: print(f"✗ Failed to load Qwen model: {e}") return None, None # ---------------- TRANSLITERATION ---------------- # def transliterate_with_qwen(text, source_lang): """Use Qwen for natural transliteration""" if source_lang == "English" or not text.strip(): return text model, tokenizer = load_qwen_model() if model is None or tokenizer is None: return get_simple_transliteration(text, source_lang) # Simple fallback try: # Create better prompts with examples if source_lang == "Tamil": system_prompt = "You are a Tamil transliteration expert. Convert Tamil script to English letters (Thanglish) like how Tamil people type on phones." user_prompt = f"""Convert this Tamil text to Thanglish using English letters: Tamil: நான் தமிழ் படிக்கிறேன் Thanglish: naan tamil padikkiren Tamil: {text} Thanglish:""" else: # Malayalam system_prompt = "You are a Malayalam transliteration expert. Convert Malayalam script to English letters (Manglish) like how Malayalam people type on phones." user_prompt = f"""Convert this Malayalam text to Manglish using English letters: Malayalam: ഞാൻ മലയാളം പഠിക്കുന്നു Manglish: njan malayalam padikkunnu Malayalam: {text} Manglish:""" # Format for Qwen messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = inputs.to(DEVICE) # Generate with better parameters with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.3, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.2 ) # Extract response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_response[len(prompt):].strip() # Clean response - remove any remaining script characters import re response = response.split('\n')[0].strip() # Take first line response = re.sub(r'[^\x00-\x7F]+', '', response) # Remove non-ASCII (script chars) response = response.strip() # Validate response (should not contain original script) if source_lang == "Malayalam" and any(char in response for char in "അആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരലവശഷസഹളഴറ"): return get_simple_transliteration(text, source_lang) elif source_lang == "Tamil" and any(char in response for char in "அஆஇஈஉஊஎஏஐஒஓஔகஙசஞடணதநபமயரலவழளற"): return get_simple_transliteration(text, source_lang) return response if response else get_simple_transliteration(text, source_lang) except Exception as e: print(f"Qwen transliteration error: {e}") return get_simple_transliteration(text, source_lang) def get_simple_transliteration(text, lang_choice): """Simple transliteration if Qwen fails""" # Basic word-level mappings for common words if lang_choice == "Malayalam": word_map = { "കേരളം": "kerala", "എന്റെ": "ente", "സ്വന്തം": "swantham", "നാടാണ്": "naadaan", "എനിക്ക്": "enikku", "മലയാളം": "malayalam", "വളരെ": "valare", "ഇഷ്ടമാണ്": "ishtamaan", "ഞാൻ": "njan", "പുസ്തകം": "pusthakam", "വായിക്കുന്നു": "vaayikkunnu" } elif lang_choice == "Tamil": word_map = { "அன்னை": "annai", "தமிழ்": "tamil", "எங்கள்": "engal", "தாய்மொழி": "thaaimozhi", "நான்": "naan", "இன்று": "indru", "நல்ல": "nalla", "வானிலை": "vaanilai" } else: return text # Simple word replacement words = text.split() result_words = [] for word in words: # Remove punctuation for lookup clean_word = word.rstrip('.,!?') punct = word[len(clean_word):] if clean_word in word_map: result_words.append(word_map[clean_word] + punct) else: # For unknown words, try basic phonetic conversion result_words.append(basic_phonetic_convert(clean_word, lang_choice) + punct) return ' '.join(result_words) def basic_phonetic_convert(word, lang_choice): """Very basic phonetic conversion for unknown words""" # This is a minimal fallback - just remove complex characters import re if lang_choice == "Malayalam": # Replace some common Malayalam characters with approximate sounds result = word.replace('ം', 'm').replace('ൺ', 'n').replace('ൻ', 'n') result = re.sub(r'[^\x00-\x7F]+', '', result) # Remove remaining script chars return result if result else "unknown" elif lang_choice == "Tamil": result = re.sub(r'[^\x00-\x7F]+', '', word) # Remove script chars return result if result else "unknown" return word # ---------------- SPEECH RECOGNITION ---------------- # @spaces.GPU def transcribe_audio(audio_path, language_choice): """Transcribe audio using Whisper""" model, processor = load_whisper_model(language_choice) lang_code = LANG_CODES[language_choice] # Load audio import librosa audio, sr = librosa.load(audio_path, sr=16000) # Process audio input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features input_features = input_features.to(DEVICE, dtype=next(model.parameters()).dtype) # Generate transcription with torch.no_grad(): try: forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe") predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids, max_length=448, num_beams=5, temperature=0.0 ) except: predicted_ids = model.generate( input_features, max_length=448, num_beams=5, temperature=0.0 ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription.strip() # ---------------- FEEDBACK SYSTEM ---------------- # def normalize_text_for_comparison(text): """Remove punctuation and normalize text for fair comparison""" import string # Remove punctuation and extra spaces text = text.translate(str.maketrans('', '', string.punctuation)) text = ' '.join(text.split()) # Normalize spaces return text.lower() def create_feedback(intended, actual, lang_choice): """Create simple feedback comparison with tables""" # Get transliterations intended_roman = transliterate_with_qwen(intended, lang_choice) actual_roman = transliterate_with_qwen(actual, lang_choice) # Normalize for comparison (remove punctuation) intended_normalized = normalize_text_for_comparison(intended) actual_normalized = normalize_text_for_comparison(actual) # Calculate accuracy intended_words = intended_normalized.split() actual_words = actual_normalized.split() # Simple word-level accuracy sm = difflib.SequenceMatcher(None, intended_words, actual_words) accuracy = sm.ratio() * 100 # Create comparison data for table comparison_data = [ ["Target Text", intended], ["Target (Romanized)", intended_roman], ["Your Speech", actual], ["Your Speech (Romanized)", actual_roman], ["Accuracy Score", f"{accuracy:.1f}%"] ] # Find incorrect words for pronunciation table wrong_pronunciations = [] # Get word-level differences for tag, i1, i2, j1, j2 in sm.get_opcodes(): if tag == 'replace': # Words that were pronounced differently for idx in range(max(i2-i1, j2-j1)): expected_word = intended_words[i1 + idx] if (i1 + idx) < i2 else "" actual_word = actual_words[j1 + idx] if (j1 + idx) < j2 else "" if expected_word and actual_word and expected_word != actual_word: # Get romanized versions expected_roman = transliterate_with_qwen(expected_word, lang_choice) actual_roman = transliterate_with_qwen(actual_word, lang_choice) wrong_pronunciations.append([ expected_word, expected_roman, actual_word, actual_roman ]) elif tag == 'delete': # Missing words for idx in range(i2-i1): expected_word = intended_words[i1 + idx] expected_roman = transliterate_with_qwen(expected_word, lang_choice) wrong_pronunciations.append([ expected_word, expected_roman, "(Not spoken)", "" ]) elif tag == 'insert': # Extra words for idx in range(j2-j1): actual_word = actual_words[j1 + idx] actual_roman = transliterate_with_qwen(actual_word, lang_choice) wrong_pronunciations.append([ "(Not expected)", "", actual_word, actual_roman ]) # Create motivational message if accuracy >= 95: message = "🎉 Outstanding! Perfect pronunciation!" elif accuracy >= 85: message = "🌟 Excellent! Very natural sounding!" elif accuracy >= 70: message = "👍 Good job! Your pronunciation is improving!" elif accuracy >= 50: message = "📚 Getting there! Focus on the highlighted sounds!" else: message = "💪 Keep practicing! Every attempt makes you better!" return comparison_data, wrong_pronunciations, message, accuracy # ---------------- MAIN FUNCTION ---------------- # @spaces.GPU def analyze_pronunciation(audio, lang_choice, intended_text): """Main function to analyze pronunciation""" if audio is None or not intended_text.strip(): return "⚠️ Please record audio and generate a sentence first.", "", "", [], [], "" try: # Extract original sentence (remove romanization if present) if "🔤" in intended_text: intended_sentence = intended_text.split("🔤")[0].strip() else: intended_sentence = intended_text.strip() # Transcribe audio actual_text = transcribe_audio(audio, lang_choice) if not actual_text.strip(): return "⚠️ No speech detected. Please try recording again.", "", "", [], [], "" # Calculate metrics wer_val = jiwer.wer(intended_sentence, actual_text) cer_val = jiwer.cer(intended_sentence, actual_text) # Get romanizations actual_roman = transliterate_with_qwen(actual_text, lang_choice) # Create feedback tables comparison_data, wrong_pronunciations, message, accuracy = create_feedback(intended_sentence, actual_text, lang_choice) return actual_text, actual_roman, f"{wer_val:.1%}", comparison_data, wrong_pronunciations, message except Exception as e: return f"❌ Error: {str(e)}", "", "", [], [], "" # ---------------- HELPERS ---------------- # def get_random_sentence_with_transliteration(language_choice): """Get a random sentence with its transliteration""" sentence = random.choice(SENTENCE_BANK[language_choice]) if language_choice in ["Tamil", "Malayalam"]: transliteration = transliterate_with_qwen(sentence, language_choice) combined = f"{sentence}\n\n🔤 {transliteration}" return combined return sentence # ---------------- UI ---------------- # with gr.Blocks(title="AI Pronunciation Coach", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎙️ AI Pronunciation Coach ### Practice English, Tamil & Malayalam with AI feedback powered by Gemma-3-4B-IT **Features:** - ✨ **Smart Transliteration**: Natural Thanglish/Manglish using Gemma-3-4B-IT (proven best) - 🎯 **Accurate Recognition**: Language-specific Whisper models - 📊 **Smart Analysis**: Punctuation-aware comparison with correction tables **How to use:** 1. Select your language 2. Generate a practice sentence 3. Record yourself reading it aloud 4. Get instant feedback with detailed analysis! """) with gr.Row(): lang_choice = gr.Dropdown( choices=list(LANG_CODES.keys()), value="Malayalam", label="🌍 Choose Language" ) gen_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary") intended_display = gr.Textbox( label="📝 Practice Sentence", interactive=False, placeholder="Click 'Generate Practice Sentence' to get started...", lines=3 ) audio_input = gr.Audio( sources=["microphone"], type="filepath", label="🎤 Record Your Pronunciation" ) analyze_btn = gr.Button("🔍 Analyze My Pronunciation", variant="primary", size="lg") with gr.Row(): actual_out = gr.Textbox(label="🗣️ What You Said", interactive=False) actual_roman_out = gr.Textbox(label="🔤 Your Pronunciation (Romanized)", interactive=False) wer_out = gr.Textbox(label="📊 Word Error Rate", interactive=False) # Analysis tables gr.Markdown("### 📊 Analysis Results") with gr.Row(): with gr.Column(): comparison_table = gr.Dataframe( headers=["Metric", "Value"], label="📋 Overall Comparison", interactive=False ) with gr.Column(): pronunciation_table = gr.Dataframe( headers=["Expected Word", "Expected (Romanized)", "You Said", "You Said (Romanized)"], label="❌ Pronunciation Corrections Needed", interactive=False ) feedback_message = gr.Textbox(label="💬 Feedback", interactive=False) # Event handlers gen_btn.click( fn=get_random_sentence_with_transliteration, inputs=[lang_choice], outputs=[intended_display] ) analyze_btn.click( fn=analyze_pronunciation, inputs=[audio_input, lang_choice, intended_display], outputs=[actual_out, actual_roman_out, wer_out, comparison_table, pronunciation_table, feedback_message] ) if __name__ == "__main__": demo.launch()