sudhanm commited on
Commit
bb3f271
·
verified ·
1 Parent(s): 3a8ecbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -565
app.py CHANGED
@@ -4,618 +4,229 @@ import difflib
4
  import re
5
  import jiwer
6
  import torch
7
- import warnings
8
- import contextlib
9
- import traceback
10
- import gc
11
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
12
- import librosa
13
- import numpy as np
14
-
15
- # Optional transliteration
16
- try:
17
- from indic_transliteration import sanscript
18
- from indic_transliteration.sanscript import transliterate
19
- INDIC_OK = True
20
- print("✅ Transliteration available")
21
- except:
22
- INDIC_OK = False
23
- print("⚠️ indic_transliteration not available. Install with: pip install indic-transliteration")
24
-
25
- # Optional HF Spaces GPU decorator
26
- try:
27
- import spaces
28
- GPU_DECORATOR = spaces.GPU
29
- print("✅ HF Spaces GPU decorator available")
30
- except:
31
- class _NoOp:
32
- def __call__(self, f): return f
33
- GPU_DECORATOR = _NoOp()
34
- print("⚠️ HF Spaces not available (normal for local usage)")
35
-
36
- warnings.filterwarnings("ignore")
37
 
38
  # ---------------- CONFIG ---------------- #
39
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
- DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
41
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
42
- amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext
43
 
44
- print(f"🔧 Device: {DEVICE}")
45
- print(f"🔧 PyTorch version: {torch.__version__}")
46
- if DEVICE == "cuda":
47
- print(f"🔧 CUDA available: {torch.cuda.is_available()}")
48
- print(f"🔧 GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
 
49
 
50
  LANG_CODES = {
51
  "English": "en",
52
  "Tamil": "ta",
53
- "Malayalam": "ml",
54
- "Hindi": "hi"
55
  }
56
 
57
- # Simplified model setup - use only reliable models
58
- SPECIALIZED_MODELS = {
59
- "English": "openai/whisper-base.en",
60
- "Tamil": "openai/whisper-large-v2", # More reliable fallback
61
- "Malayalam": "openai/whisper-large-v2", # More reliable fallback
62
- "Hindi": "openai/whisper-large-v2"
 
63
  }
64
 
65
  SCRIPT_PATTERNS = {
66
  "Tamil": re.compile(r"[஀-௿]"),
67
- "Malayalam": re.compile(r"[ഀ-ൿ]"),
68
- "Hindi": re.compile(r"[ऀ-ॿ]"),
69
  "English": re.compile(r"[A-Za-z]")
70
  }
71
 
72
- # Transliteration mappings
73
- TRANSLITERATION_SCRIPTS = {
74
- "Tamil": sanscript.TAMIL if INDIC_OK else None,
75
- "Malayalam": sanscript.MALAYALAM if INDIC_OK else None,
76
- "Hindi": sanscript.DEVANAGARI if INDIC_OK else None,
77
- "English": None
78
- }
79
-
80
  SENTENCE_BANK = {
81
  "English": [
82
  "The sun sets over the horizon.",
83
- "Learning languages is fun and rewarding.",
84
  "I like to drink coffee in the morning.",
85
- "Technology helps us connect with others.",
86
  "Reading books expands our knowledge."
87
  ],
88
  "Tamil": [
89
  "இன்று நல்ல வானிலை உள்ளது.",
90
  "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
91
  "எனக்கு புத்தகம் படிக்க விருப்பம்.",
92
- "காலையில் காபி குடிக்க பிடிக்கும்.",
93
- "நண்பர்களுடன் பேசுவது மகிழ்ச்சி."
94
  ],
95
  "Malayalam": [
96
  "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
97
  "ഇന്ന് മഴപെയ്യുന്നു.",
98
  "ഞാൻ പുസ്തകം വായിക്കുന്നു.",
99
- "കാലയിൽ ചായ കുടിക്കാൻ ഇഷ്ടമാണ്.",
100
- "സുഹൃത്തുക്കളോടു സംസാരിക്കുന്നത് സന്തോഷമാണ്."
101
- ],
102
- "Hindi": [
103
- "आज मौसम अच्छा है।",
104
- "मुझे हिंदी बोलना पसंद है।",
105
- "मैं किताब पढ़ रहा हूँ।",
106
- "सुबह चाय पीना अच्छा लगता है।",
107
- "दोस्तों के साथ बात करना खुशी देता है।"
108
  ]
109
  }
110
 
111
- # Model cache
112
- models = {}
 
 
 
 
 
 
 
113
 
114
- # ---------------- SAFE HELPERS ---------------- #
115
- def safe_operation(func, *args, **kwargs):
116
- """Wrapper for safe operations with detailed error reporting"""
117
- try:
118
- return func(*args, **kwargs), None
119
- except Exception as e:
120
- error_msg = f"Error in {func.__name__}: {str(e)}"
121
- print(f"❌ {error_msg}")
122
- print(f"🔍 Traceback: {traceback.format_exc()}")
123
- return None, error_msg
124
 
 
125
  def get_random_sentence(language_choice):
126
- try:
127
- return random.choice(SENTENCE_BANK[language_choice])
128
- except Exception as e:
129
- print(f"❌ Error getting sentence: {e}")
130
- return "Error loading sentence"
131
-
132
- def is_correct_script(text, lang_name):
133
- """Check if text contains the expected script for the language"""
134
- try:
135
- if not text or not text.strip():
136
- return False
137
- pattern = SCRIPT_PATTERNS.get(lang_name)
138
- if not pattern:
139
- return True
140
- return bool(pattern.search(text))
141
- except Exception as e:
142
- print(f"❌ Script check error: {e}")
143
- return True # Default to True to avoid blocking
144
-
145
- def transliterate_text(text, lang_choice, to_romanized=True):
146
- """Transliterate text to/from romanized form"""
147
- try:
148
- if not INDIC_OK or not text or not text.strip():
149
- return text
150
-
151
- source_script = TRANSLITERATION_SCRIPTS.get(lang_choice)
152
- if not source_script:
153
- return text
154
-
155
- if to_romanized:
156
- return transliterate(text, source_script, sanscript.HK)
157
- else:
158
- return transliterate(text, sanscript.HK, source_script)
159
- except Exception as e:
160
- print(f"⚠️ Transliteration failed: {e}")
161
- return text
162
-
163
- def preprocess_audio(audio_path, target_sr=16000):
164
- """Enhanced audio preprocessing with better error handling"""
165
- try:
166
- print(f"🔊 Processing audio: {audio_path}")
167
-
168
- if not audio_path:
169
- return None, "No audio file provided"
170
-
171
- # Load audio
172
- audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
173
- print(f"🔊 Audio loaded: shape={audio.shape}, sr={sr}")
174
-
175
- if audio is None or len(audio) == 0:
176
- return None, "Empty audio file"
177
-
178
- # Normalize audio
179
- audio = audio.astype(np.float32)
180
- max_val = np.max(np.abs(audio))
181
- if max_val > 0:
182
- audio = audio / max_val
183
-
184
- # Trim silence
185
- audio, _ = librosa.effects.trim(audio, top_db=20)
186
-
187
- # Check minimum length (0.5 seconds)
188
- min_length = int(target_sr * 0.5)
189
- if len(audio) < min_length:
190
- return None, f"Audio too short: {len(audio)/target_sr:.2f}s (minimum 0.5s)"
191
-
192
- print(f"✅ Audio processed successfully: {len(audio)/target_sr:.2f}s")
193
- return audio, sr
194
-
195
- except Exception as e:
196
- error_msg = f"Audio preprocessing failed: {str(e)}"
197
- print(f"❌ {error_msg}")
198
- return None, error_msg
199
-
200
- # ---------------- SIMPLIFIED MODEL LOADERS ---------------- #
201
- def load_whisper_model(model_name, language=None):
202
- """Load a single Whisper model with robust error handling"""
203
- model_key = f"{model_name}_{language}"
204
 
205
- if model_key in models:
206
- print(f"✅ Using cached model: {model_key}")
207
- return models[model_key], None
208
 
209
- try:
210
- print(f"🔄 Loading model: {model_name}")
211
-
212
- # Clear GPU memory first
213
- if DEVICE == "cuda":
214
- torch.cuda.empty_cache()
215
- gc.collect()
216
-
217
- # Load with pipeline (simpler and more robust)
218
- pipeline_model = pipeline(
219
- "automatic-speech-recognition",
220
- model=model_name,
221
- device=DEVICE_INDEX,
222
- torch_dtype=DTYPE,
223
- return_timestamps=False
224
- )
225
-
226
- models[model_key] = pipeline_model
227
- print(f"✅ Model loaded successfully: {model_name}")
228
- return pipeline_model, None
229
-
230
- except Exception as e:
231
- error_msg = f"Failed to load {model_name}: {str(e)}"
232
- print(f"❌ {error_msg}")
233
- return None, error_msg
234
-
235
- # ---------------- TRANSCRIPTION ---------------- #
236
- def transcribe_audio_safe(audio_path, model_name, language):
237
- """Safe transcription with comprehensive error handling"""
238
- try:
239
- print(f"🎤 Starting transcription with {model_name} for {language}")
240
-
241
- # Load model
242
- model, model_error = load_whisper_model(model_name, language)
243
- if model is None:
244
- return None, f"Model loading failed: {model_error}"
245
-
246
- # Set language if supported
247
- lang_code = LANG_CODES.get(language, "en")
248
- generate_kwargs = {
249
- "max_new_tokens": 200,
250
- "num_beams": 1, # Faster decoding
251
- "do_sample": False
252
- }
253
-
254
- # Try to set language (some models support this)
255
- try:
256
- if hasattr(model, "model") and hasattr(model, "tokenizer"):
257
- forced_ids = model.tokenizer.get_decoder_prompt_ids(
258
- language=lang_code,
259
- task="transcribe"
260
- )
261
- if forced_ids:
262
- model.model.config.forced_decoder_ids = forced_ids
263
- print(f"🔧 Language set to: {lang_code}")
264
- except Exception as lang_error:
265
- print(f"⚠️ Language forcing failed (continuing anyway): {lang_error}")
266
-
267
- # Transcribe
268
- print(f"🔄 Transcribing...")
269
- with amp_ctx():
270
- result = model(audio_path, generate_kwargs=generate_kwargs)
271
-
272
- if isinstance(result, dict):
273
- text = result.get("text", "").strip()
274
- else:
275
- text = str(result).strip()
276
-
277
- print(f"✅ Transcription complete: '{text[:50]}...'")
278
- return text, None
279
-
280
- except Exception as e:
281
- error_msg = f"Transcription failed: {str(e)}"
282
- print(f"❌ {error_msg}")
283
- print(f"🔍 Full traceback: {traceback.format_exc()}")
284
- return None, error_msg
285
-
286
- # ---------------- ANALYSIS ---------------- #
287
- def compute_metrics_safe(reference, hypothesis):
288
- """Compute WER and CER with error handling"""
289
- try:
290
- ref_clean = reference.strip()
291
- hyp_clean = hypothesis.strip()
292
-
293
- if not ref_clean or not hyp_clean:
294
- return 1.0, 1.0, "Empty text"
295
-
296
- wer = jiwer.wer(ref_clean, hyp_clean)
297
- cer = jiwer.cer(ref_clean, hyp_clean)
298
-
299
- return wer, cer, None
300
- except Exception as e:
301
- print(f"❌ Metric computation failed: {e}")
302
- return 1.0, 1.0, str(e)
303
-
304
- def get_pronunciation_score(wer, cer):
305
- """Convert error rates to intuitive scores and feedback"""
306
- try:
307
- combined_error = (wer * 0.7) + (cer * 0.3)
308
- accuracy = 1 - combined_error
309
-
310
- if accuracy >= 0.95:
311
- return "🏆 Perfect!", "Outstanding pronunciation! Native-like accuracy."
312
- elif accuracy >= 0.85:
313
- return "🎉 Excellent!", "Very good pronunciation with minor variations."
314
- elif accuracy >= 0.70:
315
- return "👍 Good!", "Good pronunciation, practice specific sounds."
316
- elif accuracy >= 0.50:
317
- return "📚 Needs Practice", "Focus on clearer pronunciation and rhythm."
318
- else:
319
- return "💪 Keep Trying!", "Break down into smaller parts and practice slowly."
320
- except Exception as e:
321
- print(f"❌ Score calculation failed: {e}")
322
- return "❓ Unknown", "Could not calculate score"
323
-
324
- def highlight_differences_safe(reference, hypothesis):
325
- """Safe highlighting with error handling"""
326
- try:
327
- if not reference or not hypothesis:
328
- return "No text to compare", "No text to compare"
329
-
330
- # Word-level highlighting
331
- ref_words = reference.split()
332
- hyp_words = hypothesis.split()
333
-
334
- sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
335
- word_html = []
336
-
337
- for tag, i1, i2, j1, j2 in sm.get_opcodes():
338
- if tag == 'equal':
339
- word_html.extend([
340
- f"<span style='background-color:#d4edda; color:#155724; padding:2px 4px; margin:1px; border-radius:3px'>{word}</span>"
341
- for word in ref_words[i1:i2]
342
- ])
343
- elif tag == 'replace':
344
- word_html.extend([
345
- f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>"
346
- for word in ref_words[i1:i2]
347
- ])
348
- word_html.extend([
349
- f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>→{word}</span>"
350
- for word in hyp_words[j1:j2]
351
- ])
352
- elif tag == 'delete':
353
- word_html.extend([
354
- f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>"
355
- for word in ref_words[i1:i2]
356
- ])
357
- elif tag == 'insert':
358
- word_html.extend([
359
- f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>+{word}</span>"
360
- for word in hyp_words[j1:j2]
361
- ])
362
-
363
- # Character-level highlighting
364
- sm_char = difflib.SequenceMatcher(None, list(reference), list(hypothesis))
365
- char_html = []
366
-
367
- for tag, i1, i2, j1, j2 in sm_char.get_opcodes():
368
- if tag == 'equal':
369
- char_html.extend([
370
- f"<span style='color:#28a745'>{char}</span>"
371
- for char in reference[i1:i2]
372
- ])
373
- elif tag in ('replace', 'delete'):
374
- char_html.extend([
375
- f"<span style='color:#dc3545; text-decoration:underline; font-weight:bold'>{char}</span>"
376
- for char in reference[i1:i2]
377
- ])
378
- elif tag == 'insert':
379
- char_html.extend([
380
- f"<span style='color:#fd7e14; font-weight:bold'>{char}</span>"
381
- for char in hypothesis[j1:j2]
382
- ])
383
-
384
- return " ".join(word_html), "".join(char_html)
385
-
386
- except Exception as e:
387
- print(f"❌ Highlighting failed: {e}")
388
- return f"Error highlighting: {str(e)}", f"Error highlighting: {str(e)}"
389
-
390
- # ---------------- MAIN FUNCTION ---------------- #
391
- @GPU_DECORATOR
392
- def compare_pronunciation(audio, language_choice, intended_sentence):
393
- """Main function with comprehensive error handling"""
394
 
395
- print(f"\n🔍 Starting pronunciation analysis...")
396
- print(f"📝 Language: {language_choice}")
397
- print(f"🎯 Target: {intended_sentence[:50]}...")
398
 
399
- try:
400
- # Validate inputs
401
- if audio is None:
402
- return ("❌ Please record audio first", "", "", "", "", "", "", "", "", "", "")
403
-
404
- if not intended_sentence or not intended_sentence.strip():
405
- return ("❌ Please generate a sentence first", "", "", "", "", "", "", "", "", "", "")
406
-
407
- # Preprocess audio
408
- processed_audio, audio_error = preprocess_audio(audio)
409
- if processed_audio is None:
410
- return (f"❌ Audio processing failed: {audio_error}", "", "", "", "", "", "", "", "", "", "")
411
-
412
- # Get models for this language
413
- primary_model = "openai/whisper-large-v2" # Always reliable
414
- specialized_model = SPECIALIZED_MODELS.get(language_choice, primary_model)
415
-
416
- print(f"🤖 Using models: Primary={primary_model}, Specialized={specialized_model}")
417
-
418
- # Try primary transcription
419
- primary_text, primary_error = transcribe_audio_safe(audio, primary_model, language_choice)
420
-
421
- # Try specialized transcription (if different from primary)
422
- if specialized_model != primary_model:
423
- specialized_text, specialized_error = transcribe_audio_safe(audio, specialized_model, language_choice)
424
- else:
425
- specialized_text, specialized_error = primary_text, primary_error
426
-
427
- # Choose best result
428
- best_text = None
429
- best_source = "None"
430
-
431
- if primary_text and not primary_text.startswith("Error"):
432
- best_text = primary_text
433
- best_source = "Primary Model"
434
- elif specialized_text and not specialized_text.startswith("Error"):
435
- best_text = specialized_text
436
- best_source = "Specialized Model"
437
-
438
- if not best_text:
439
- error_details = f"Primary: {primary_error or 'Unknown error'}\nSpecialized: {specialized_error or 'Unknown error'}"
440
- return (f"❌ Both models failed:\n{error_details}",
441
- primary_text or "Failed", specialized_text or "Failed",
442
- "", "", "", "", "", "", "", "")
443
-
444
- print(f"✅ Best transcription from {best_source}: '{best_text}'")
445
-
446
- # Compute metrics
447
- wer, cer, metric_error = compute_metrics_safe(intended_sentence, best_text)
448
- if metric_error:
449
- print(f"⚠️ Metric computation warning: {metric_error}")
450
-
451
- # Get score and feedback
452
- score, feedback = get_pronunciation_score(wer, cer)
453
-
454
- # Create visual comparisons
455
- word_diff, char_diff = highlight_differences_safe(intended_sentence, best_text)
456
-
457
- # Transliterations
458
- intended_translit = transliterate_text(intended_sentence, language_choice, to_romanized=True)
459
- actual_translit = transliterate_text(best_text, language_choice, to_romanized=True)
460
-
461
- # Create status message
462
- status_msg = f"""✅ Analysis Complete!
 
 
 
 
 
 
 
 
 
463
 
464
- {score}
465
- {feedback}
 
 
466
 
467
- 🤖 Best result from: {best_source}
468
- 📊 Word Accuracy: {(1-wer)*100:.1f}%
469
- 📈 Character Accuracy: {(1-cer)*100:.1f}%
470
 
471
- 🔍 Quick Tips:
472
- • Green = Correct pronunciation ✅
473
- • Red = Wrong/missing words ❌
474
- • Orange = Added/substituted words 🔄
475
- """
476
-
477
- return (
478
- status_msg,
479
- primary_text or "Failed",
480
- specialized_text or "Failed",
481
- f"{wer:.3f} ({(1-wer)*100:.1f}%)",
482
- f"{cer:.3f} ({(1-cer)*100:.1f}%)",
483
- intended_sentence,
484
- best_text,
485
- intended_translit,
486
- actual_translit,
487
- word_diff,
488
- char_diff
489
- )
490
-
491
- except Exception as e:
492
- error_msg = f"❌ Unexpected error: {str(e)}"
493
- print(f"{error_msg}")
494
- print(f"🔍 Full traceback: {traceback.format_exc()}")
495
- return (error_msg, "", "", "", "", "", "", "", "", "", "")
496
 
497
- # ---------------- UI ---------------- #
498
- def create_interface():
499
- with gr.Blocks(title="Enhanced Pronunciation Comparator", theme=gr.themes.Soft()) as demo:
500
- gr.Markdown("""
501
- # 🎙️ Enhanced Pronunciation Comparator
502
-
503
- **Perfect your pronunciation in English, Tamil, Malayalam, and Hindi!**
504
-
505
- ### 🚀 How to use:
506
- 1. 🌐 Select your target language
507
- 2. 🎲 Generate a practice sentence
508
- 3. 🎤 Record yourself clearly (at least 0.5 seconds)
509
- 4. 🔍 Get detailed analysis with visual feedback
510
- """)
511
-
512
- with gr.Row():
513
- with gr.Column(scale=2):
514
- language_dropdown = gr.Dropdown(
515
- choices=list(LANG_CODES.keys()),
516
- value="English", # Start with English for reliability
517
- label="🌐 Select Language"
518
- )
519
- with gr.Column(scale=1):
520
- generate_btn = gr.Button("🎲 Generate Practice Sentence", variant="primary")
521
-
522
- intended_textbox = gr.Textbox(
523
- label="📝 Practice Sentence",
524
- interactive=False,
525
- lines=2,
526
- placeholder="Click 'Generate Practice Sentence' to get started..."
527
- )
528
-
529
- audio_input = gr.Audio(
530
- sources=["microphone", "upload"],
531
- type="filepath",
532
- label="🎤 Record Your Pronunciation (speak clearly for at least 0.5 seconds)"
533
- )
534
 
535
- analyze_btn = gr.Button("🔍 Analyze Pronunciation", variant="secondary", size="lg")
536
-
537
- status_output = gr.Textbox(
538
- label="📊 Analysis Results",
539
- interactive=False,
540
- lines=10
541
- )
542
-
543
- with gr.Accordion("🤖 Model Outputs (Debug Info)", open=False):
544
- with gr.Row():
545
- primary_output = gr.Textbox(label="Primary Model Output", interactive=False)
546
- specialized_output = gr.Textbox(label="Specialized Model Output", interactive=False)
547
-
548
- with gr.Accordion("📈 Detailed Metrics", open=False):
549
- with gr.Row():
550
- wer_output = gr.Textbox(label="Word Error Rate", interactive=False)
551
- cer_output = gr.Textbox(label="Character Error Rate", interactive=False)
552
-
553
- gr.Markdown("### 🔍 Side-by-Side Comparison")
554
-
555
- with gr.Row():
556
- with gr.Column():
557
- gr.Markdown("#### 📝 Original Script")
558
- intended_orig = gr.Textbox(label="🎯 Target Text", interactive=False)
559
- actual_orig = gr.Textbox(label="🗣️ What You Said", interactive=False)
560
- with gr.Column():
561
- gr.Markdown("#### 🔤 Romanized (Transliterated)")
562
- intended_translit = gr.Textbox(label="🎯 Target (Romanized)", interactive=False)
563
- actual_translit = gr.Textbox(label="🗣️ What You Said (Romanized)", interactive=False)
564
-
565
- gr.Markdown("### 🎨 Visual Feedback")
566
- gr.Markdown("**🟢 Green** = Correct | **🔴 Red** = Wrong/Missing | **🟠 Orange** = Added/Substituted")
567
-
568
- word_diff_html = gr.HTML(label="🔤 Word-by-Word Comparison")
569
- char_diff_html = gr.HTML(label="🔍 Character-by-Character Analysis")
570
-
571
- # Event handlers
572
- generate_btn.click(
573
- fn=get_random_sentence,
574
- inputs=[language_dropdown],
575
- outputs=[intended_textbox]
576
- )
577
-
578
- analyze_btn.click(
579
- fn=compare_pronunciation,
580
- inputs=[audio_input, language_dropdown, intended_textbox],
581
- outputs=[
582
- status_output, primary_output, specialized_output,
583
- wer_output, cer_output, intended_orig, actual_orig,
584
- intended_translit, actual_translit, word_diff_html, char_diff_html
585
- ]
586
- )
587
-
588
- language_dropdown.change(
589
- fn=get_random_sentence,
590
- inputs=[language_dropdown],
591
- outputs=[intended_textbox]
592
- )
593
-
594
- gr.Markdown("""
595
- ### 🔧 Troubleshooting:
596
 
597
- - **"Audio too short"** → Record for at least 0.5 seconds
598
- - **"Model loading failed"** Try refreshing the page
599
- - **"Empty transcription"** Speak louder and clearer
600
- - **Script mismatch** Make sure you're speaking the correct language
601
- - **General errors** → Check the debug info section for details
602
- """)
603
-
604
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
- # ---------------- LAUNCH ---------------- #
607
  if __name__ == "__main__":
608
- print("🚀 Starting Enhanced Pronunciation Comparator...")
609
- print("🔧 System Check:")
610
- print(f" - PyTorch: {torch.__version__}")
611
- print(f" - Device: {DEVICE}")
612
- print(f" - Transliteration: {'✅' if INDIC_OK else '❌'}")
613
-
614
- demo = create_interface()
615
- demo.launch(
616
- server_name="0.0.0.0",
617
- server_port=7860,
618
- share=True,
619
- show_error=True,
620
- debug=True # Enable debug mode for better error reporting
621
- )
 
4
  import re
5
  import jiwer
6
  import torch
7
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
8
+ from indic_transliteration import sanscript
9
+ from indic_transliteration.sanscript import transliterate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # ---------------- CONFIG ---------------- #
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
13
 
14
+ # Updated model configurations for each language
15
+ MODEL_CONFIGS = {
16
+ "English": "openai/whisper-large-v2",
17
+ "Tamil": "vasista22/whisper-tamil-large-v2",
18
+ "Malayalam": "thennal/whisper-medium-ml"
19
+ }
20
 
21
  LANG_CODES = {
22
  "English": "en",
23
  "Tamil": "ta",
24
+ "Malayalam": "ml"
 
25
  }
26
 
27
+ LANG_PRIMERS = {
28
+ "English": ("The transcript should be in English only.",
29
+ "Write only in English without translation. Example: This is an English sentence."),
30
+ "Tamil": ("நகல் தமிழ் எழுத்துக்களில் மட்டும் இருக்க வேண்டும்.",
31
+ "தமிழ் எழுத்துக்களில் மட்டும் எழுதவும், மொழிபெயர்ப்பு செய்யக்கூடாது. உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."),
32
+ "Malayalam": ("ട്രാൻസ്ഖ്രിപ്റ്റ് മലയാള ലിപിയിൽ ആയിരിക്കണം.",
33
+ "മലയാള ലിപിയിൽ മാത്രം എഴുതുക, വിവർത്തനം ചെയ്യരുത്. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്. എനിക്ക് മലയാളം അറിയാം.")
34
  }
35
 
36
  SCRIPT_PATTERNS = {
37
  "Tamil": re.compile(r"[஀-௿]"),
38
+ "Malayalam": re.compile(r"[ഀ-ൿ]"),
 
39
  "English": re.compile(r"[A-Za-z]")
40
  }
41
 
 
 
 
 
 
 
 
 
42
  SENTENCE_BANK = {
43
  "English": [
44
  "The sun sets over the horizon.",
45
+ "Learning languages is fun.",
46
  "I like to drink coffee in the morning.",
47
+ "Technology helps us communicate better.",
48
  "Reading books expands our knowledge."
49
  ],
50
  "Tamil": [
51
  "இன்று நல்ல வானிலை உள்ளது.",
52
  "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
53
  "எனக்கு புத்தகம் படிக்க விருப்பம்.",
54
+ "தமிழ் மொழி மிகவும் அழகானது.",
55
+ "நான் தினமும் பள்ளிக்கு செல்கிறேன்."
56
  ],
57
  "Malayalam": [
58
  "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
59
  "ഇന്ന് മഴപെയ്യുന്നു.",
60
  "ഞാൻ പുസ്തകം വായിക്കുന്നു.",
61
+ "കേരളം എന്റെ സ്വന്തം നാടാണ്.",
62
+ "ഞാൻ മലയാളം പഠിക്കുന്നു."
 
 
 
 
 
 
 
63
  ]
64
  }
65
 
66
+ # ---------------- LOAD MODELS ---------------- #
67
+ print("Loading Whisper models...")
68
+ whisper_models = {}
69
+ whisper_processors = {}
70
+
71
+ for lang, model_id in MODEL_CONFIGS.items():
72
+ print(f"Loading {lang} model: {model_id}")
73
+ whisper_models[lang] = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
74
+ whisper_processors[lang] = WhisperProcessor.from_pretrained(model_id)
75
 
76
+ print("All models loaded successfully!")
 
 
 
 
 
 
 
 
 
77
 
78
+ # ---------------- HELPERS ---------------- #
79
  def get_random_sentence(language_choice):
80
+ return random.choice(SENTENCE_BANK[language_choice])
81
+
82
+ def is_script(text, lang_name):
83
+ pattern = SCRIPT_PATTERNS.get(lang_name)
84
+ return bool(pattern.search(text)) if pattern else True
85
+
86
+ def transliterate_to_hk(text, lang_choice):
87
+ mapping = {
88
+ "Tamil": sanscript.TAMIL,
89
+ "Malayalam": sanscript.MALAYALAM,
90
+ "English": None
91
+ }
92
+ return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text
93
+
94
+ def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
95
+ # Get the appropriate model and processor for the language
96
+ model = whisper_models[language_choice]
97
+ processor = whisper_processors[language_choice]
98
+ lang_code = LANG_CODES[language_choice]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Load and process audio
101
+ import librosa
102
+ audio, sr = librosa.load(audio_path, sr=16000)
103
 
104
+ # Process audio with the specific model's processor
105
+ input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Generate forced decoder ids for the language
108
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe")
 
109
 
110
+ # Generate transcription
111
+ with torch.no_grad():
112
+ predicted_ids = model.generate(
113
+ input_features,
114
+ forced_decoder_ids=forced_decoder_ids,
115
+ max_length=448,
116
+ num_beams=beam_size,
117
+ temperature=temperature if temperature > 0 else None,
118
+ do_sample=temperature > 0,
119
+ )
120
+
121
+ # Decode the transcription
122
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
123
+ return transcription.strip()
124
+
125
+ def highlight_differences(ref, hyp):
126
+ ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
127
+ sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
128
+ out_html = []
129
+ for tag, i1, i2, j1, j2 in sm.get_opcodes():
130
+ if tag == 'equal':
131
+ out_html.extend([f"<span style='color:green'>{w}</span>" for w in ref_words[i1:i2]])
132
+ elif tag == 'replace':
133
+ out_html.extend([f"<span style='color:red'>{w}</span>" for w in ref_words[i1:i2]])
134
+ out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
135
+ elif tag == 'delete':
136
+ out_html.extend([f"<span style='color:red;text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
137
+ elif tag == 'insert':
138
+ out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
139
+ return " ".join(out_html)
140
+
141
+ def char_level_highlight(ref, hyp):
142
+ sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
143
+ out = []
144
+ for tag, i1, i2, j1, j2 in sm.get_opcodes():
145
+ if tag == 'equal':
146
+ out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]])
147
+ elif tag in ('replace', 'delete'):
148
+ out.extend([f"<span style='color:red;text-decoration:underline'>{c}</span>" for c in ref[i1:i2]])
149
+ elif tag == 'insert':
150
+ out.extend([f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]])
151
+ return "".join(out)
152
+
153
+ # ---------------- MAIN ---------------- #
154
+ def compare_pronunciation(audio, language_choice, intended_sentence,
155
+ pass1_beam, pass1_temp, pass1_condition):
156
+ if audio is None or not intended_sentence.strip():
157
+ return ("No audio or intended sentence.", "", "", "", "", "", "", "")
158
+
159
+ primer_weak, primer_strong = LANG_PRIMERS[language_choice]
160
+
161
+ # Pass 1: raw transcription with user-configured decoding parameters
162
+ actual_text = transcribe_once(audio, language_choice, primer_weak,
163
+ pass1_beam, pass1_temp, pass1_condition)
164
+
165
+ # Pass 2: strict transcription biased by intended sentence (fixed decoding params)
166
+ strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
167
+ corrected_text = transcribe_once(audio, language_choice, strict_prompt,
168
+ beam_size=5, temperature=0.0, condition_on_previous_text=False)
169
+
170
+ # Compute WER and CER
171
+ wer_val = jiwer.wer(intended_sentence, actual_text)
172
+ cer_val = jiwer.cer(intended_sentence, actual_text)
173
+
174
+ # Transliteration of Pass 1 output
175
+ hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]"
176
+
177
+ # Highlight word-level and character-level differences
178
+ diff_html = highlight_differences(intended_sentence, actual_text)
179
+ char_html = char_level_highlight(intended_sentence, actual_text)
180
+
181
+ return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}",
182
+ diff_html, char_html, intended_sentence)
183
 
184
+ # ---------------- UI ---------------- #
185
+ with gr.Blocks(title="Pronunciation Comparator") as demo:
186
+ gr.Markdown("## 🎙 Pronunciation Comparator - English, Tamil & Malayalam")
187
+ gr.Markdown("Practice pronunciation with specialized Whisper models for each language!")
188
 
189
+ with gr.Row():
190
+ lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
191
+ gen_btn = gr.Button("🎲 Generate Sentence")
192
 
193
+ intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ with gr.Row():
196
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ with gr.Column():
199
+ gr.Markdown("### Transcription Parameters")
200
+ pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size")
201
+ pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature")
202
+ pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text")
203
+
204
+ submit_btn = gr.Button("🔍 Analyze Pronunciation", variant="primary")
205
+
206
+ with gr.Row():
207
+ pass1_out = gr.Textbox(label="Pass 1: What You Actually Said")
208
+ pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ with gr.Row():
211
+ hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)")
212
+ wer_out = gr.Textbox(label="Word Error Rate")
213
+ cer_out = gr.Textbox(label="Character Error Rate")
214
+
215
+ gr.Markdown("### Visual Feedback")
216
+ diff_html_box = gr.HTML(label="Word Differences Highlighted")
217
+ char_html_box = gr.HTML(label="Character-Level Highlighting (mispronounced = red underline)")
218
+
219
+ # Event handlers
220
+ gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display])
221
+
222
+ submit_btn.click(
223
+ fn=compare_pronunciation,
224
+ inputs=[audio_input, language_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
225
+ outputs=[
226
+ pass1_out, pass2_out, hk_translit, wer_out, cer_out,
227
+ diff_html_box, char_html_box, intended_display
228
+ ]
229
+ )
230
 
 
231
  if __name__ == "__main__":
232
+ demo.launch()