sudhanm commited on
Commit
3a8ecbf
Β·
verified Β·
1 Parent(s): 05566a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -403
app.py CHANGED
@@ -6,6 +6,8 @@ import jiwer
6
  import torch
7
  import warnings
8
  import contextlib
 
 
9
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
10
  import librosa
11
  import numpy as np
@@ -15,18 +17,21 @@ try:
15
  from indic_transliteration import sanscript
16
  from indic_transliteration.sanscript import transliterate
17
  INDIC_OK = True
 
18
  except:
19
  INDIC_OK = False
20
- print("⚠️ indic_transliteration not available. Transliteration features disabled.")
21
 
22
  # Optional HF Spaces GPU decorator
23
  try:
24
  import spaces
25
  GPU_DECORATOR = spaces.GPU
 
26
  except:
27
  class _NoOp:
28
  def __call__(self, f): return f
29
  GPU_DECORATOR = _NoOp()
 
30
 
31
  warnings.filterwarnings("ignore")
32
 
@@ -35,7 +40,12 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
  DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
36
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
37
  amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext
38
- print(f"πŸ”§ Using device: {DEVICE}")
 
 
 
 
 
39
 
40
  LANG_CODES = {
41
  "English": "en",
@@ -44,15 +54,12 @@ LANG_CODES = {
44
  "Hindi": "hi"
45
  }
46
 
47
- # Primary: IndicWhisper
48
- INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
49
-
50
- # Specialized models for better accuracy
51
  SPECIALIZED_MODELS = {
52
  "English": "openai/whisper-base.en",
53
- "Tamil": "vasista22/whisper-tamil-large-v2",
54
- "Malayalam": "thennal/whisper-medium-ml",
55
- "Hindi": "openai/whisper-large-v2" # Using general model for Hindi
56
  }
57
 
58
  SCRIPT_PATTERNS = {
@@ -64,9 +71,9 @@ SCRIPT_PATTERNS = {
64
 
65
  # Transliteration mappings
66
  TRANSLITERATION_SCRIPTS = {
67
- "Tamil": sanscript.TAMIL,
68
- "Malayalam": sanscript.MALAYALAM,
69
- "Hindi": sanscript.DEVANAGARI,
70
  "English": None
71
  }
72
 
@@ -102,48 +109,71 @@ SENTENCE_BANK = {
102
  }
103
 
104
  # Model cache
105
- primary_pipeline = None
106
- specialized_models = {}
 
 
 
 
 
 
 
 
 
 
107
 
108
- # ---------------- HELPERS ---------------- #
109
  def get_random_sentence(language_choice):
110
- return random.choice(SENTENCE_BANK[language_choice])
 
 
 
 
111
 
112
  def is_correct_script(text, lang_name):
113
  """Check if text contains the expected script for the language"""
114
- if not text.strip():
115
- return False
116
- pattern = SCRIPT_PATTERNS.get(lang_name)
117
- if not pattern:
118
- return True
119
- return bool(pattern.search(text))
 
 
 
 
120
 
121
  def transliterate_text(text, lang_choice, to_romanized=True):
122
  """Transliterate text to/from romanized form"""
123
- if not INDIC_OK or not text.strip():
124
- return text
125
-
126
- source_script = TRANSLITERATION_SCRIPTS.get(lang_choice)
127
- if not source_script:
128
- return text
129
-
130
  try:
 
 
 
 
 
 
 
131
  if to_romanized:
132
- # Convert to Harvard-Kyoto (romanized)
133
  return transliterate(text, source_script, sanscript.HK)
134
  else:
135
- # Convert from romanized to native script (if needed)
136
  return transliterate(text, sanscript.HK, source_script)
137
  except Exception as e:
138
  print(f"⚠️ Transliteration failed: {e}")
139
  return text
140
 
141
  def preprocess_audio(audio_path, target_sr=16000):
142
- """Enhanced audio preprocessing"""
143
  try:
 
 
 
 
 
 
144
  audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
 
 
145
  if audio is None or len(audio) == 0:
146
- return None, None
147
 
148
  # Normalize audio
149
  audio = audio.astype(np.float32)
@@ -154,374 +184,282 @@ def preprocess_audio(audio_path, target_sr=16000):
154
  # Trim silence
155
  audio, _ = librosa.effects.trim(audio, top_db=20)
156
 
157
- # Check minimum length (0.1 seconds)
158
- if len(audio) < int(target_sr * 0.1):
159
- return None, None
160
-
161
- return audio, target_sr
 
 
 
162
  except Exception as e:
163
- print(f"⚠️ Audio preprocessing failed: {e}")
164
- return None, None
 
165
 
166
- # ---------------- MODEL LOADERS ---------------- #
167
- @GPU_DECORATOR
168
- def load_primary_model():
169
- """Load the primary IndicWhisper model"""
170
- global primary_pipeline
171
- if primary_pipeline is not None:
172
- return primary_pipeline
 
173
 
174
  try:
175
- print(f"πŸ”„ Loading primary model: {INDICWHISPER_MODEL}")
176
 
177
- # Try direct loading first
178
- primary_pipeline = pipeline(
 
 
 
 
 
179
  "automatic-speech-recognition",
180
- model=INDICWHISPER_MODEL,
181
  device=DEVICE_INDEX,
182
  torch_dtype=DTYPE,
183
- trust_remote_code=True
184
  )
185
- print("βœ… Primary model loaded successfully!")
186
- return primary_pipeline
187
-
188
- except Exception as e:
189
- print(f"⚠️ Primary model failed, using fallback: {e}")
190
- # Fallback to base Whisper
191
- primary_pipeline = pipeline(
192
- "automatic-speech-recognition",
193
- model="openai/whisper-large-v2",
194
- device=DEVICE_INDEX,
195
- torch_dtype=DTYPE
196
- )
197
- print("βœ… Fallback model loaded!")
198
- return primary_pipeline
199
-
200
- @GPU_DECORATOR
201
- def load_specialized_model(language):
202
- """Load specialized model for specific language"""
203
- if language in specialized_models:
204
- return specialized_models[language]
205
-
206
- model_name = SPECIALIZED_MODELS[language]
207
- print(f"πŸ”„ Loading specialized {language} model: {model_name}")
208
-
209
- try:
210
- processor = AutoProcessor.from_pretrained(model_name)
211
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
212
- model_name,
213
- torch_dtype=DTYPE,
214
- device_map="auto" if DEVICE == "cuda" else None
215
- ).to(DEVICE)
216
 
217
- specialized_models[language] = {
218
- "processor": processor,
219
- "model": model
220
- }
221
- print(f"βœ… Specialized {language} model loaded!")
222
- return specialized_models[language]
223
 
224
  except Exception as e:
225
- print(f"❌ Failed to load specialized {language} model: {e}")
226
- return None
 
227
 
228
  # ---------------- TRANSCRIPTION ---------------- #
229
- @GPU_DECORATOR
230
- def transcribe_with_primary(audio_path, language):
231
- """Transcribe using primary IndicWhisper model"""
232
  try:
233
- pipeline_model = load_primary_model()
234
- lang_code = LANG_CODES[language]
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- # Set language forcing if possible
237
  try:
238
- if hasattr(pipeline_model, "model") and hasattr(pipeline_model, "tokenizer"):
239
- forced_ids = pipeline_model.tokenizer.get_decoder_prompt_ids(
240
  language=lang_code,
241
  task="transcribe"
242
  )
243
  if forced_ids:
244
- pipeline_model.model.config.forced_decoder_ids = forced_ids
245
- except Exception as e:
246
- print(f"⚠️ Language forcing failed: {e}")
 
247
 
 
 
248
  with amp_ctx():
249
- result = pipeline_model(audio_path)
250
 
251
  if isinstance(result, dict):
252
- return result.get("text", "").strip()
253
- return str(result).strip()
254
-
255
- except Exception as e:
256
- return f"Primary transcription error: {str(e)}"
257
-
258
- @GPU_DECORATOR
259
- def transcribe_with_specialized(audio_path, language):
260
- """Transcribe using specialized model"""
261
- try:
262
- model_components = load_specialized_model(language)
263
- if not model_components:
264
- return "Specialized model not available"
265
-
266
- # Preprocess audio
267
- audio, sr = preprocess_audio(audio_path)
268
- if audio is None:
269
- return "Audio preprocessing failed"
270
-
271
- # Process with specialized model
272
- inputs = model_components["processor"](
273
- audio,
274
- sampling_rate=sr,
275
- return_tensors="pt"
276
- )
277
-
278
- input_features = inputs.input_features.to(DEVICE)
279
-
280
- # Generation parameters
281
- gen_kwargs = {
282
- "inputs": input_features,
283
- "max_length": 200,
284
- "num_beams": 3,
285
- "do_sample": False
286
- }
287
-
288
- # Language forcing for non-English
289
- if language != "English":
290
- try:
291
- forced_ids = model_components["processor"].tokenizer.get_decoder_prompt_ids(
292
- language=LANG_CODES[language],
293
- task="transcribe"
294
- )
295
- if forced_ids:
296
- gen_kwargs["forced_decoder_ids"] = forced_ids
297
- except Exception as e:
298
- print(f"⚠️ Specialized language forcing failed: {e}")
299
-
300
- # Generate transcription
301
- with torch.no_grad(), amp_ctx():
302
- generated_ids = model_components["model"].generate(**gen_kwargs)
303
-
304
- # Decode result
305
- transcription = model_components["processor"].batch_decode(
306
- generated_ids,
307
- skip_special_tokens=True
308
- )[0]
309
 
310
- return transcription.strip()
 
311
 
312
  except Exception as e:
313
- return f"Specialized transcription error: {str(e)}"
 
 
 
314
 
315
  # ---------------- ANALYSIS ---------------- #
316
- def compute_metrics(reference, hypothesis):
317
  """Compute WER and CER with error handling"""
318
  try:
319
- # Clean up texts
320
  ref_clean = reference.strip()
321
  hyp_clean = hypothesis.strip()
322
 
323
  if not ref_clean or not hyp_clean:
324
- return 1.0, 1.0
325
 
326
- # Compute WER and CER
327
  wer = jiwer.wer(ref_clean, hyp_clean)
328
  cer = jiwer.cer(ref_clean, hyp_clean)
329
 
330
- return wer, cer
331
  except Exception as e:
332
- print(f"⚠️ Metric computation failed: {e}")
333
- return 1.0, 1.0
334
 
335
  def get_pronunciation_score(wer, cer):
336
  """Convert error rates to intuitive scores and feedback"""
337
- # Weighted combination (WER is more important)
338
- combined_error = (wer * 0.7) + (cer * 0.3)
339
- accuracy = 1 - combined_error
340
-
341
- if accuracy >= 0.95:
342
- return "πŸ† Perfect!", "Outstanding pronunciation! Native-like accuracy.", "#d4edda"
343
- elif accuracy >= 0.85:
344
- return "πŸŽ‰ Excellent!", "Very good pronunciation with minor variations.", "#d1ecf1"
345
- elif accuracy >= 0.70:
346
- return "πŸ‘ Good!", "Good pronunciation, practice specific sounds.", "#fff3cd"
347
- elif accuracy >= 0.50:
348
- return "πŸ“š Needs Practice", "Focus on clearer pronunciation and rhythm.", "#f8d7da"
349
- else:
350
- return "πŸ’ͺ Keep Trying!", "Break down into smaller parts and practice slowly.", "#f5c6cb"
351
-
352
- def create_detailed_comparison(intended, actual, lang_choice):
353
- """Create detailed side-by-side comparison with transliteration"""
354
-
355
- # Original scripts
356
- intended_orig = intended.strip()
357
- actual_orig = actual.strip()
358
-
359
- # Transliterations
360
- intended_translit = transliterate_text(intended_orig, lang_choice, to_romanized=True)
361
- actual_translit = transliterate_text(actual_orig, lang_choice, to_romanized=True)
362
-
363
- # Word-level highlighting
364
- word_diff_orig = highlight_word_differences(intended_orig, actual_orig)
365
- word_diff_translit = highlight_word_differences(intended_translit, actual_translit)
366
-
367
- # Character-level highlighting
368
- char_diff_orig = highlight_char_differences(intended_orig, actual_orig)
369
- char_diff_translit = highlight_char_differences(intended_translit, actual_translit)
370
-
371
- return {
372
- "intended_orig": intended_orig,
373
- "actual_orig": actual_orig,
374
- "intended_translit": intended_translit,
375
- "actual_translit": actual_translit,
376
- "word_diff_orig": word_diff_orig,
377
- "word_diff_translit": word_diff_translit,
378
- "char_diff_orig": char_diff_orig,
379
- "char_diff_translit": char_diff_translit
380
- }
381
-
382
- def highlight_word_differences(reference, hypothesis):
383
- """Highlight word-level differences with colors"""
384
- ref_words = reference.split()
385
- hyp_words = hypothesis.split()
386
-
387
- sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
388
- html_output = []
389
-
390
- for tag, i1, i2, j1, j2 in sm.get_opcodes():
391
- if tag == 'equal':
392
- # Correct words - green background
393
- html_output.extend([
394
- f"<span style='background-color:#d4edda; color:#155724; padding:2px 4px; margin:1px; border-radius:3px'>{word}</span>"
395
- for word in ref_words[i1:i2]
396
- ])
397
- elif tag == 'replace':
398
- # Wrong words - red background for reference, orange for hypothesis
399
- html_output.extend([
400
- f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>"
401
- for word in ref_words[i1:i2]
402
- ])
403
- html_output.extend([
404
- f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>β†’{word}</span>"
405
- for word in hyp_words[j1:j2]
406
- ])
407
- elif tag == 'delete':
408
- # Missing words - red background
409
- html_output.extend([
410
- f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>"
411
- for word in ref_words[i1:i2]
412
- ])
413
- elif tag == 'insert':
414
- # Extra words - orange background
415
- html_output.extend([
416
- f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>+{word}</span>"
417
- for word in hyp_words[j1:j2]
418
- ])
419
-
420
- return " ".join(html_output)
421
-
422
- def highlight_char_differences(reference, hypothesis):
423
- """Highlight character-level differences"""
424
- sm = difflib.SequenceMatcher(None, list(reference), list(hypothesis))
425
- html_output = []
426
-
427
- for tag, i1, i2, j1, j2 in sm.get_opcodes():
428
- if tag == 'equal':
429
- # Correct characters - green
430
- html_output.extend([
431
- f"<span style='color:#28a745'>{char}</span>"
432
- for char in reference[i1:i2]
433
- ])
434
- elif tag in ('replace', 'delete'):
435
- # Wrong/missing characters - red with underline
436
- html_output.extend([
437
- f"<span style='color:#dc3545; text-decoration:underline; font-weight:bold'>{char}</span>"
438
- for char in reference[i1:i2]
439
- ])
440
- elif tag == 'insert':
441
- # Extra characters - orange
442
- html_output.extend([
443
- f"<span style='color:#fd7e14; font-weight:bold'>{char}</span>"
444
- for char in hypothesis[j1:j2]
445
- ])
446
-
447
- return "".join(html_output)
448
 
449
- def analyze_pronunciation_errors(intended, actual, lang_choice):
450
- """Provide specific feedback about pronunciation errors"""
451
- comparison = create_detailed_comparison(intended, actual, lang_choice)
452
-
453
- # Analyze error patterns
454
- intended_words = intended.split()
455
- actual_words = actual.split()
456
-
457
- error_analysis = []
458
-
459
- # Length difference analysis
460
- if len(actual_words) < len(intended_words):
461
- missing_count = len(intended_words) - len(actual_words)
462
- error_analysis.append(f"πŸ” You missed {missing_count} word(s). Try speaking more slowly.")
463
- elif len(actual_words) > len(intended_words):
464
- extra_count = len(actual_words) - len(intended_words)
465
- error_analysis.append(f"πŸ” You added {extra_count} extra word(s). Focus on the exact sentence.")
466
-
467
- # Script verification
468
- if not is_correct_script(actual, lang_choice):
469
- error_analysis.append(f"⚠️ The transcription doesn't contain {lang_choice} script. Check your pronunciation.")
470
-
471
- # WER/CER based feedback
472
- wer, cer = compute_metrics(intended, actual)
473
-
474
- if wer > 0.5:
475
- error_analysis.append("🎯 Focus on pronouncing each word clearly and separately.")
476
- elif wer > 0.3:
477
- error_analysis.append("🎯 Good overall, but some words need clearer pronunciation.")
478
-
479
- if cer > 0.3:
480
- error_analysis.append("πŸ”€ Pay attention to individual sounds and syllables.")
481
-
482
- return error_analysis, comparison
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
  # ---------------- MAIN FUNCTION ---------------- #
485
  @GPU_DECORATOR
486
  def compare_pronunciation(audio, language_choice, intended_sentence):
487
- """Main function to analyze pronunciation"""
488
-
489
- if audio is None:
490
- return ("❌ Please record audio first", "", "", "", "", "", "", "", "", "", "")
491
-
492
- if not intended_sentence.strip():
493
- return ("❌ Please generate a sentence first", "", "", "", "", "", "", "", "", "", "")
494
-
495
- print(f"πŸ” Analyzing pronunciation for {language_choice}...")
496
-
497
- # Get transcriptions from both models
498
- primary_result = transcribe_with_primary(audio, language_choice)
499
- specialized_result = transcribe_with_specialized(audio, language_choice)
500
-
501
- # Choose best result (prefer specialized if successful)
502
- if not specialized_result.startswith("Specialized") and specialized_result.strip():
503
- best_transcription = specialized_result
504
- best_source = "Specialized Model"
505
- elif not primary_result.startswith("Primary") and primary_result.strip():
506
- best_transcription = primary_result
507
- best_source = "Primary Model"
508
- else:
509
- return (
510
- f"❌ Both models failed:\nPrimary: {primary_result}\nSpecialized: {specialized_result}",
511
- "", "", "", "", "", "", "", "", "", ""
512
- )
513
-
514
- # Analyze pronunciation
515
- error_analysis, comparison = analyze_pronunciation_errors(
516
- intended_sentence, best_transcription, language_choice
517
- )
518
 
519
- # Compute metrics
520
- wer, cer = compute_metrics(intended_sentence, best_transcription)
521
- score, feedback, color = get_pronunciation_score(wer, cer)
522
 
523
- # Create status message
524
- status_msg = f"""βœ… Analysis Complete!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  {score}
527
  {feedback}
@@ -530,22 +468,31 @@ def compare_pronunciation(audio, language_choice, intended_sentence):
530
  πŸ“Š Word Accuracy: {(1-wer)*100:.1f}%
531
  πŸ“ˆ Character Accuracy: {(1-cer)*100:.1f}%
532
 
533
- πŸ” Analysis:
534
- """ + "\n".join(error_analysis)
535
-
536
- return (
537
- status_msg,
538
- primary_result,
539
- specialized_result,
540
- f"{wer:.3f} ({(1-wer)*100:.1f}%)",
541
- f"{cer:.3f} ({(1-cer)*100:.1f}%)",
542
- comparison["intended_orig"],
543
- comparison["actual_orig"],
544
- comparison["intended_translit"],
545
- comparison["actual_translit"],
546
- comparison["word_diff_orig"],
547
- comparison["char_diff_orig"]
548
- )
 
 
 
 
 
 
 
 
 
549
 
550
  # ---------------- UI ---------------- #
551
  def create_interface():
@@ -555,21 +502,18 @@ def create_interface():
555
 
556
  **Perfect your pronunciation in English, Tamil, Malayalam, and Hindi!**
557
 
558
- This tool uses specialized AI models to give you detailed feedback on your pronunciation,
559
- including transliteration to help you understand exactly where you need improvement.
560
-
561
- ### How to use:
562
  1. 🌐 Select your target language
563
  2. 🎲 Generate a practice sentence
564
- 3. 🎀 Record yourself saying the sentence clearly
565
- 4. πŸ” Get detailed pronunciation analysis with transliteration
566
  """)
567
 
568
  with gr.Row():
569
  with gr.Column(scale=2):
570
  language_dropdown = gr.Dropdown(
571
  choices=list(LANG_CODES.keys()),
572
- value="Tamil",
573
  label="🌐 Select Language"
574
  )
575
  with gr.Column(scale=1):
@@ -585,29 +529,28 @@ def create_interface():
585
  audio_input = gr.Audio(
586
  sources=["microphone", "upload"],
587
  type="filepath",
588
- label="🎀 Record Your Pronunciation"
589
  )
590
 
591
  analyze_btn = gr.Button("πŸ” Analyze Pronunciation", variant="secondary", size="lg")
592
 
593
- with gr.Row():
594
- status_output = gr.Textbox(
595
- label="πŸ“Š Analysis Results",
596
- interactive=False,
597
- lines=8
598
- )
599
 
600
- with gr.Accordion("πŸ€– Model Outputs", open=False):
601
  with gr.Row():
602
- primary_output = gr.Textbox(label="Primary Model (IndicWhisper)", interactive=False)
603
- specialized_output = gr.Textbox(label="Specialized Model", interactive=False)
604
 
605
  with gr.Accordion("πŸ“ˆ Detailed Metrics", open=False):
606
  with gr.Row():
607
  wer_output = gr.Textbox(label="Word Error Rate", interactive=False)
608
  cer_output = gr.Textbox(label="Character Error Rate", interactive=False)
609
 
610
- gr.Markdown("### πŸ” Detailed Comparison")
611
 
612
  with gr.Row():
613
  with gr.Column():
@@ -619,8 +562,8 @@ def create_interface():
619
  intended_translit = gr.Textbox(label="🎯 Target (Romanized)", interactive=False)
620
  actual_translit = gr.Textbox(label="πŸ—£οΈ What You Said (Romanized)", interactive=False)
621
 
622
- gr.Markdown("### 🎨 Visual Comparison")
623
- gr.Markdown("**Green** = Correct, **Red** = Wrong/Missing, **Orange** = Added/Substituted")
624
 
625
  word_diff_html = gr.HTML(label="πŸ”€ Word-by-Word Comparison")
626
  char_diff_html = gr.HTML(label="πŸ” Character-by-Character Analysis")
@@ -649,22 +592,13 @@ def create_interface():
649
  )
650
 
651
  gr.Markdown("""
652
- ### πŸ“š Pro Tips for Better Pronunciation:
653
-
654
- - **Speak slowly and clearly** - Don't rush through the sentence
655
- - **Pronounce each syllable** - Break down complex words
656
- - **Check the romanized version** - Use it to understand correct pronunciation
657
- - **Practice repeatedly** - Use the same sentence multiple times to track improvement
658
- - **Focus on problem areas** - Pay attention to red-highlighted parts
659
- - **Record in a quiet environment** - Minimize background noise
660
-
661
- ### 🎯 Understanding the Feedback:
662
-
663
- - **Green highlights** = Perfect pronunciation βœ…
664
- - **Red highlights** = Missing or mispronounced ❌
665
- - **Orange highlights** = Added or substituted πŸ”„
666
- - **Transliteration** = Helps you see pronunciation patterns
667
- - **Error rates** = Lower is better (0% = perfect)
668
  """)
669
 
670
  return demo
@@ -672,10 +606,16 @@ def create_interface():
672
  # ---------------- LAUNCH ---------------- #
673
  if __name__ == "__main__":
674
  print("πŸš€ Starting Enhanced Pronunciation Comparator...")
 
 
 
 
 
675
  demo = create_interface()
676
  demo.launch(
677
  server_name="0.0.0.0",
678
  server_port=7860,
679
  share=True,
680
- show_error=True
 
681
  )
 
6
  import torch
7
  import warnings
8
  import contextlib
9
+ import traceback
10
+ import gc
11
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
12
  import librosa
13
  import numpy as np
 
17
  from indic_transliteration import sanscript
18
  from indic_transliteration.sanscript import transliterate
19
  INDIC_OK = True
20
+ print("βœ… Transliteration available")
21
  except:
22
  INDIC_OK = False
23
+ print("⚠️ indic_transliteration not available. Install with: pip install indic-transliteration")
24
 
25
  # Optional HF Spaces GPU decorator
26
  try:
27
  import spaces
28
  GPU_DECORATOR = spaces.GPU
29
+ print("βœ… HF Spaces GPU decorator available")
30
  except:
31
  class _NoOp:
32
  def __call__(self, f): return f
33
  GPU_DECORATOR = _NoOp()
34
+ print("⚠️ HF Spaces not available (normal for local usage)")
35
 
36
  warnings.filterwarnings("ignore")
37
 
 
40
  DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
41
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
42
  amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext
43
+
44
+ print(f"πŸ”§ Device: {DEVICE}")
45
+ print(f"πŸ”§ PyTorch version: {torch.__version__}")
46
+ if DEVICE == "cuda":
47
+ print(f"πŸ”§ CUDA available: {torch.cuda.is_available()}")
48
+ print(f"πŸ”§ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
49
 
50
  LANG_CODES = {
51
  "English": "en",
 
54
  "Hindi": "hi"
55
  }
56
 
57
+ # Simplified model setup - use only reliable models
 
 
 
58
  SPECIALIZED_MODELS = {
59
  "English": "openai/whisper-base.en",
60
+ "Tamil": "openai/whisper-large-v2", # More reliable fallback
61
+ "Malayalam": "openai/whisper-large-v2", # More reliable fallback
62
+ "Hindi": "openai/whisper-large-v2"
63
  }
64
 
65
  SCRIPT_PATTERNS = {
 
71
 
72
  # Transliteration mappings
73
  TRANSLITERATION_SCRIPTS = {
74
+ "Tamil": sanscript.TAMIL if INDIC_OK else None,
75
+ "Malayalam": sanscript.MALAYALAM if INDIC_OK else None,
76
+ "Hindi": sanscript.DEVANAGARI if INDIC_OK else None,
77
  "English": None
78
  }
79
 
 
109
  }
110
 
111
  # Model cache
112
+ models = {}
113
+
114
+ # ---------------- SAFE HELPERS ---------------- #
115
+ def safe_operation(func, *args, **kwargs):
116
+ """Wrapper for safe operations with detailed error reporting"""
117
+ try:
118
+ return func(*args, **kwargs), None
119
+ except Exception as e:
120
+ error_msg = f"Error in {func.__name__}: {str(e)}"
121
+ print(f"❌ {error_msg}")
122
+ print(f"πŸ” Traceback: {traceback.format_exc()}")
123
+ return None, error_msg
124
 
 
125
  def get_random_sentence(language_choice):
126
+ try:
127
+ return random.choice(SENTENCE_BANK[language_choice])
128
+ except Exception as e:
129
+ print(f"❌ Error getting sentence: {e}")
130
+ return "Error loading sentence"
131
 
132
  def is_correct_script(text, lang_name):
133
  """Check if text contains the expected script for the language"""
134
+ try:
135
+ if not text or not text.strip():
136
+ return False
137
+ pattern = SCRIPT_PATTERNS.get(lang_name)
138
+ if not pattern:
139
+ return True
140
+ return bool(pattern.search(text))
141
+ except Exception as e:
142
+ print(f"❌ Script check error: {e}")
143
+ return True # Default to True to avoid blocking
144
 
145
  def transliterate_text(text, lang_choice, to_romanized=True):
146
  """Transliterate text to/from romanized form"""
 
 
 
 
 
 
 
147
  try:
148
+ if not INDIC_OK or not text or not text.strip():
149
+ return text
150
+
151
+ source_script = TRANSLITERATION_SCRIPTS.get(lang_choice)
152
+ if not source_script:
153
+ return text
154
+
155
  if to_romanized:
 
156
  return transliterate(text, source_script, sanscript.HK)
157
  else:
 
158
  return transliterate(text, sanscript.HK, source_script)
159
  except Exception as e:
160
  print(f"⚠️ Transliteration failed: {e}")
161
  return text
162
 
163
  def preprocess_audio(audio_path, target_sr=16000):
164
+ """Enhanced audio preprocessing with better error handling"""
165
  try:
166
+ print(f"πŸ”Š Processing audio: {audio_path}")
167
+
168
+ if not audio_path:
169
+ return None, "No audio file provided"
170
+
171
+ # Load audio
172
  audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
173
+ print(f"πŸ”Š Audio loaded: shape={audio.shape}, sr={sr}")
174
+
175
  if audio is None or len(audio) == 0:
176
+ return None, "Empty audio file"
177
 
178
  # Normalize audio
179
  audio = audio.astype(np.float32)
 
184
  # Trim silence
185
  audio, _ = librosa.effects.trim(audio, top_db=20)
186
 
187
+ # Check minimum length (0.5 seconds)
188
+ min_length = int(target_sr * 0.5)
189
+ if len(audio) < min_length:
190
+ return None, f"Audio too short: {len(audio)/target_sr:.2f}s (minimum 0.5s)"
191
+
192
+ print(f"βœ… Audio processed successfully: {len(audio)/target_sr:.2f}s")
193
+ return audio, sr
194
+
195
  except Exception as e:
196
+ error_msg = f"Audio preprocessing failed: {str(e)}"
197
+ print(f"❌ {error_msg}")
198
+ return None, error_msg
199
 
200
+ # ---------------- SIMPLIFIED MODEL LOADERS ---------------- #
201
+ def load_whisper_model(model_name, language=None):
202
+ """Load a single Whisper model with robust error handling"""
203
+ model_key = f"{model_name}_{language}"
204
+
205
+ if model_key in models:
206
+ print(f"βœ… Using cached model: {model_key}")
207
+ return models[model_key], None
208
 
209
  try:
210
+ print(f"πŸ”„ Loading model: {model_name}")
211
 
212
+ # Clear GPU memory first
213
+ if DEVICE == "cuda":
214
+ torch.cuda.empty_cache()
215
+ gc.collect()
216
+
217
+ # Load with pipeline (simpler and more robust)
218
+ pipeline_model = pipeline(
219
  "automatic-speech-recognition",
220
+ model=model_name,
221
  device=DEVICE_INDEX,
222
  torch_dtype=DTYPE,
223
+ return_timestamps=False
224
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ models[model_key] = pipeline_model
227
+ print(f"βœ… Model loaded successfully: {model_name}")
228
+ return pipeline_model, None
 
 
 
229
 
230
  except Exception as e:
231
+ error_msg = f"Failed to load {model_name}: {str(e)}"
232
+ print(f"❌ {error_msg}")
233
+ return None, error_msg
234
 
235
  # ---------------- TRANSCRIPTION ---------------- #
236
+ def transcribe_audio_safe(audio_path, model_name, language):
237
+ """Safe transcription with comprehensive error handling"""
 
238
  try:
239
+ print(f"🎀 Starting transcription with {model_name} for {language}")
240
+
241
+ # Load model
242
+ model, model_error = load_whisper_model(model_name, language)
243
+ if model is None:
244
+ return None, f"Model loading failed: {model_error}"
245
+
246
+ # Set language if supported
247
+ lang_code = LANG_CODES.get(language, "en")
248
+ generate_kwargs = {
249
+ "max_new_tokens": 200,
250
+ "num_beams": 1, # Faster decoding
251
+ "do_sample": False
252
+ }
253
 
254
+ # Try to set language (some models support this)
255
  try:
256
+ if hasattr(model, "model") and hasattr(model, "tokenizer"):
257
+ forced_ids = model.tokenizer.get_decoder_prompt_ids(
258
  language=lang_code,
259
  task="transcribe"
260
  )
261
  if forced_ids:
262
+ model.model.config.forced_decoder_ids = forced_ids
263
+ print(f"πŸ”§ Language set to: {lang_code}")
264
+ except Exception as lang_error:
265
+ print(f"⚠️ Language forcing failed (continuing anyway): {lang_error}")
266
 
267
+ # Transcribe
268
+ print(f"πŸ”„ Transcribing...")
269
  with amp_ctx():
270
+ result = model(audio_path, generate_kwargs=generate_kwargs)
271
 
272
  if isinstance(result, dict):
273
+ text = result.get("text", "").strip()
274
+ else:
275
+ text = str(result).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ print(f"βœ… Transcription complete: '{text[:50]}...'")
278
+ return text, None
279
 
280
  except Exception as e:
281
+ error_msg = f"Transcription failed: {str(e)}"
282
+ print(f"❌ {error_msg}")
283
+ print(f"πŸ” Full traceback: {traceback.format_exc()}")
284
+ return None, error_msg
285
 
286
  # ---------------- ANALYSIS ---------------- #
287
+ def compute_metrics_safe(reference, hypothesis):
288
  """Compute WER and CER with error handling"""
289
  try:
 
290
  ref_clean = reference.strip()
291
  hyp_clean = hypothesis.strip()
292
 
293
  if not ref_clean or not hyp_clean:
294
+ return 1.0, 1.0, "Empty text"
295
 
 
296
  wer = jiwer.wer(ref_clean, hyp_clean)
297
  cer = jiwer.cer(ref_clean, hyp_clean)
298
 
299
+ return wer, cer, None
300
  except Exception as e:
301
+ print(f"❌ Metric computation failed: {e}")
302
+ return 1.0, 1.0, str(e)
303
 
304
  def get_pronunciation_score(wer, cer):
305
  """Convert error rates to intuitive scores and feedback"""
306
+ try:
307
+ combined_error = (wer * 0.7) + (cer * 0.3)
308
+ accuracy = 1 - combined_error
309
+
310
+ if accuracy >= 0.95:
311
+ return "πŸ† Perfect!", "Outstanding pronunciation! Native-like accuracy."
312
+ elif accuracy >= 0.85:
313
+ return "πŸŽ‰ Excellent!", "Very good pronunciation with minor variations."
314
+ elif accuracy >= 0.70:
315
+ return "πŸ‘ Good!", "Good pronunciation, practice specific sounds."
316
+ elif accuracy >= 0.50:
317
+ return "πŸ“š Needs Practice", "Focus on clearer pronunciation and rhythm."
318
+ else:
319
+ return "πŸ’ͺ Keep Trying!", "Break down into smaller parts and practice slowly."
320
+ except Exception as e:
321
+ print(f"❌ Score calculation failed: {e}")
322
+ return "❓ Unknown", "Could not calculate score"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ def highlight_differences_safe(reference, hypothesis):
325
+ """Safe highlighting with error handling"""
326
+ try:
327
+ if not reference or not hypothesis:
328
+ return "No text to compare", "No text to compare"
329
+
330
+ # Word-level highlighting
331
+ ref_words = reference.split()
332
+ hyp_words = hypothesis.split()
333
+
334
+ sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
335
+ word_html = []
336
+
337
+ for tag, i1, i2, j1, j2 in sm.get_opcodes():
338
+ if tag == 'equal':
339
+ word_html.extend([
340
+ f"<span style='background-color:#d4edda; color:#155724; padding:2px 4px; margin:1px; border-radius:3px'>{word}</span>"
341
+ for word in ref_words[i1:i2]
342
+ ])
343
+ elif tag == 'replace':
344
+ word_html.extend([
345
+ f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>"
346
+ for word in ref_words[i1:i2]
347
+ ])
348
+ word_html.extend([
349
+ f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>β†’{word}</span>"
350
+ for word in hyp_words[j1:j2]
351
+ ])
352
+ elif tag == 'delete':
353
+ word_html.extend([
354
+ f"<span style='background-color:#f8d7da; color:#721c24; padding:2px 4px; margin:1px; border-radius:3px; text-decoration:line-through'>{word}</span>"
355
+ for word in ref_words[i1:i2]
356
+ ])
357
+ elif tag == 'insert':
358
+ word_html.extend([
359
+ f"<span style='background-color:#fff3cd; color:#856404; padding:2px 4px; margin:1px; border-radius:3px'>+{word}</span>"
360
+ for word in hyp_words[j1:j2]
361
+ ])
362
+
363
+ # Character-level highlighting
364
+ sm_char = difflib.SequenceMatcher(None, list(reference), list(hypothesis))
365
+ char_html = []
366
+
367
+ for tag, i1, i2, j1, j2 in sm_char.get_opcodes():
368
+ if tag == 'equal':
369
+ char_html.extend([
370
+ f"<span style='color:#28a745'>{char}</span>"
371
+ for char in reference[i1:i2]
372
+ ])
373
+ elif tag in ('replace', 'delete'):
374
+ char_html.extend([
375
+ f"<span style='color:#dc3545; text-decoration:underline; font-weight:bold'>{char}</span>"
376
+ for char in reference[i1:i2]
377
+ ])
378
+ elif tag == 'insert':
379
+ char_html.extend([
380
+ f"<span style='color:#fd7e14; font-weight:bold'>{char}</span>"
381
+ for char in hypothesis[j1:j2]
382
+ ])
383
+
384
+ return " ".join(word_html), "".join(char_html)
385
+
386
+ except Exception as e:
387
+ print(f"❌ Highlighting failed: {e}")
388
+ return f"Error highlighting: {str(e)}", f"Error highlighting: {str(e)}"
389
 
390
  # ---------------- MAIN FUNCTION ---------------- #
391
  @GPU_DECORATOR
392
  def compare_pronunciation(audio, language_choice, intended_sentence):
393
+ """Main function with comprehensive error handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
+ print(f"\nπŸ” Starting pronunciation analysis...")
396
+ print(f"πŸ“ Language: {language_choice}")
397
+ print(f"🎯 Target: {intended_sentence[:50]}...")
398
 
399
+ try:
400
+ # Validate inputs
401
+ if audio is None:
402
+ return ("❌ Please record audio first", "", "", "", "", "", "", "", "", "", "")
403
+
404
+ if not intended_sentence or not intended_sentence.strip():
405
+ return ("❌ Please generate a sentence first", "", "", "", "", "", "", "", "", "", "")
406
+
407
+ # Preprocess audio
408
+ processed_audio, audio_error = preprocess_audio(audio)
409
+ if processed_audio is None:
410
+ return (f"❌ Audio processing failed: {audio_error}", "", "", "", "", "", "", "", "", "", "")
411
+
412
+ # Get models for this language
413
+ primary_model = "openai/whisper-large-v2" # Always reliable
414
+ specialized_model = SPECIALIZED_MODELS.get(language_choice, primary_model)
415
+
416
+ print(f"πŸ€– Using models: Primary={primary_model}, Specialized={specialized_model}")
417
+
418
+ # Try primary transcription
419
+ primary_text, primary_error = transcribe_audio_safe(audio, primary_model, language_choice)
420
+
421
+ # Try specialized transcription (if different from primary)
422
+ if specialized_model != primary_model:
423
+ specialized_text, specialized_error = transcribe_audio_safe(audio, specialized_model, language_choice)
424
+ else:
425
+ specialized_text, specialized_error = primary_text, primary_error
426
+
427
+ # Choose best result
428
+ best_text = None
429
+ best_source = "None"
430
+
431
+ if primary_text and not primary_text.startswith("Error"):
432
+ best_text = primary_text
433
+ best_source = "Primary Model"
434
+ elif specialized_text and not specialized_text.startswith("Error"):
435
+ best_text = specialized_text
436
+ best_source = "Specialized Model"
437
+
438
+ if not best_text:
439
+ error_details = f"Primary: {primary_error or 'Unknown error'}\nSpecialized: {specialized_error or 'Unknown error'}"
440
+ return (f"❌ Both models failed:\n{error_details}",
441
+ primary_text or "Failed", specialized_text or "Failed",
442
+ "", "", "", "", "", "", "", "")
443
+
444
+ print(f"βœ… Best transcription from {best_source}: '{best_text}'")
445
+
446
+ # Compute metrics
447
+ wer, cer, metric_error = compute_metrics_safe(intended_sentence, best_text)
448
+ if metric_error:
449
+ print(f"⚠️ Metric computation warning: {metric_error}")
450
+
451
+ # Get score and feedback
452
+ score, feedback = get_pronunciation_score(wer, cer)
453
+
454
+ # Create visual comparisons
455
+ word_diff, char_diff = highlight_differences_safe(intended_sentence, best_text)
456
+
457
+ # Transliterations
458
+ intended_translit = transliterate_text(intended_sentence, language_choice, to_romanized=True)
459
+ actual_translit = transliterate_text(best_text, language_choice, to_romanized=True)
460
+
461
+ # Create status message
462
+ status_msg = f"""βœ… Analysis Complete!
463
 
464
  {score}
465
  {feedback}
 
468
  πŸ“Š Word Accuracy: {(1-wer)*100:.1f}%
469
  πŸ“ˆ Character Accuracy: {(1-cer)*100:.1f}%
470
 
471
+ πŸ” Quick Tips:
472
+ β€’ Green = Correct pronunciation βœ…
473
+ β€’ Red = Wrong/missing words ❌
474
+ β€’ Orange = Added/substituted words πŸ”„
475
+ """
476
+
477
+ return (
478
+ status_msg,
479
+ primary_text or "Failed",
480
+ specialized_text or "Failed",
481
+ f"{wer:.3f} ({(1-wer)*100:.1f}%)",
482
+ f"{cer:.3f} ({(1-cer)*100:.1f}%)",
483
+ intended_sentence,
484
+ best_text,
485
+ intended_translit,
486
+ actual_translit,
487
+ word_diff,
488
+ char_diff
489
+ )
490
+
491
+ except Exception as e:
492
+ error_msg = f"❌ Unexpected error: {str(e)}"
493
+ print(f"{error_msg}")
494
+ print(f"πŸ” Full traceback: {traceback.format_exc()}")
495
+ return (error_msg, "", "", "", "", "", "", "", "", "", "")
496
 
497
  # ---------------- UI ---------------- #
498
  def create_interface():
 
502
 
503
  **Perfect your pronunciation in English, Tamil, Malayalam, and Hindi!**
504
 
505
+ ### πŸš€ How to use:
 
 
 
506
  1. 🌐 Select your target language
507
  2. 🎲 Generate a practice sentence
508
+ 3. 🎀 Record yourself clearly (at least 0.5 seconds)
509
+ 4. πŸ” Get detailed analysis with visual feedback
510
  """)
511
 
512
  with gr.Row():
513
  with gr.Column(scale=2):
514
  language_dropdown = gr.Dropdown(
515
  choices=list(LANG_CODES.keys()),
516
+ value="English", # Start with English for reliability
517
  label="🌐 Select Language"
518
  )
519
  with gr.Column(scale=1):
 
529
  audio_input = gr.Audio(
530
  sources=["microphone", "upload"],
531
  type="filepath",
532
+ label="🎀 Record Your Pronunciation (speak clearly for at least 0.5 seconds)"
533
  )
534
 
535
  analyze_btn = gr.Button("πŸ” Analyze Pronunciation", variant="secondary", size="lg")
536
 
537
+ status_output = gr.Textbox(
538
+ label="πŸ“Š Analysis Results",
539
+ interactive=False,
540
+ lines=10
541
+ )
 
542
 
543
+ with gr.Accordion("πŸ€– Model Outputs (Debug Info)", open=False):
544
  with gr.Row():
545
+ primary_output = gr.Textbox(label="Primary Model Output", interactive=False)
546
+ specialized_output = gr.Textbox(label="Specialized Model Output", interactive=False)
547
 
548
  with gr.Accordion("πŸ“ˆ Detailed Metrics", open=False):
549
  with gr.Row():
550
  wer_output = gr.Textbox(label="Word Error Rate", interactive=False)
551
  cer_output = gr.Textbox(label="Character Error Rate", interactive=False)
552
 
553
+ gr.Markdown("### πŸ” Side-by-Side Comparison")
554
 
555
  with gr.Row():
556
  with gr.Column():
 
562
  intended_translit = gr.Textbox(label="🎯 Target (Romanized)", interactive=False)
563
  actual_translit = gr.Textbox(label="πŸ—£οΈ What You Said (Romanized)", interactive=False)
564
 
565
+ gr.Markdown("### 🎨 Visual Feedback")
566
+ gr.Markdown("**🟒 Green** = Correct | **πŸ”΄ Red** = Wrong/Missing | **🟠 Orange** = Added/Substituted")
567
 
568
  word_diff_html = gr.HTML(label="πŸ”€ Word-by-Word Comparison")
569
  char_diff_html = gr.HTML(label="πŸ” Character-by-Character Analysis")
 
592
  )
593
 
594
  gr.Markdown("""
595
+ ### πŸ”§ Troubleshooting:
596
+
597
+ - **"Audio too short"** β†’ Record for at least 0.5 seconds
598
+ - **"Model loading failed"** β†’ Try refreshing the page
599
+ - **"Empty transcription"** β†’ Speak louder and clearer
600
+ - **Script mismatch** β†’ Make sure you're speaking the correct language
601
+ - **General errors** β†’ Check the debug info section for details
 
 
 
 
 
 
 
 
 
602
  """)
603
 
604
  return demo
 
606
  # ---------------- LAUNCH ---------------- #
607
  if __name__ == "__main__":
608
  print("πŸš€ Starting Enhanced Pronunciation Comparator...")
609
+ print("πŸ”§ System Check:")
610
+ print(f" - PyTorch: {torch.__version__}")
611
+ print(f" - Device: {DEVICE}")
612
+ print(f" - Transliteration: {'βœ…' if INDIC_OK else '❌'}")
613
+
614
  demo = create_interface()
615
  demo.launch(
616
  server_name="0.0.0.0",
617
  server_port=7860,
618
  share=True,
619
+ show_error=True,
620
+ debug=True # Enable debug mode for better error reporting
621
  )