Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import random | |
import difflib | |
import jiwer | |
import torch | |
from transformers import ( | |
WhisperForConditionalGeneration, | |
WhisperProcessor, | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) | |
import spaces | |
import gc | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_CONFIGS = { | |
"English": "openai/whisper-large-v2", | |
"Tamil": "vasista22/whisper-tamil-large-v2", | |
"Malayalam": "thennal/whisper-medium-ml" | |
} | |
LANG_CODES = { | |
"English": "en", | |
"Tamil": "ta", | |
"Malayalam": "ml" | |
} | |
SENTENCE_BANK = { | |
"English": [ | |
"The sun sets over the horizon.", | |
"Learning languages is fun.", | |
"I like to drink coffee in the morning.", | |
"Technology helps us communicate better.", | |
"Reading books expands our knowledge." | |
], | |
"Tamil": [ | |
"இன்று நல்ல வானிலை உள்ளது.", | |
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.", | |
"எனக்கு புத்தகம் படிக்க விருப்பம்.", | |
"தமிழ் மொழி மிகவும் அழகானது.", | |
"அன்னை தமிழ் எங்கள் தாய்மொழி." | |
], | |
"Malayalam": [ | |
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", | |
"ഇന്ന് മഴപെയ്യുന്നു.", | |
"ഞാൻ പുസ്തകം വായിക്കുന്നു.", | |
"കേരളം എന്റെ സ്വന്തം നാടാണ്.", | |
"സംഗീതം ജീവിതത്തിന്റെ ഭാഗമാണ്." | |
] | |
} | |
# ---------------- MODELS ---------------- # | |
current_whisper_model = {"language": None, "model": None, "processor": None} | |
qwen_model = {"model": None, "tokenizer": None} | |
def load_whisper_model(language_choice): | |
"""Load Whisper model for the selected language""" | |
global current_whisper_model | |
if current_whisper_model["language"] == language_choice and current_whisper_model["model"] is not None: | |
return current_whisper_model["model"], current_whisper_model["processor"] | |
# Clear previous model | |
if current_whisper_model["model"] is not None: | |
del current_whisper_model["model"] | |
del current_whisper_model["processor"] | |
gc.collect() | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
# Load new model | |
model_id = MODEL_CONFIGS[language_choice] | |
print(f"Loading Whisper model: {model_id}") | |
try: | |
model = WhisperForConditionalGeneration.from_pretrained( | |
model_id, torch_dtype=torch.float32 | |
).to(DEVICE) | |
processor = WhisperProcessor.from_pretrained(model_id) | |
current_whisper_model = { | |
"language": language_choice, | |
"model": model, | |
"processor": processor | |
} | |
print(f"✓ Whisper model loaded successfully") | |
return model, processor | |
except Exception as e: | |
print(f"✗ Error loading Whisper model: {e}") | |
# Fallback to base model | |
model = WhisperForConditionalGeneration.from_pretrained( | |
"openai/whisper-base", torch_dtype=torch.float32 | |
).to(DEVICE) | |
processor = WhisperProcessor.from_pretrained("openai/whisper-base") | |
current_whisper_model = { | |
"language": language_choice, | |
"model": model, | |
"processor": processor | |
} | |
return model, processor | |
def load_qwen_model(): | |
"""Load Qwen2.5-1.5B-Instruct for transliteration""" | |
global qwen_model | |
if qwen_model["model"] is not None: | |
return qwen_model["model"], qwen_model["tokenizer"] | |
try: | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
print(f"Loading Qwen model: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
device_map="auto" if DEVICE == "cuda" else None | |
) | |
if DEVICE == "cpu": | |
model = model.to(DEVICE) | |
model.eval() | |
qwen_model = {"model": model, "tokenizer": tokenizer} | |
print(f"✓ Qwen model loaded successfully") | |
return model, tokenizer | |
except Exception as e: | |
print(f"✗ Failed to load Qwen model: {e}") | |
return None, None | |
# ---------------- TRANSLITERATION ---------------- # | |
def transliterate_with_qwen(text, source_lang): | |
"""Use Qwen for natural transliteration""" | |
if source_lang == "English" or not text.strip(): | |
return text | |
model, tokenizer = load_qwen_model() | |
if model is None or tokenizer is None: | |
return get_simple_transliteration(text, source_lang) # Simple fallback | |
try: | |
# Create better prompts with examples | |
if source_lang == "Tamil": | |
system_prompt = "You are a Tamil transliteration expert. Convert Tamil script to English letters (Thanglish) like how Tamil people type on phones." | |
user_prompt = f"""Convert this Tamil text to Thanglish using English letters: | |
Tamil: நான் தமிழ் படிக்கிறேன் | |
Thanglish: naan tamil padikkiren | |
Tamil: {text} | |
Thanglish:""" | |
else: # Malayalam | |
system_prompt = "You are a Malayalam transliteration expert. Convert Malayalam script to English letters (Manglish) like how Malayalam people type on phones." | |
user_prompt = f"""Convert this Malayalam text to Manglish using English letters: | |
Malayalam: ഞാൻ മലയാളം പഠിക്കുന്നു | |
Manglish: njan malayalam padikkunnu | |
Malayalam: {text} | |
Manglish:""" | |
# Format for Qwen | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = inputs.to(DEVICE) | |
# Generate with better parameters | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
temperature=0.3, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.2 | |
) | |
# Extract response | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = full_response[len(prompt):].strip() | |
# Clean response - remove any remaining script characters | |
import re | |
response = response.split('\n')[0].strip() # Take first line | |
response = re.sub(r'[^\x00-\x7F]+', '', response) # Remove non-ASCII (script chars) | |
response = response.strip() | |
# Validate response (should not contain original script) | |
if source_lang == "Malayalam" and any(char in response for char in "അആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരലവശഷസഹളഴറ"): | |
return get_simple_transliteration(text, source_lang) | |
elif source_lang == "Tamil" and any(char in response for char in "அஆஇஈஉஊஎஏஐஒஓஔகஙசஞடணதநபமயரலவழளற"): | |
return get_simple_transliteration(text, source_lang) | |
return response if response else get_simple_transliteration(text, source_lang) | |
except Exception as e: | |
print(f"Qwen transliteration error: {e}") | |
return get_simple_transliteration(text, source_lang) | |
def get_simple_transliteration(text, lang_choice): | |
"""Simple transliteration if Qwen fails""" | |
# Basic word-level mappings for common words | |
if lang_choice == "Malayalam": | |
word_map = { | |
"കേരളം": "kerala", | |
"എന്റെ": "ente", | |
"സ്വന്തം": "swantham", | |
"നാടാണ്": "naadaan", | |
"എനിക്ക്": "enikku", | |
"മലയാളം": "malayalam", | |
"വളരെ": "valare", | |
"ഇഷ്ടമാണ്": "ishtamaan", | |
"ഞാൻ": "njan", | |
"പുസ്തകം": "pusthakam", | |
"വായിക്കുന്നു": "vaayikkunnu" | |
} | |
elif lang_choice == "Tamil": | |
word_map = { | |
"அன்னை": "annai", | |
"தமிழ்": "tamil", | |
"எங்கள்": "engal", | |
"தாய்மொழி": "thaaimozhi", | |
"நான்": "naan", | |
"இன்று": "indru", | |
"நல்ல": "nalla", | |
"வானிலை": "vaanilai" | |
} | |
else: | |
return text | |
# Simple word replacement | |
words = text.split() | |
result_words = [] | |
for word in words: | |
# Remove punctuation for lookup | |
clean_word = word.rstrip('.,!?') | |
punct = word[len(clean_word):] | |
if clean_word in word_map: | |
result_words.append(word_map[clean_word] + punct) | |
else: | |
# For unknown words, try basic phonetic conversion | |
result_words.append(basic_phonetic_convert(clean_word, lang_choice) + punct) | |
return ' '.join(result_words) | |
def basic_phonetic_convert(word, lang_choice): | |
"""Very basic phonetic conversion for unknown words""" | |
# This is a minimal fallback - just remove complex characters | |
import re | |
if lang_choice == "Malayalam": | |
# Replace some common Malayalam characters with approximate sounds | |
result = word.replace('ം', 'm').replace('ൺ', 'n').replace('ൻ', 'n') | |
result = re.sub(r'[^\x00-\x7F]+', '', result) # Remove remaining script chars | |
return result if result else "unknown" | |
elif lang_choice == "Tamil": | |
result = re.sub(r'[^\x00-\x7F]+', '', word) # Remove script chars | |
return result if result else "unknown" | |
return word | |
# ---------------- SPEECH RECOGNITION ---------------- # | |
def transcribe_audio(audio_path, language_choice): | |
"""Transcribe audio using Whisper""" | |
model, processor = load_whisper_model(language_choice) | |
lang_code = LANG_CODES[language_choice] | |
# Load audio | |
import librosa | |
audio, sr = librosa.load(audio_path, sr=16000) | |
# Process audio | |
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features | |
input_features = input_features.to(DEVICE, dtype=next(model.parameters()).dtype) | |
# Generate transcription | |
with torch.no_grad(): | |
try: | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe") | |
predicted_ids = model.generate( | |
input_features, | |
forced_decoder_ids=forced_decoder_ids, | |
max_length=448, | |
num_beams=5, | |
temperature=0.0 | |
) | |
except: | |
predicted_ids = model.generate( | |
input_features, | |
max_length=448, | |
num_beams=5, | |
temperature=0.0 | |
) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription.strip() | |
# ---------------- FEEDBACK SYSTEM ---------------- # | |
def normalize_text_for_comparison(text): | |
"""Remove punctuation and normalize text for fair comparison""" | |
import string | |
# Remove punctuation and extra spaces | |
text = text.translate(str.maketrans('', '', string.punctuation)) | |
text = ' '.join(text.split()) # Normalize spaces | |
return text.lower() | |
def create_feedback(intended, actual, lang_choice): | |
"""Create simple feedback comparison with tables""" | |
# Get transliterations | |
intended_roman = transliterate_with_qwen(intended, lang_choice) | |
actual_roman = transliterate_with_qwen(actual, lang_choice) | |
# Normalize for comparison (remove punctuation) | |
intended_normalized = normalize_text_for_comparison(intended) | |
actual_normalized = normalize_text_for_comparison(actual) | |
# Calculate accuracy | |
intended_words = intended_normalized.split() | |
actual_words = actual_normalized.split() | |
# Simple word-level accuracy | |
sm = difflib.SequenceMatcher(None, intended_words, actual_words) | |
accuracy = sm.ratio() * 100 | |
# Create comparison data for table | |
comparison_data = [ | |
["Target Text", intended], | |
["Target (Romanized)", intended_roman], | |
["Your Speech", actual], | |
["Your Speech (Romanized)", actual_roman], | |
["Accuracy Score", f"{accuracy:.1f}%"] | |
] | |
# Find incorrect words for pronunciation table | |
wrong_pronunciations = [] | |
# Get word-level differences | |
for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
if tag == 'replace': | |
# Words that were pronounced differently | |
for idx in range(max(i2-i1, j2-j1)): | |
expected_word = intended_words[i1 + idx] if (i1 + idx) < i2 else "" | |
actual_word = actual_words[j1 + idx] if (j1 + idx) < j2 else "" | |
if expected_word and actual_word and expected_word != actual_word: | |
# Get romanized versions | |
expected_roman = transliterate_with_qwen(expected_word, lang_choice) | |
actual_roman = transliterate_with_qwen(actual_word, lang_choice) | |
wrong_pronunciations.append([ | |
expected_word, | |
expected_roman, | |
actual_word, | |
actual_roman | |
]) | |
elif tag == 'delete': | |
# Missing words | |
for idx in range(i2-i1): | |
expected_word = intended_words[i1 + idx] | |
expected_roman = transliterate_with_qwen(expected_word, lang_choice) | |
wrong_pronunciations.append([ | |
expected_word, | |
expected_roman, | |
"(Not spoken)", | |
"" | |
]) | |
elif tag == 'insert': | |
# Extra words | |
for idx in range(j2-j1): | |
actual_word = actual_words[j1 + idx] | |
actual_roman = transliterate_with_qwen(actual_word, lang_choice) | |
wrong_pronunciations.append([ | |
"(Not expected)", | |
"", | |
actual_word, | |
actual_roman | |
]) | |
# Create motivational message | |
if accuracy >= 95: | |
message = "🎉 Outstanding! Perfect pronunciation!" | |
elif accuracy >= 85: | |
message = "🌟 Excellent! Very natural sounding!" | |
elif accuracy >= 70: | |
message = "👍 Good job! Your pronunciation is improving!" | |
elif accuracy >= 50: | |
message = "📚 Getting there! Focus on the highlighted sounds!" | |
else: | |
message = "💪 Keep practicing! Every attempt makes you better!" | |
return comparison_data, wrong_pronunciations, message, accuracy | |
# ---------------- MAIN FUNCTION ---------------- # | |
def analyze_pronunciation(audio, lang_choice, intended_text): | |
"""Main function to analyze pronunciation""" | |
if audio is None or not intended_text.strip(): | |
return "⚠️ Please record audio and generate a sentence first.", "", "", [], [], "" | |
try: | |
# Extract original sentence (remove romanization if present) | |
if "🔤" in intended_text: | |
intended_sentence = intended_text.split("🔤")[0].strip() | |
else: | |
intended_sentence = intended_text.strip() | |
# Transcribe audio | |
actual_text = transcribe_audio(audio, lang_choice) | |
if not actual_text.strip(): | |
return "⚠️ No speech detected. Please try recording again.", "", "", [], [], "" | |
# Calculate metrics | |
wer_val = jiwer.wer(intended_sentence, actual_text) | |
cer_val = jiwer.cer(intended_sentence, actual_text) | |
# Get romanizations | |
actual_roman = transliterate_with_qwen(actual_text, lang_choice) | |
# Create feedback tables | |
comparison_data, wrong_pronunciations, message, accuracy = create_feedback(intended_sentence, actual_text, lang_choice) | |
return actual_text, actual_roman, f"{wer_val:.1%}", comparison_data, wrong_pronunciations, message | |
except Exception as e: | |
return f"❌ Error: {str(e)}", "", "", [], [], "" | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence_with_transliteration(language_choice): | |
"""Get a random sentence with its transliteration""" | |
sentence = random.choice(SENTENCE_BANK[language_choice]) | |
if language_choice in ["Tamil", "Malayalam"]: | |
transliteration = transliterate_with_qwen(sentence, language_choice) | |
combined = f"{sentence}\n\n🔤 {transliteration}" | |
return combined | |
return sentence | |
# ---------------- UI ---------------- # | |
with gr.Blocks(title="AI Pronunciation Coach", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎙️ AI Pronunciation Coach | |
### Practice English, Tamil & Malayalam with AI feedback powered by Gemma-3-4B-IT | |
**Features:** | |
- ✨ **Smart Transliteration**: Natural Thanglish/Manglish using Gemma-3-4B-IT (proven best) | |
- 🎯 **Accurate Recognition**: Language-specific Whisper models | |
- 📊 **Smart Analysis**: Punctuation-aware comparison with correction tables | |
**How to use:** | |
1. Select your language | |
2. Generate a practice sentence | |
3. Record yourself reading it aloud | |
4. Get instant feedback with detailed analysis! | |
""") | |
with gr.Row(): | |
lang_choice = gr.Dropdown( | |
choices=list(LANG_CODES.keys()), | |
value="Malayalam", | |
label="🌍 Choose Language" | |
) | |
gen_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary") | |
intended_display = gr.Textbox( | |
label="📝 Practice Sentence", | |
interactive=False, | |
placeholder="Click 'Generate Practice Sentence' to get started...", | |
lines=3 | |
) | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
type="filepath", | |
label="🎤 Record Your Pronunciation" | |
) | |
analyze_btn = gr.Button("🔍 Analyze My Pronunciation", variant="primary", size="lg") | |
with gr.Row(): | |
actual_out = gr.Textbox(label="🗣️ What You Said", interactive=False) | |
actual_roman_out = gr.Textbox(label="🔤 Your Pronunciation (Romanized)", interactive=False) | |
wer_out = gr.Textbox(label="📊 Word Error Rate", interactive=False) | |
# Analysis tables | |
gr.Markdown("### 📊 Analysis Results") | |
with gr.Row(): | |
with gr.Column(): | |
comparison_table = gr.Dataframe( | |
headers=["Metric", "Value"], | |
label="📋 Overall Comparison", | |
interactive=False | |
) | |
with gr.Column(): | |
pronunciation_table = gr.Dataframe( | |
headers=["Expected Word", "Expected (Romanized)", "You Said", "You Said (Romanized)"], | |
label="❌ Pronunciation Corrections Needed", | |
interactive=False | |
) | |
feedback_message = gr.Textbox(label="💬 Feedback", interactive=False) | |
# Event handlers | |
gen_btn.click( | |
fn=get_random_sentence_with_transliteration, | |
inputs=[lang_choice], | |
outputs=[intended_display] | |
) | |
analyze_btn.click( | |
fn=analyze_pronunciation, | |
inputs=[audio_input, lang_choice, intended_display], | |
outputs=[actual_out, actual_roman_out, wer_out, comparison_table, pronunciation_table, feedback_message] | |
) | |
if __name__ == "__main__": | |
demo.launch() |