sudhanm's picture
Update app.py
32688f6 verified
import gradio as gr
import random
import difflib
import jiwer
import torch
from transformers import (
WhisperForConditionalGeneration,
WhisperProcessor,
AutoModelForCausalLM,
AutoTokenizer
)
import spaces
import gc
# ---------------- CONFIG ---------------- #
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_CONFIGS = {
"English": "openai/whisper-large-v2",
"Tamil": "vasista22/whisper-tamil-large-v2",
"Malayalam": "thennal/whisper-medium-ml"
}
LANG_CODES = {
"English": "en",
"Tamil": "ta",
"Malayalam": "ml"
}
SENTENCE_BANK = {
"English": [
"The sun sets over the horizon.",
"Learning languages is fun.",
"I like to drink coffee in the morning.",
"Technology helps us communicate better.",
"Reading books expands our knowledge."
],
"Tamil": [
"இன்று நல்ல வானிலை உள்ளது.",
"நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
"எனக்கு புத்தகம் படிக்க விருப்பம்.",
"தமிழ் மொழி மிகவும் அழகானது.",
"அன்னை தமிழ் எங்கள் தாய்மொழி."
],
"Malayalam": [
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
"ഇന്ന് മഴപെയ്യുന്നു.",
"ഞാൻ പുസ്തകം വായിക്കുന്നു.",
"കേരളം എന്റെ സ്വന്തം നാടാണ്.",
"സംഗീതം ജീവിതത്തിന്റെ ഭാഗമാണ്."
]
}
# ---------------- MODELS ---------------- #
current_whisper_model = {"language": None, "model": None, "processor": None}
qwen_model = {"model": None, "tokenizer": None}
def load_whisper_model(language_choice):
"""Load Whisper model for the selected language"""
global current_whisper_model
if current_whisper_model["language"] == language_choice and current_whisper_model["model"] is not None:
return current_whisper_model["model"], current_whisper_model["processor"]
# Clear previous model
if current_whisper_model["model"] is not None:
del current_whisper_model["model"]
del current_whisper_model["processor"]
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
# Load new model
model_id = MODEL_CONFIGS[language_choice]
print(f"Loading Whisper model: {model_id}")
try:
model = WhisperForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float32
).to(DEVICE)
processor = WhisperProcessor.from_pretrained(model_id)
current_whisper_model = {
"language": language_choice,
"model": model,
"processor": processor
}
print(f"✓ Whisper model loaded successfully")
return model, processor
except Exception as e:
print(f"✗ Error loading Whisper model: {e}")
# Fallback to base model
model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-base", torch_dtype=torch.float32
).to(DEVICE)
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
current_whisper_model = {
"language": language_choice,
"model": model,
"processor": processor
}
return model, processor
def load_qwen_model():
"""Load Qwen2.5-1.5B-Instruct for transliteration"""
global qwen_model
if qwen_model["model"] is not None:
return qwen_model["model"], qwen_model["tokenizer"]
try:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
print(f"Loading Qwen model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None
)
if DEVICE == "cpu":
model = model.to(DEVICE)
model.eval()
qwen_model = {"model": model, "tokenizer": tokenizer}
print(f"✓ Qwen model loaded successfully")
return model, tokenizer
except Exception as e:
print(f"✗ Failed to load Qwen model: {e}")
return None, None
# ---------------- TRANSLITERATION ---------------- #
def transliterate_with_qwen(text, source_lang):
"""Use Qwen for natural transliteration"""
if source_lang == "English" or not text.strip():
return text
model, tokenizer = load_qwen_model()
if model is None or tokenizer is None:
return get_simple_transliteration(text, source_lang) # Simple fallback
try:
# Create better prompts with examples
if source_lang == "Tamil":
system_prompt = "You are a Tamil transliteration expert. Convert Tamil script to English letters (Thanglish) like how Tamil people type on phones."
user_prompt = f"""Convert this Tamil text to Thanglish using English letters:
Tamil: நான் தமிழ் படிக்கிறேன்
Thanglish: naan tamil padikkiren
Tamil: {text}
Thanglish:"""
else: # Malayalam
system_prompt = "You are a Malayalam transliteration expert. Convert Malayalam script to English letters (Manglish) like how Malayalam people type on phones."
user_prompt = f"""Convert this Malayalam text to Manglish using English letters:
Malayalam: ഞാൻ മലയാളം പഠിക്കുന്നു
Manglish: njan malayalam padikkunnu
Malayalam: {text}
Manglish:"""
# Format for Qwen
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = inputs.to(DEVICE)
# Generate with better parameters
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.3,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2
)
# Extract response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response[len(prompt):].strip()
# Clean response - remove any remaining script characters
import re
response = response.split('\n')[0].strip() # Take first line
response = re.sub(r'[^\x00-\x7F]+', '', response) # Remove non-ASCII (script chars)
response = response.strip()
# Validate response (should not contain original script)
if source_lang == "Malayalam" and any(char in response for char in "അആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരലവശഷസഹളഴറ"):
return get_simple_transliteration(text, source_lang)
elif source_lang == "Tamil" and any(char in response for char in "அஆஇஈஉஊஎஏஐஒஓஔகஙசஞடணதநபமயரலவழளற"):
return get_simple_transliteration(text, source_lang)
return response if response else get_simple_transliteration(text, source_lang)
except Exception as e:
print(f"Qwen transliteration error: {e}")
return get_simple_transliteration(text, source_lang)
def get_simple_transliteration(text, lang_choice):
"""Simple transliteration if Qwen fails"""
# Basic word-level mappings for common words
if lang_choice == "Malayalam":
word_map = {
"കേരളം": "kerala",
"എന്റെ": "ente",
"സ്വന്തം": "swantham",
"നാടാണ്": "naadaan",
"എനിക്ക്": "enikku",
"മലയാളം": "malayalam",
"വളരെ": "valare",
"ഇഷ്ടമാണ്": "ishtamaan",
"ഞാൻ": "njan",
"പുസ്തകം": "pusthakam",
"വായിക്കുന്നു": "vaayikkunnu"
}
elif lang_choice == "Tamil":
word_map = {
"அன்னை": "annai",
"தமிழ்": "tamil",
"எங்கள்": "engal",
"தாய்மொழி": "thaaimozhi",
"நான்": "naan",
"இன்று": "indru",
"நல்ல": "nalla",
"வானிலை": "vaanilai"
}
else:
return text
# Simple word replacement
words = text.split()
result_words = []
for word in words:
# Remove punctuation for lookup
clean_word = word.rstrip('.,!?')
punct = word[len(clean_word):]
if clean_word in word_map:
result_words.append(word_map[clean_word] + punct)
else:
# For unknown words, try basic phonetic conversion
result_words.append(basic_phonetic_convert(clean_word, lang_choice) + punct)
return ' '.join(result_words)
def basic_phonetic_convert(word, lang_choice):
"""Very basic phonetic conversion for unknown words"""
# This is a minimal fallback - just remove complex characters
import re
if lang_choice == "Malayalam":
# Replace some common Malayalam characters with approximate sounds
result = word.replace('ം', 'm').replace('ൺ', 'n').replace('ൻ', 'n')
result = re.sub(r'[^\x00-\x7F]+', '', result) # Remove remaining script chars
return result if result else "unknown"
elif lang_choice == "Tamil":
result = re.sub(r'[^\x00-\x7F]+', '', word) # Remove script chars
return result if result else "unknown"
return word
# ---------------- SPEECH RECOGNITION ---------------- #
@spaces.GPU
def transcribe_audio(audio_path, language_choice):
"""Transcribe audio using Whisper"""
model, processor = load_whisper_model(language_choice)
lang_code = LANG_CODES[language_choice]
# Load audio
import librosa
audio, sr = librosa.load(audio_path, sr=16000)
# Process audio
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(DEVICE, dtype=next(model.parameters()).dtype)
# Generate transcription
with torch.no_grad():
try:
forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe")
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_length=448,
num_beams=5,
temperature=0.0
)
except:
predicted_ids = model.generate(
input_features,
max_length=448,
num_beams=5,
temperature=0.0
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription.strip()
# ---------------- FEEDBACK SYSTEM ---------------- #
def normalize_text_for_comparison(text):
"""Remove punctuation and normalize text for fair comparison"""
import string
# Remove punctuation and extra spaces
text = text.translate(str.maketrans('', '', string.punctuation))
text = ' '.join(text.split()) # Normalize spaces
return text.lower()
def create_feedback(intended, actual, lang_choice):
"""Create simple feedback comparison with tables"""
# Get transliterations
intended_roman = transliterate_with_qwen(intended, lang_choice)
actual_roman = transliterate_with_qwen(actual, lang_choice)
# Normalize for comparison (remove punctuation)
intended_normalized = normalize_text_for_comparison(intended)
actual_normalized = normalize_text_for_comparison(actual)
# Calculate accuracy
intended_words = intended_normalized.split()
actual_words = actual_normalized.split()
# Simple word-level accuracy
sm = difflib.SequenceMatcher(None, intended_words, actual_words)
accuracy = sm.ratio() * 100
# Create comparison data for table
comparison_data = [
["Target Text", intended],
["Target (Romanized)", intended_roman],
["Your Speech", actual],
["Your Speech (Romanized)", actual_roman],
["Accuracy Score", f"{accuracy:.1f}%"]
]
# Find incorrect words for pronunciation table
wrong_pronunciations = []
# Get word-level differences
for tag, i1, i2, j1, j2 in sm.get_opcodes():
if tag == 'replace':
# Words that were pronounced differently
for idx in range(max(i2-i1, j2-j1)):
expected_word = intended_words[i1 + idx] if (i1 + idx) < i2 else ""
actual_word = actual_words[j1 + idx] if (j1 + idx) < j2 else ""
if expected_word and actual_word and expected_word != actual_word:
# Get romanized versions
expected_roman = transliterate_with_qwen(expected_word, lang_choice)
actual_roman = transliterate_with_qwen(actual_word, lang_choice)
wrong_pronunciations.append([
expected_word,
expected_roman,
actual_word,
actual_roman
])
elif tag == 'delete':
# Missing words
for idx in range(i2-i1):
expected_word = intended_words[i1 + idx]
expected_roman = transliterate_with_qwen(expected_word, lang_choice)
wrong_pronunciations.append([
expected_word,
expected_roman,
"(Not spoken)",
""
])
elif tag == 'insert':
# Extra words
for idx in range(j2-j1):
actual_word = actual_words[j1 + idx]
actual_roman = transliterate_with_qwen(actual_word, lang_choice)
wrong_pronunciations.append([
"(Not expected)",
"",
actual_word,
actual_roman
])
# Create motivational message
if accuracy >= 95:
message = "🎉 Outstanding! Perfect pronunciation!"
elif accuracy >= 85:
message = "🌟 Excellent! Very natural sounding!"
elif accuracy >= 70:
message = "👍 Good job! Your pronunciation is improving!"
elif accuracy >= 50:
message = "📚 Getting there! Focus on the highlighted sounds!"
else:
message = "💪 Keep practicing! Every attempt makes you better!"
return comparison_data, wrong_pronunciations, message, accuracy
# ---------------- MAIN FUNCTION ---------------- #
@spaces.GPU
def analyze_pronunciation(audio, lang_choice, intended_text):
"""Main function to analyze pronunciation"""
if audio is None or not intended_text.strip():
return "⚠️ Please record audio and generate a sentence first.", "", "", [], [], ""
try:
# Extract original sentence (remove romanization if present)
if "🔤" in intended_text:
intended_sentence = intended_text.split("🔤")[0].strip()
else:
intended_sentence = intended_text.strip()
# Transcribe audio
actual_text = transcribe_audio(audio, lang_choice)
if not actual_text.strip():
return "⚠️ No speech detected. Please try recording again.", "", "", [], [], ""
# Calculate metrics
wer_val = jiwer.wer(intended_sentence, actual_text)
cer_val = jiwer.cer(intended_sentence, actual_text)
# Get romanizations
actual_roman = transliterate_with_qwen(actual_text, lang_choice)
# Create feedback tables
comparison_data, wrong_pronunciations, message, accuracy = create_feedback(intended_sentence, actual_text, lang_choice)
return actual_text, actual_roman, f"{wer_val:.1%}", comparison_data, wrong_pronunciations, message
except Exception as e:
return f"❌ Error: {str(e)}", "", "", [], [], ""
# ---------------- HELPERS ---------------- #
def get_random_sentence_with_transliteration(language_choice):
"""Get a random sentence with its transliteration"""
sentence = random.choice(SENTENCE_BANK[language_choice])
if language_choice in ["Tamil", "Malayalam"]:
transliteration = transliterate_with_qwen(sentence, language_choice)
combined = f"{sentence}\n\n🔤 {transliteration}"
return combined
return sentence
# ---------------- UI ---------------- #
with gr.Blocks(title="AI Pronunciation Coach", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎙️ AI Pronunciation Coach
### Practice English, Tamil & Malayalam with AI feedback powered by Gemma-3-4B-IT
**Features:**
- ✨ **Smart Transliteration**: Natural Thanglish/Manglish using Gemma-3-4B-IT (proven best)
- 🎯 **Accurate Recognition**: Language-specific Whisper models
- 📊 **Smart Analysis**: Punctuation-aware comparison with correction tables
**How to use:**
1. Select your language
2. Generate a practice sentence
3. Record yourself reading it aloud
4. Get instant feedback with detailed analysis!
""")
with gr.Row():
lang_choice = gr.Dropdown(
choices=list(LANG_CODES.keys()),
value="Malayalam",
label="🌍 Choose Language"
)
gen_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary")
intended_display = gr.Textbox(
label="📝 Practice Sentence",
interactive=False,
placeholder="Click 'Generate Practice Sentence' to get started...",
lines=3
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="🎤 Record Your Pronunciation"
)
analyze_btn = gr.Button("🔍 Analyze My Pronunciation", variant="primary", size="lg")
with gr.Row():
actual_out = gr.Textbox(label="🗣️ What You Said", interactive=False)
actual_roman_out = gr.Textbox(label="🔤 Your Pronunciation (Romanized)", interactive=False)
wer_out = gr.Textbox(label="📊 Word Error Rate", interactive=False)
# Analysis tables
gr.Markdown("### 📊 Analysis Results")
with gr.Row():
with gr.Column():
comparison_table = gr.Dataframe(
headers=["Metric", "Value"],
label="📋 Overall Comparison",
interactive=False
)
with gr.Column():
pronunciation_table = gr.Dataframe(
headers=["Expected Word", "Expected (Romanized)", "You Said", "You Said (Romanized)"],
label="❌ Pronunciation Corrections Needed",
interactive=False
)
feedback_message = gr.Textbox(label="💬 Feedback", interactive=False)
# Event handlers
gen_btn.click(
fn=get_random_sentence_with_transliteration,
inputs=[lang_choice],
outputs=[intended_display]
)
analyze_btn.click(
fn=analyze_pronunciation,
inputs=[audio_input, lang_choice, intended_display],
outputs=[actual_out, actual_roman_out, wer_out, comparison_table, pronunciation_table, feedback_message]
)
if __name__ == "__main__":
demo.launch()