Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -29,18 +29,11 @@ LANG_CODES = {
|
|
29 |
"Malayalam": "ml"
|
30 |
}
|
31 |
|
32 |
-
# Updated model configurations
|
33 |
ASR_MODELS = {
|
34 |
"English": "openai/whisper-base.en",
|
35 |
-
"Tamil": "
|
36 |
-
"Malayalam": "
|
37 |
-
}
|
38 |
-
|
39 |
-
# Backup models in case primary ones fail
|
40 |
-
FALLBACK_MODELS = {
|
41 |
-
"English": "openai/whisper-base.en",
|
42 |
-
"Tamil": "openai/whisper-small",
|
43 |
-
"Malayalam": "openai/whisper-small"
|
44 |
}
|
45 |
|
46 |
LANG_PRIMERS = {
|
@@ -95,49 +88,26 @@ SENTENCE_BANK = {
|
|
95 |
asr_models = {}
|
96 |
|
97 |
def load_asr_model(language):
|
98 |
-
"""Load ASR model for specific language
|
99 |
if language not in asr_models:
|
|
|
|
|
|
|
100 |
try:
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
# Try loading the primary model
|
105 |
-
try:
|
106 |
-
processor = AutoProcessor.from_pretrained(model_name)
|
107 |
-
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
108 |
-
model_name,
|
109 |
-
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
110 |
-
low_cpu_mem_usage=True,
|
111 |
-
use_safetensors=True
|
112 |
-
).to(DEVICE)
|
113 |
-
|
114 |
-
asr_models[language] = {"processor": processor, "model": model, "model_name": model_name}
|
115 |
-
print(f"β
Primary ASR model loaded for {language}")
|
116 |
-
return asr_models[language]
|
117 |
-
|
118 |
-
except Exception as e:
|
119 |
-
print(f"β οΈ Primary model failed for {language}: {e}")
|
120 |
-
print(f"π Trying fallback model...")
|
121 |
-
|
122 |
-
# Try fallback model
|
123 |
-
fallback_name = FALLBACK_MODELS[language]
|
124 |
-
processor = WhisperProcessor.from_pretrained(fallback_name)
|
125 |
-
model = WhisperForConditionalGeneration.from_pretrained(
|
126 |
-
fallback_name,
|
127 |
-
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
128 |
-
low_cpu_mem_usage=True
|
129 |
-
).to(DEVICE)
|
130 |
-
|
131 |
-
asr_models[language] = {"processor": processor, "model": model, "model_name": fallback_name}
|
132 |
-
print(f"β
Fallback ASR model loaded for {language}")
|
133 |
-
|
134 |
except Exception as e:
|
135 |
-
print(f"β Failed to load
|
136 |
-
|
137 |
-
if language != "English":
|
138 |
-
print(f"π Using English ASR as final fallback for {language}")
|
139 |
-
load_asr_model("English")
|
140 |
-
asr_models[language] = asr_models["English"]
|
141 |
|
142 |
return asr_models[language]
|
143 |
|
@@ -354,40 +324,55 @@ def get_pronunciation_score(wer_val, cer_val):
|
|
354 |
# ---------------- MAIN FUNCTION ---------------- #
|
355 |
def compare_pronunciation(audio, language_choice, intended_sentence):
|
356 |
"""Main function to compare pronunciation"""
|
|
|
|
|
|
|
|
|
357 |
if audio is None:
|
|
|
358 |
return ("β Please record audio first.", "", "", "", "", "", "", "", "", "", "", "", "")
|
359 |
|
360 |
if not intended_sentence.strip():
|
|
|
361 |
return ("β Please generate a practice sentence first.", "", "", "", "", "", "", "", "", "", "", "", "")
|
362 |
|
363 |
try:
|
364 |
print(f"π Analyzing pronunciation for {language_choice}...")
|
365 |
|
366 |
# Pass 1: Raw transcription
|
|
|
367 |
primer_weak, _ = LANG_PRIMERS[language_choice]
|
368 |
actual_text = transcribe_audio(audio, language_choice, primer_weak, force_language=True)
|
|
|
369 |
|
370 |
# Pass 2: Target-biased transcription with stronger prompt
|
|
|
371 |
_, primer_strong = LANG_PRIMERS[language_choice]
|
372 |
strict_prompt = f"{primer_strong}\nExpected: {intended_sentence}"
|
373 |
corrected_text = transcribe_audio(audio, language_choice, strict_prompt, force_language=True)
|
|
|
374 |
|
375 |
# Handle transcription errors
|
376 |
if actual_text.startswith("Error:"):
|
|
|
377 |
return (f"β {actual_text}", "", "", "", "", "", "", "", "", "", "", "", "")
|
378 |
|
379 |
# Calculate error metrics
|
380 |
try:
|
|
|
381 |
wer_val = jiwer.wer(intended_sentence, actual_text)
|
382 |
cer_val = jiwer.cer(intended_sentence, actual_text)
|
|
|
383 |
except Exception as e:
|
384 |
-
print(f"Error calculating metrics: {e}")
|
385 |
wer_val, cer_val = 1.0, 1.0
|
386 |
|
387 |
# Get pronunciation score and feedback
|
388 |
score_text, feedback = get_pronunciation_score(wer_val, cer_val)
|
|
|
389 |
|
390 |
# Transliterations for both actual and intended
|
|
|
391 |
actual_hk = transliterate_to_hk(actual_text, language_choice)
|
392 |
target_hk = transliterate_to_hk(intended_sentence, language_choice)
|
393 |
|
@@ -396,11 +381,13 @@ def compare_pronunciation(audio, language_choice, intended_sentence):
|
|
396 |
actual_hk = f"β οΈ Expected {language_choice} script, got mixed/other script"
|
397 |
|
398 |
# Visual feedback
|
|
|
399 |
diff_html = highlight_differences(intended_sentence, actual_text)
|
400 |
char_html = char_level_highlight(intended_sentence, actual_text)
|
401 |
|
402 |
# Status message with detailed feedback
|
403 |
status = f"β
Analysis Complete - {score_text}\n㪠{feedback}"
|
|
|
404 |
|
405 |
return (
|
406 |
status,
|
@@ -421,8 +408,10 @@ def compare_pronunciation(audio, language_choice, intended_sentence):
|
|
421 |
|
422 |
except Exception as e:
|
423 |
error_msg = f"β Analysis Error: {str(e)[:200]}"
|
424 |
-
print(f"
|
425 |
-
|
|
|
|
|
426 |
|
427 |
# ---------------- UI ---------------- #
|
428 |
def create_interface():
|
@@ -534,30 +523,28 @@ def create_interface():
|
|
534 |
gr.Markdown("""
|
535 |
---
|
536 |
### π§ Technical Details:
|
537 |
-
- **ASR Models**:
|
|
|
|
|
|
|
|
|
538 |
- **Metrics**: WER (Word Error Rate) and CER (Character Error Rate)
|
539 |
- **Transliteration**: Harvard-Kyoto system for Indic scripts
|
540 |
- **Analysis**: Dual-pass approach for comprehensive feedback
|
541 |
|
542 |
-
**Note**:
|
543 |
-
**Languages**:
|
544 |
""")
|
545 |
|
546 |
return demo
|
547 |
|
548 |
# ---------------- LAUNCH ---------------- #
|
549 |
if __name__ == "__main__":
|
550 |
-
print("π Starting Multilingual Pronunciation Trainer...")
|
551 |
print(f"π§ Device: {DEVICE}")
|
552 |
print(f"π§ PyTorch version: {torch.__version__}")
|
553 |
-
|
554 |
-
|
555 |
-
print("π¦ Pre-loading English model...")
|
556 |
-
try:
|
557 |
-
load_asr_model("English")
|
558 |
-
print("β
English model loaded successfully")
|
559 |
-
except Exception as e:
|
560 |
-
print(f"β οΈ Warning: Could not pre-load English model: {e}")
|
561 |
|
562 |
demo = create_interface()
|
563 |
demo.launch(
|
|
|
29 |
"Malayalam": "ml"
|
30 |
}
|
31 |
|
32 |
+
# Updated model configurations with LARGE models for maximum accuracy
|
33 |
ASR_MODELS = {
|
34 |
"English": "openai/whisper-base.en",
|
35 |
+
"Tamil": "ai4bharat/whisper-large-ta", # LARGE AI4Bharat Tamil model (~1.5GB)
|
36 |
+
"Malayalam": "ai4bharat/whisper-large-ml" # LARGE AI4Bharat Malayalam model (~1.5GB)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
}
|
38 |
|
39 |
LANG_PRIMERS = {
|
|
|
88 |
asr_models = {}
|
89 |
|
90 |
def load_asr_model(language):
|
91 |
+
"""Load ASR model for specific language - PRIMARY MODELS ONLY"""
|
92 |
if language not in asr_models:
|
93 |
+
model_name = ASR_MODELS[language]
|
94 |
+
print(f"π Loading LARGE model for {language}: {model_name}")
|
95 |
+
|
96 |
try:
|
97 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
98 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
99 |
+
model_name,
|
100 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
101 |
+
low_cpu_mem_usage=True,
|
102 |
+
use_safetensors=True
|
103 |
+
).to(DEVICE)
|
104 |
+
|
105 |
+
asr_models[language] = {"processor": processor, "model": model, "model_name": model_name}
|
106 |
+
print(f"β
LARGE model loaded successfully for {language}")
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
except Exception as e:
|
109 |
+
print(f"β Failed to load {model_name}: {e}")
|
110 |
+
raise Exception(f"Could not load {language} model. Please check model availability.")
|
|
|
|
|
|
|
|
|
111 |
|
112 |
return asr_models[language]
|
113 |
|
|
|
324 |
# ---------------- MAIN FUNCTION ---------------- #
|
325 |
def compare_pronunciation(audio, language_choice, intended_sentence):
|
326 |
"""Main function to compare pronunciation"""
|
327 |
+
print(f"π Starting analysis with language: {language_choice}")
|
328 |
+
print(f"π Audio file: {audio}")
|
329 |
+
print(f"π― Intended sentence: {intended_sentence}")
|
330 |
+
|
331 |
if audio is None:
|
332 |
+
print("β No audio provided")
|
333 |
return ("β Please record audio first.", "", "", "", "", "", "", "", "", "", "", "", "")
|
334 |
|
335 |
if not intended_sentence.strip():
|
336 |
+
print("β No intended sentence")
|
337 |
return ("β Please generate a practice sentence first.", "", "", "", "", "", "", "", "", "", "", "", "")
|
338 |
|
339 |
try:
|
340 |
print(f"π Analyzing pronunciation for {language_choice}...")
|
341 |
|
342 |
# Pass 1: Raw transcription
|
343 |
+
print("π Starting Pass 1 transcription...")
|
344 |
primer_weak, _ = LANG_PRIMERS[language_choice]
|
345 |
actual_text = transcribe_audio(audio, language_choice, primer_weak, force_language=True)
|
346 |
+
print(f"β
Pass 1 result: {actual_text}")
|
347 |
|
348 |
# Pass 2: Target-biased transcription with stronger prompt
|
349 |
+
print("π Starting Pass 2 transcription...")
|
350 |
_, primer_strong = LANG_PRIMERS[language_choice]
|
351 |
strict_prompt = f"{primer_strong}\nExpected: {intended_sentence}"
|
352 |
corrected_text = transcribe_audio(audio, language_choice, strict_prompt, force_language=True)
|
353 |
+
print(f"β
Pass 2 result: {corrected_text}")
|
354 |
|
355 |
# Handle transcription errors
|
356 |
if actual_text.startswith("Error:"):
|
357 |
+
print(f"β Transcription error: {actual_text}")
|
358 |
return (f"β {actual_text}", "", "", "", "", "", "", "", "", "", "", "", "")
|
359 |
|
360 |
# Calculate error metrics
|
361 |
try:
|
362 |
+
print("π Calculating error metrics...")
|
363 |
wer_val = jiwer.wer(intended_sentence, actual_text)
|
364 |
cer_val = jiwer.cer(intended_sentence, actual_text)
|
365 |
+
print(f"β
WER: {wer_val:.3f}, CER: {cer_val:.3f}")
|
366 |
except Exception as e:
|
367 |
+
print(f"β Error calculating metrics: {e}")
|
368 |
wer_val, cer_val = 1.0, 1.0
|
369 |
|
370 |
# Get pronunciation score and feedback
|
371 |
score_text, feedback = get_pronunciation_score(wer_val, cer_val)
|
372 |
+
print(f"β
Score: {score_text}")
|
373 |
|
374 |
# Transliterations for both actual and intended
|
375 |
+
print("π Generating transliterations...")
|
376 |
actual_hk = transliterate_to_hk(actual_text, language_choice)
|
377 |
target_hk = transliterate_to_hk(intended_sentence, language_choice)
|
378 |
|
|
|
381 |
actual_hk = f"β οΈ Expected {language_choice} script, got mixed/other script"
|
382 |
|
383 |
# Visual feedback
|
384 |
+
print("π Generating visual feedback...")
|
385 |
diff_html = highlight_differences(intended_sentence, actual_text)
|
386 |
char_html = char_level_highlight(intended_sentence, actual_text)
|
387 |
|
388 |
# Status message with detailed feedback
|
389 |
status = f"β
Analysis Complete - {score_text}\n㪠{feedback}"
|
390 |
+
print(f"β
Analysis completed successfully")
|
391 |
|
392 |
return (
|
393 |
status,
|
|
|
408 |
|
409 |
except Exception as e:
|
410 |
error_msg = f"β Analysis Error: {str(e)[:200]}"
|
411 |
+
print(f"β FATAL ERROR: {e}")
|
412 |
+
import traceback
|
413 |
+
traceback.print_exc()
|
414 |
+
return (error_msg, str(e), "", "", "", "", "", "", "", "", "", "", "")
|
415 |
|
416 |
# ---------------- UI ---------------- #
|
417 |
def create_interface():
|
|
|
523 |
gr.Markdown("""
|
524 |
---
|
525 |
### π§ Technical Details:
|
526 |
+
- **ASR Models**:
|
527 |
+
- **Tamil**: AI4Bharat Whisper-LARGE-TA (~1.5GB, maximum accuracy)
|
528 |
+
- **Malayalam**: AI4Bharat Whisper-LARGE-ML (~1.5GB, maximum accuracy)
|
529 |
+
- **English**: OpenAI Whisper-Base-EN (optimized for English)
|
530 |
+
- **Performance**: Using largest available models for best pronunciation assessment
|
531 |
- **Metrics**: WER (Word Error Rate) and CER (Character Error Rate)
|
532 |
- **Transliteration**: Harvard-Kyoto system for Indic scripts
|
533 |
- **Analysis**: Dual-pass approach for comprehensive feedback
|
534 |
|
535 |
+
**Note**: Large models provide maximum accuracy but require longer initial loading time.
|
536 |
+
**Languages**: English, Tamil, and Malayalam with specialized large models.
|
537 |
""")
|
538 |
|
539 |
return demo
|
540 |
|
541 |
# ---------------- LAUNCH ---------------- #
|
542 |
if __name__ == "__main__":
|
543 |
+
print("π Starting Multilingual Pronunciation Trainer with LARGE models...")
|
544 |
print(f"π§ Device: {DEVICE}")
|
545 |
print(f"π§ PyTorch version: {torch.__version__}")
|
546 |
+
print("π¦ Models will be loaded on-demand for best performance...")
|
547 |
+
print("β‘ Using AI4Bharat LARGE models for maximum accuracy!")
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
|
549 |
demo = create_interface()
|
550 |
demo.launch(
|