sudhanm commited on
Commit
a950033
·
verified ·
1 Parent(s): e5c7bcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -294
app.py CHANGED
@@ -10,35 +10,26 @@ import librosa
10
  import soundfile as sf
11
  from indic_transliteration import sanscript
12
  from indic_transliteration.sanscript import transliterate
 
13
  import warnings
14
  import spaces
15
- warnings.filterwarnings("ignore")
16
 
17
- # Try to import whisper_jax, fallback to transformers if not available
18
- try:
19
- from whisper_jax import FlaxWhisperPipeline
20
- import jax.numpy as jnp
21
- WHISPER_JAX_AVAILABLE = True
22
- print("🚀 Using JAX-optimized IndicWhisper (70x faster!)")
23
- except ImportError:
24
- WHISPER_JAX_AVAILABLE = False
25
- print("⚠️ whisper_jax not available, using transformers fallback")
26
 
27
  # ---------------- CONFIG ---------------- #
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  print(f"🔧 Using device: {DEVICE}")
 
30
 
31
  LANG_CODES = {
32
  "English": "en",
33
- "Tamil": "ta",
34
  "Malayalam": "ml"
35
  }
36
 
37
- # SOTA IndicWhisper model - one model for all languages!
38
- INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
39
 
40
- # Fallback models if IndicWhisper fails
41
- FALLBACK_MODELS = {
42
  "English": "openai/whisper-base.en",
43
  "Tamil": "vasista22/whisper-tamil-large-v2",
44
  "Malayalam": "thennal/whisper-medium-ml"
@@ -55,7 +46,7 @@ LANG_PRIMERS = {
55
 
56
  SCRIPT_PATTERNS = {
57
  "Tamil": re.compile(r"[஀-௿]"),
58
- "Malayalam": re.compile(r"[ഀ-ൿ]"),
59
  "English": re.compile(r"[A-Za-z]")
60
  }
61
 
@@ -72,7 +63,7 @@ SENTENCE_BANK = {
72
  ],
73
  "Tamil": [
74
  "இன்று நல்ல வானிலை உள்ளது.",
75
- "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
76
  "எனக்கு புத்தகம் படிக்க விருப்பம்.",
77
  "தமிழ் மொழி மிகவும் அழகானது.",
78
  "குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.",
@@ -92,89 +83,47 @@ SENTENCE_BANK = {
92
  ]
93
  }
94
 
 
 
 
 
95
  # ---------------- MODEL CACHE ---------------- #
96
  indicwhisper_pipeline = None
97
  fallback_models = {}
 
98
 
99
- @spaces.GPU
100
- def load_indicwhisper():
101
- """Load the SOTA IndicWhisper model"""
102
- global indicwhisper_pipeline
103
-
104
- if indicwhisper_pipeline is None:
105
- try:
106
- print(f"🔄 Loading SOTA IndicWhisper: {INDICWHISPER_MODEL}")
107
-
108
- if WHISPER_JAX_AVAILABLE:
109
- # Use JAX-optimized version (70x faster!)
110
- indicwhisper_pipeline = FlaxWhisperPipeline(
111
- INDICWHISPER_MODEL,
112
- dtype=jnp.bfloat16,
113
- batch_size=1
114
- )
115
- print("✅ IndicWhisper loaded with JAX optimization (70x faster!)")
116
- else:
117
- # Fallback to transformers if whisper_jax not available
118
- from transformers import pipeline
119
- indicwhisper_pipeline = pipeline(
120
- "automatic-speech-recognition",
121
- model=INDICWHISPER_MODEL,
122
- device=DEVICE if DEVICE == "cuda" else -1
123
- )
124
- print("✅ IndicWhisper loaded with transformers (fallback mode)")
125
-
126
- except Exception as e:
127
- print(f"❌ Failed to load IndicWhisper: {e}")
128
- indicwhisper_pipeline = None
129
- raise Exception(f"Could not load IndicWhisper model: {str(e)}")
130
-
131
- return indicwhisper_pipeline
132
-
133
- @spaces.GPU
134
- def load_fallback_model(language):
135
- """Load fallback model if IndicWhisper fails"""
136
- if language not in fallback_models:
137
- model_name = FALLBACK_MODELS[language]
138
- print(f"🔄 Loading fallback model for {language}: {model_name}")
139
-
140
- try:
141
- processor = AutoProcessor.from_pretrained(model_name)
142
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
143
- model_name,
144
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
145
- low_cpu_mem_usage=True,
146
- use_safetensors=True
147
- ).to(DEVICE)
148
-
149
- fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
150
- print(f"✅ Fallback model loaded for {language}")
151
-
152
- except Exception as e:
153
- print(f"❌ Failed to load fallback {model_name}: {e}")
154
- raise Exception(f"Could not load fallback {language} model")
155
-
156
- return fallback_models[language]
157
 
158
- # ---------------- HELPERS ---------------- #
159
  def get_random_sentence(language_choice):
160
- """Get random sentence for practice"""
161
  return random.choice(SENTENCE_BANK[language_choice])
162
 
163
  def is_script(text, lang_name):
164
- """Check if text is in expected script"""
165
  pattern = SCRIPT_PATTERNS.get(lang_name)
166
  if not pattern:
167
  return True
168
- return bool(pattern.search(text))
 
 
 
 
 
 
 
 
 
169
 
170
  def transliterate_to_hk(text, lang_choice):
171
- """Transliterate Indic text to Harvard-Kyoto"""
172
  mapping = {
173
  "Tamil": sanscript.TAMIL,
174
  "Malayalam": sanscript.MALAYALAM,
175
  "English": None
176
  }
177
-
178
  script = mapping.get(lang_choice)
179
  if script and is_script(text, lang_choice):
180
  try:
@@ -185,142 +134,191 @@ def transliterate_to_hk(text, lang_choice):
185
  return text
186
 
187
  def preprocess_audio(audio_path, target_sr=16000):
188
- """Preprocess audio for ASR"""
189
  try:
190
- # Load audio
191
  audio, sr = librosa.load(audio_path, sr=target_sr)
192
-
193
- # Normalize audio
194
  if np.max(np.abs(audio)) > 0:
195
  audio = audio / np.max(np.abs(audio))
196
-
197
- # Remove silence from beginning and end
198
  audio, _ = librosa.effects.trim(audio, top_db=20)
199
-
200
- # Ensure minimum length
201
- if len(audio) < target_sr * 0.1: # Less than 0.1 seconds
202
  return None, None
203
-
204
  return audio, target_sr
205
  except Exception as e:
206
  print(f"Audio preprocessing error: {e}")
207
  return None, None
208
 
209
  @spaces.GPU
210
- def transcribe_with_indicwhisper(audio_path, language):
211
- """Transcribe using SOTA IndicWhisper"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  try:
213
- pipeline = load_indicwhisper()
214
-
215
- if WHISPER_JAX_AVAILABLE and hasattr(pipeline, '__call__'):
216
- # JAX-optimized version
217
- result = pipeline(audio_path)
218
- if isinstance(result, dict) and 'text' in result:
219
- return result['text'].strip()
 
 
 
 
 
 
 
 
 
 
 
 
220
  elif isinstance(result, str):
221
  return result.strip()
222
  else:
223
  return str(result).strip()
224
  else:
225
- # Transformers fallback
226
- result = pipeline(audio_path)
227
- return result.get('text', '').strip()
228
-
229
  except Exception as e:
230
- print(f"IndicWhisper transcription error: {e}")
231
  raise e
232
 
233
  @spaces.GPU
234
- def transcribe_with_fallback(audio_path, language):
235
- """Transcribe using fallback models"""
236
  try:
237
- components = load_fallback_model(language)
238
  processor = components["processor"]
239
  model = components["model"]
240
-
241
- # Preprocess audio
242
  audio, sr = preprocess_audio(audio_path)
243
  if audio is None:
244
  return "Error: Audio too short or could not be processed"
245
-
246
- # Prepare inputs
247
  inputs = processor(
248
- audio,
249
- sampling_rate=sr,
250
  return_tensors="pt",
251
  padding=True
252
  )
253
-
254
- # Move to device
255
  input_features = inputs.input_features.to(DEVICE)
256
-
257
- # Generate transcription
 
 
 
 
 
 
 
 
 
 
 
258
  with torch.no_grad():
259
- generate_kwargs = {
260
- "input_features": input_features,
261
  "max_length": 200,
262
  "num_beams": 3,
263
  "do_sample": False
264
  }
265
-
266
- # Language forcing for non-English
267
- if language != "English":
268
- lang_code = LANG_CODES.get(language, "en")
269
- try:
270
- if hasattr(processor, 'get_decoder_prompt_ids'):
271
- forced_decoder_ids = processor.get_decoder_prompt_ids(
272
- language=lang_code,
273
- task="transcribe"
274
- )
275
- generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
276
- except Exception as e:
277
- print(f"⚠️ Language forcing failed: {e}")
278
-
279
- predicted_ids = model.generate(**generate_kwargs)
280
-
281
- # Decode
282
  transcription = processor.batch_decode(
283
- predicted_ids,
284
  skip_special_tokens=True,
285
  clean_up_tokenization_spaces=True
286
  )[0]
287
-
288
  return transcription.strip() or "(No transcription generated)"
289
-
290
  except Exception as e:
291
- print(f"Fallback transcription error: {e}")
292
  return f"Error: {str(e)[:150]}..."
293
 
294
  @spaces.GPU
295
- def transcribe_audio(audio_path, language, initial_prompt="", use_fallback=False):
296
- """Main transcription function with IndicWhisper + fallback"""
297
  try:
298
- if use_fallback:
299
- print(f"🔄 Using fallback model for {language}")
300
- return transcribe_with_fallback(audio_path, language)
301
  else:
302
- print(f"🔄 Using SOTA IndicWhisper for {language}")
303
- return transcribe_with_indicwhisper(audio_path, language)
304
-
305
  except Exception as e:
306
- print(f"Transcription failed, trying fallback: {e}")
307
- if not use_fallback:
308
- # Retry with fallback
309
- return transcribe_audio(audio_path, language, initial_prompt, use_fallback=True)
310
  else:
311
  return f"Error: All transcription methods failed - {str(e)[:100]}"
312
 
313
  def highlight_differences(ref, hyp):
314
- """Highlight word-level differences with better styling"""
315
  if not ref.strip() or not hyp.strip():
316
  return "No text to compare"
317
-
318
  ref_words = ref.strip().split()
319
  hyp_words = hyp.strip().split()
320
-
321
  sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
322
  out_html = []
323
-
324
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
325
  if tag == 'equal':
326
  out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
@@ -331,17 +329,15 @@ def highlight_differences(ref, hyp):
331
  out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
332
  elif tag == 'insert':
333
  out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
334
-
335
  return " ".join(out_html)
336
 
337
  def char_level_highlight(ref, hyp):
338
- """Highlight character-level differences"""
339
  if not ref.strip() or not hyp.strip():
340
  return "No text to compare"
341
-
342
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
343
  out = []
344
-
345
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
346
  if tag == 'equal':
347
  out.extend([f"<span style='color:green; background-color:#e8f5e8;'>{c}</span>" for c in ref[i1:i2]])
@@ -349,14 +345,10 @@ def char_level_highlight(ref, hyp):
349
  out.extend([f"<span style='color:red; text-decoration:underline; background-color:#ffe8e8; font-weight:bold;'>{c}</span>" for c in ref[i1:i2]])
350
  elif tag == 'insert':
351
  out.extend([f"<span style='color:orange; background-color:#fff3cd; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]])
352
-
353
  return "".join(out)
354
 
355
  def get_pronunciation_score(wer_val, cer_val):
356
- """Calculate pronunciation score and feedback"""
357
- # Weight WER more heavily than CER
358
  combined_score = (wer_val * 0.7) + (cer_val * 0.3)
359
-
360
  if combined_score <= 0.1:
361
  return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!"
362
  elif combined_score <= 0.2:
@@ -368,83 +360,75 @@ def get_pronunciation_score(wer_val, cer_val):
368
  else:
369
  return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
370
 
371
- # ---------------- MAIN FUNCTION ---------------- #
372
  @spaces.GPU
373
  def compare_pronunciation(audio, language_choice, intended_sentence):
374
- """Main function to compare pronunciation using SOTA IndicWhisper"""
375
- print(f"🔍 Starting SOTA analysis with language: {language_choice}")
376
  print(f"📝 Audio file: {audio}")
377
  print(f"🎯 Intended sentence: {intended_sentence}")
378
-
379
  if audio is None:
380
  print("❌ No audio provided")
381
  return ("❌ Please record audio first.", "", "", "", "", "", "", "")
382
-
383
  if not intended_sentence.strip():
384
  print("❌ No intended sentence")
385
  return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
386
-
387
  try:
388
- print(f"🔍 Analyzing pronunciation using SOTA IndicWhisper...")
389
-
390
- # Pass 1: SOTA IndicWhisper transcription
391
- print("🔄 Starting Pass 1: SOTA IndicWhisper transcription...")
392
- actual_text = transcribe_audio(audio, language_choice, use_fallback=False)
393
- print(f"✅ SOTA Pass 1 result: {actual_text}")
394
-
395
- # Pass 2: Fallback model for comparison
396
- print("🔄 Starting Pass 2: Fallback model transcription...")
397
- fallback_text = transcribe_audio(audio, language_choice, use_fallback=True)
398
- print(f"✅ Fallback Pass 2 result: {fallback_text}")
399
-
400
- # Handle transcription errors
401
- if actual_text.startswith("Error:"):
402
  print(f"❌ Transcription error: {actual_text}")
403
  return (f"❌ {actual_text}", "", "", "", "", "", "", "")
404
-
405
- # Calculate error metrics using the better transcription
 
 
 
406
  try:
407
  print("🔄 Calculating error metrics...")
408
- wer_val = jiwer.wer(intended_sentence, actual_text)
409
- cer_val = jiwer.cer(intended_sentence, actual_text)
410
  print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}")
411
  except Exception as e:
412
  print(f"❌ Error calculating metrics: {e}")
413
  wer_val, cer_val = 1.0, 1.0
414
-
415
- # Get pronunciation score and feedback
416
  score_text, feedback = get_pronunciation_score(wer_val, cer_val)
417
- print(f"✅ Score: {score_text}")
418
-
419
- # Transliterations
420
  print("🔄 Generating transliterations...")
421
  actual_hk = transliterate_to_hk(actual_text, language_choice)
422
  target_hk = transliterate_to_hk(intended_sentence, language_choice)
423
-
424
- # Handle script mismatches
425
  if not is_script(actual_text, language_choice) and language_choice != "English":
426
  actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
427
-
428
- # Visual feedback
429
  print("🔄 Generating visual feedback...")
430
  diff_html = highlight_differences(intended_sentence, actual_text)
431
  char_html = char_level_highlight(intended_sentence, actual_text)
432
-
433
- # Status message with SOTA info
434
- status = f"✅ SOTA Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by IndicWhisper (AI4Bharat SOTA)"
435
- print(f"✅ SOTA analysis completed successfully")
436
-
437
  return (
438
  status,
439
- actual_text or "(No transcription)",
440
- fallback_text or "(No fallback transcription)",
441
  f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)",
442
  f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% character accuracy)",
443
  diff_html,
444
  char_html,
445
  f"🎯 Target: {intended_sentence}"
446
  )
447
-
448
  except Exception as e:
449
  error_msg = f"❌ Analysis Error: {str(e)[:200]}"
450
  print(f"❌ FATAL ERROR: {e}")
@@ -452,175 +436,159 @@ def compare_pronunciation(audio, language_choice, intended_sentence):
452
  traceback.print_exc()
453
  return (error_msg, str(e), "", "", "", "", "", "")
454
 
455
- # ---------------- UI ---------------- #
456
  def create_interface():
457
  with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
458
-
459
  gr.Markdown("""
460
- # 🎙️ SOTA Multilingual Pronunciation Trainer
461
-
462
- **Practice pronunciation in Tamil, Malayalam & English** using **IndicWhisper - the State-of-the-Art ASR model**!
463
-
464
- ### 🏆 **Powered by IndicWhisper:**
465
- - **SOTA Performance:** Lowest WER on 39/59 benchmarks for Indian languages
466
- - **JAX-Optimized:** 70x faster than standard implementations
467
- - **AI4Bharat Research:** Built by IIT Madras for maximum accuracy
468
-
469
  ### 📋 How to Use:
470
- 1. **Select** your target language 🌍
471
- 2. **Generate** a practice sentence 🎲
472
- 3. **Record** yourself reading it aloud 🎤
473
- 4. **Get** detailed feedback with SOTA-level accuracy 📊
474
-
475
  ### 🎯 Features:
476
- - **SOTA + Fallback analysis** for comprehensive assessment
477
- - **Visual highlighting** of pronunciation errors
478
- - **Romanization** for Indic scripts
479
- - **Advanced metrics** (Word & Character accuracy)
480
  """)
481
-
482
  with gr.Row():
483
  with gr.Column(scale=3):
484
  lang_choice = gr.Dropdown(
485
- choices=list(LANG_CODES.keys()),
486
  value="Tamil",
487
  label="🌍 Select Language"
488
  )
489
  with gr.Column(scale=1):
490
  gen_btn = gr.Button("🎲 Generate Sentence", variant="primary")
491
-
492
  intended_display = gr.Textbox(
493
  label="📝 Practice Sentence (Read this aloud)",
494
  placeholder="Click 'Generate Sentence' to get started...",
495
  interactive=False,
496
  lines=3
497
  )
498
-
499
  audio_input = gr.Audio(
500
- sources=["microphone", "upload"],
501
  type="filepath",
502
  label="🎤 Record Your Pronunciation"
503
  )
504
-
505
- analyze_btn = gr.Button("🔍 Analyze with SOTA IndicWhisper", variant="primary")
506
-
507
  status_output = gr.Textbox(
508
- label="📊 SOTA Analysis Results",
509
  interactive=False,
510
  lines=4
511
  )
512
-
513
  with gr.Row():
514
  with gr.Column():
515
  pass1_out = gr.Textbox(
516
- label="🏆 SOTA IndicWhisper Output",
517
  interactive=False,
518
  lines=2
519
  )
520
  wer_out = gr.Textbox(
521
- label="📈 Word Accuracy",
522
  interactive=False
523
  )
524
-
525
  with gr.Column():
526
  pass2_out = gr.Textbox(
527
- label="🔧 Fallback Model Comparison",
528
  interactive=False,
529
  lines=2
530
  )
 
531
  cer_out = gr.Textbox(
532
- label="📊 Character Accuracy",
533
  interactive=False
534
  )
535
-
536
  with gr.Accordion("📝 Detailed Visual Feedback", open=True):
537
  gr.Markdown("""
538
  ### 🎨 Color Guide:
539
- - 🟢 **Green**: Correctly pronounced words/characters
540
- - 🔴 **Red**: Missing or mispronounced (strikethrough)
541
- - 🟠 **Orange**: Extra words or substitutions
542
  """)
543
-
544
- diff_html_box = gr.HTML(
545
- label="🔍 Word-Level Analysis",
546
- show_label=True
547
- )
548
- char_html_box = gr.HTML(
549
- label="🔤 Character-Level Analysis",
550
- show_label=True
551
- )
552
-
553
  target_display = gr.Textbox(
554
  label="🎯 Reference Text",
555
  interactive=False,
556
  visible=False
557
  )
558
-
559
- # Event handlers for buttons
560
  gen_btn.click(
561
  fn=get_random_sentence,
562
  inputs=[lang_choice],
563
  outputs=[intended_display]
564
  )
565
-
566
  analyze_btn.click(
567
  fn=compare_pronunciation,
568
  inputs=[audio_input, lang_choice, intended_display],
569
  outputs=[
570
- status_output, # status
571
- pass1_out, # SOTA IndicWhisper
572
- pass2_out, # fallback comparison
573
- wer_out, # wer formatted
574
- cer_out, # cer formatted
575
- diff_html_box, # diff_html
576
- char_html_box, # char_html
577
- target_display # target_display
578
  ]
579
  )
580
-
581
- # Auto-generate sentence on language change
582
  lang_choice.change(
583
  fn=get_random_sentence,
584
  inputs=[lang_choice],
585
  outputs=[intended_display]
586
  )
587
-
588
- # Footer
589
  gr.Markdown("""
590
  ---
591
- ### 🏆 **SOTA Technology Stack:**
592
- - **Primary ASR**: IndicWhisper (AI4Bharat/IIT Madras) - SOTA for Indian languages
593
- - **JAX Optimization**: 70x speed improvement with `parthiv11/indic_whisper_nodcil`
594
- - **Fallback Models**: Specialized fine-tuned models for comparison
595
- - **Benchmark Performance**: Lowest WER on 39/59 Vistaar benchmarks
596
- - **Training Data**: 10,700+ hours across 12 Indian languages
597
-
598
- ### 🔧 **Technical Details:**
599
- - **Metrics**: WER (Word Error Rate) and CER (Character Error Rate)
600
- - **Transliteration**: Harvard-Kyoto system for Indic scripts
601
- - **Analysis**: SOTA + Fallback comparison for comprehensive feedback
602
- - **Languages**: English, Tamil, and Malayalam with SOTA accuracy
603
-
604
- **Note**: Using the most advanced ASR models available for Indian language pronunciation assessment.
605
- **Research**: Based on "Vistaar: Diverse Benchmarks and Training Sets for Indian Language ASR" (AI4Bharat, 2023)
606
  """)
607
-
608
  return demo
609
 
610
- # ---------------- LAUNCH ---------------- #
611
  if __name__ == "__main__":
612
- print("🚀 Starting SOTA Multilingual Pronunciation Trainer...")
613
- print(f"🔧 Device: {DEVICE}")
614
  print(f"🔧 PyTorch version: {torch.__version__}")
615
- print("🏆 Using IndicWhisper - State-of-the-Art for Indian Languages")
616
- print("⚡ JAX optimization: 70x speed improvement available")
617
- print("📊 SOTA Performance: Lowest WER on 39/59 benchmarks")
618
  print("🎮 GPU functions decorated with @spaces.GPU for HuggingFace Spaces")
619
-
620
  demo = create_interface()
621
  demo.launch(
622
  share=True,
623
  show_error=True,
624
  server_name="0.0.0.0",
625
  server_port=7860
626
- )
 
10
  import soundfile as sf
11
  from indic_transliteration import sanscript
12
  from indic_transliteration.sanscript import transliterate
13
+ import unicodedata
14
  import warnings
15
  import spaces
 
16
 
17
+ warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
18
 
19
  # ---------------- CONFIG ---------------- #
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  print(f"🔧 Using device: {DEVICE}")
22
+ DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
23
 
24
  LANG_CODES = {
25
  "English": "en",
26
+ "Tamil": "ta",
27
  "Malayalam": "ml"
28
  }
29
 
30
+ INDICWHISPER_MODEL = "openai/whisper-large-v2"
 
31
 
32
+ SPECIALIZED_MODELS = {
 
33
  "English": "openai/whisper-base.en",
34
  "Tamil": "vasista22/whisper-tamil-large-v2",
35
  "Malayalam": "thennal/whisper-medium-ml"
 
46
 
47
  SCRIPT_PATTERNS = {
48
  "Tamil": re.compile(r"[஀-௿]"),
49
+ "Malayalam": re.compile(r"[ഀ-ൿ]"),
50
  "English": re.compile(r"[A-Za-z]")
51
  }
52
 
 
63
  ],
64
  "Tamil": [
65
  "இன்று நல்ல வானிலை உள்ளது.",
66
+ "நான் தமிழ் கற்றுக்கொண்டு இருக்கிறேன்.",
67
  "எனக்கு புத்தகம் படிக்க விருப்பம்.",
68
  "தமிழ் மொழி மிகவும் அழகானது.",
69
  "குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.",
 
83
  ]
84
  }
85
 
86
+ # Controls for stricter script checking and normalization
87
+ STRICT_SCRIPT_CHECK = False # set True for strict script-only validation
88
+ NORMALIZE_TEXT_FOR_METRICS = True
89
+
90
  # ---------------- MODEL CACHE ---------------- #
91
  indicwhisper_pipeline = None
92
  fallback_models = {}
93
+ WHISPER_JAX_AVAILABLE = False
94
 
95
+ def normalize_text(s: str) -> str:
96
+ if not NORMALIZE_TEXT_FOR_METRICS:
97
+ return s
98
+ # Normalize unicode and collapse whitespace; do not remove language-specific punctuation
99
+ s = unicodedata.normalize("NFC", s)
100
+ s = re.sub(r"\s+", " ", s).strip()
101
+ return s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
103
  def get_random_sentence(language_choice):
 
104
  return random.choice(SENTENCE_BANK[language_choice])
105
 
106
  def is_script(text, lang_name):
 
107
  pattern = SCRIPT_PATTERNS.get(lang_name)
108
  if not pattern:
109
  return True
110
+ if not STRICT_SCRIPT_CHECK:
111
+ # any occurrence of script chars counts as match
112
+ return bool(pattern.search(text))
113
+ # strict: allow only spaces and target script chars
114
+ for ch in text:
115
+ if ch.isspace():
116
+ continue
117
+ if not pattern.match(ch):
118
+ return False
119
+ return True
120
 
121
  def transliterate_to_hk(text, lang_choice):
 
122
  mapping = {
123
  "Tamil": sanscript.TAMIL,
124
  "Malayalam": sanscript.MALAYALAM,
125
  "English": None
126
  }
 
127
  script = mapping.get(lang_choice)
128
  if script and is_script(text, lang_choice):
129
  try:
 
134
  return text
135
 
136
  def preprocess_audio(audio_path, target_sr=16000):
 
137
  try:
 
138
  audio, sr = librosa.load(audio_path, sr=target_sr)
 
 
139
  if np.max(np.abs(audio)) > 0:
140
  audio = audio / np.max(np.abs(audio))
 
 
141
  audio, _ = librosa.effects.trim(audio, top_db=20)
142
+ if len(audio) < target_sr * 0.1:
 
 
143
  return None, None
 
144
  return audio, target_sr
145
  except Exception as e:
146
  print(f"Audio preprocessing error: {e}")
147
  return None, None
148
 
149
  @spaces.GPU
150
+ def load_indicwhisper():
151
+ global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
152
+ if indicwhisper_pipeline is None:
153
+ try:
154
+ # Try JAX pipeline
155
+ try:
156
+ from whisper_jax import FlaxWhisperPipeline
157
+ import jax.numpy as jnp
158
+ print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
159
+ indicwhisper_pipeline = FlaxWhisperPipeline(
160
+ INDICWHISPER_MODEL,
161
+ dtype=jnp.bfloat16,
162
+ batch_size=1
163
+ )
164
+ WHISPER_JAX_AVAILABLE = True
165
+ print("✅ JAX-optimized model loaded successfully!")
166
+ return indicwhisper_pipeline
167
+ except Exception as e:
168
+ print(f"⚠️ JAX loading failed: {e}")
169
+ WHISPER_JAX_AVAILABLE = False
170
+
171
+ # Fallback to transformers pipeline
172
+ print(f"🔄 Loading transformers pipeline: {INDICWHISPER_MODEL}")
173
+ from transformers import pipeline
174
+ indicwhisper_pipeline = pipeline(
175
+ "automatic-speech-recognition",
176
+ model=INDICWHISPER_MODEL,
177
+ device=DEVICE_INDEX,
178
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
179
+ )
180
+ print("✅ High-performance model loaded with transformers!")
181
+ except Exception as e:
182
+ print(f"❌ Failed to load primary model: {e}")
183
+ indicwhisper_pipeline = None
184
+ raise Exception(f"Could not load high-performance model: {str(e)}")
185
+ return indicwhisper_pipeline
186
+
187
+ @spaces.GPU
188
+ def load_specialized_model(language):
189
+ if language not in fallback_models:
190
+ model_name = SPECIALIZED_MODELS[language]
191
+ print(f"🔄 Loading specialized model for {language}: {model_name}")
192
+ try:
193
+ processor = AutoProcessor.from_pretrained(model_name)
194
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
195
+ model_name,
196
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
197
+ low_cpu_mem_usage=True,
198
+ use_safetensors=True
199
+ ).to(DEVICE)
200
+ model.eval()
201
+ fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
202
+ print(f"✅ Specialized model loaded for {language}")
203
+ except Exception as e:
204
+ print(f"❌ Failed to load specialized {model_name}: {e}")
205
+ raise Exception(f"Could not load specialized {language} model")
206
+ return fallback_models[language]
207
+
208
+ @spaces.GPU
209
+ def transcribe_with_primary_model(audio_path, language):
210
  try:
211
+ pipe = load_indicwhisper()
212
+
213
+ if callable(pipe):
214
+ # Try to set forced decoder ids when available
215
+ if language != "English":
216
+ lang_code = LANG_CODES.get(language, "en")
217
+ try:
218
+ if hasattr(pipe, "model") and hasattr(pipe, "tokenizer"):
219
+ if hasattr(pipe.model, "config"):
220
+ forced_ids = pipe.tokenizer.get_decoder_prompt_ids(
221
+ language=lang_code, task="transcribe"
222
+ )
223
+ pipe.model.config.forced_decoder_ids = forced_ids
224
+ except Exception as e:
225
+ print(f"⚠️ Language forcing failed: {e}")
226
+
227
+ result = pipe(audio_path)
228
+ if isinstance(result, dict) and "text" in result:
229
+ return result["text"].strip()
230
  elif isinstance(result, str):
231
  return result.strip()
232
  else:
233
  return str(result).strip()
234
  else:
235
+ return "Error: Pipeline not properly initialized"
 
 
 
236
  except Exception as e:
237
+ print(f"Primary model transcription error: {e}")
238
  raise e
239
 
240
  @spaces.GPU
241
+ def transcribe_with_specialized_model(audio_path, language):
 
242
  try:
243
+ components = load_specialized_model(language)
244
  processor = components["processor"]
245
  model = components["model"]
246
+
 
247
  audio, sr = preprocess_audio(audio_path)
248
  if audio is None:
249
  return "Error: Audio too short or could not be processed"
250
+
 
251
  inputs = processor(
252
+ audio,
253
+ sampling_rate=sr,
254
  return_tensors="pt",
255
  padding=True
256
  )
 
 
257
  input_features = inputs.input_features.to(DEVICE)
258
+
259
+ forced_decoder_ids = None
260
+ if language != "English":
261
+ lang_code = LANG_CODES.get(language, "en")
262
+ try:
263
+ if hasattr(processor, "get_decoder_prompt_ids"):
264
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
265
+ language=lang_code,
266
+ task="transcribe"
267
+ )
268
+ except Exception as e:
269
+ print(f"⚠️ Language forcing failed: {e}")
270
+
271
  with torch.no_grad():
272
+ gen_kwargs = {
 
273
  "max_length": 200,
274
  "num_beams": 3,
275
  "do_sample": False
276
  }
277
+ if forced_decoder_ids:
278
+ gen_kwargs["forced_decoder_ids"] = forced_decoder_ids
279
+
280
+ predicted_ids = model.generate(
281
+ input_features,
282
+ **gen_kwargs
283
+ )
284
+
 
 
 
 
 
 
 
 
 
285
  transcription = processor.batch_decode(
286
+ predicted_ids,
287
  skip_special_tokens=True,
288
  clean_up_tokenization_spaces=True
289
  )[0]
290
+
291
  return transcription.strip() or "(No transcription generated)"
 
292
  except Exception as e:
293
+ print(f"Specialized model transcription error: {e}")
294
  return f"Error: {str(e)[:150]}..."
295
 
296
  @spaces.GPU
297
+ def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
 
298
  try:
299
+ if use_specialized:
300
+ print(f"🔄 Using specialized model for {language}")
301
+ return transcribe_with_specialized_model(audio_path, language)
302
  else:
303
+ print(f"🔄 Using high-performance primary model for {language}")
304
+ return transcribe_with_primary_model(audio_path, language)
 
305
  except Exception as e:
306
+ print(f"Transcription failed, trying specialized model: {e}")
307
+ if not use_specialized:
308
+ return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
 
309
  else:
310
  return f"Error: All transcription methods failed - {str(e)[:100]}"
311
 
312
  def highlight_differences(ref, hyp):
 
313
  if not ref.strip() or not hyp.strip():
314
  return "No text to compare"
315
+
316
  ref_words = ref.strip().split()
317
  hyp_words = hyp.strip().split()
318
+
319
  sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
320
  out_html = []
321
+
322
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
323
  if tag == 'equal':
324
  out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
 
329
  out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
330
  elif tag == 'insert':
331
  out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
332
+
333
  return " ".join(out_html)
334
 
335
  def char_level_highlight(ref, hyp):
 
336
  if not ref.strip() or not hyp.strip():
337
  return "No text to compare"
338
+
339
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
340
  out = []
 
341
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
342
  if tag == 'equal':
343
  out.extend([f"<span style='color:green; background-color:#e8f5e8;'>{c}</span>" for c in ref[i1:i2]])
 
345
  out.extend([f"<span style='color:red; text-decoration:underline; background-color:#ffe8e8; font-weight:bold;'>{c}</span>" for c in ref[i1:i2]])
346
  elif tag == 'insert':
347
  out.extend([f"<span style='color:orange; background-color:#fff3cd; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]])
 
348
  return "".join(out)
349
 
350
  def get_pronunciation_score(wer_val, cer_val):
 
 
351
  combined_score = (wer_val * 0.7) + (cer_val * 0.3)
 
352
  if combined_score <= 0.1:
353
  return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!"
354
  elif combined_score <= 0.2:
 
360
  else:
361
  return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
362
 
 
363
  @spaces.GPU
364
  def compare_pronunciation(audio, language_choice, intended_sentence):
365
+ print(f"🔍 Starting advanced analysis with language: {language_choice}")
 
366
  print(f"📝 Audio file: {audio}")
367
  print(f"🎯 Intended sentence: {intended_sentence}")
368
+
369
  if audio is None:
370
  print("❌ No audio provided")
371
  return ("❌ Please record audio first.", "", "", "", "", "", "", "")
372
+
373
  if not intended_sentence.strip():
374
  print("❌ No intended sentence")
375
  return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
376
+
377
  try:
378
+ print(f"🔄 Starting Pass 1: High-performance model transcription...")
379
+ primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
380
+ print(f"✅ Primary model result: {primary_text}")
381
+
382
+ print("🔄 Starting Pass 2: Specialized model transcription...")
383
+ specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
384
+ print(f"✅ Specialized model result: {specialized_text}")
385
+
386
+ actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
387
+
388
+ if str(actual_text).startswith("Error:"):
 
 
 
389
  print(f"❌ Transcription error: {actual_text}")
390
  return (f"❌ {actual_text}", "", "", "", "", "", "", "")
391
+
392
+ # Normalize for metrics if enabled
393
+ ref_for_metrics = normalize_text(intended_sentence)
394
+ hyp_for_metrics = normalize_text(actual_text)
395
+
396
  try:
397
  print("🔄 Calculating error metrics...")
398
+ wer_val = jiwer.wer(ref_for_metrics, hyp_for_metrics)
399
+ cer_val = jiwer.cer(ref_for_metrics, hyp_for_metrics)
400
  print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}")
401
  except Exception as e:
402
  print(f"❌ Error calculating metrics: {e}")
403
  wer_val, cer_val = 1.0, 1.0
404
+
 
405
  score_text, feedback = get_pronunciation_score(wer_val, cer_val)
406
+
 
 
407
  print("🔄 Generating transliterations...")
408
  actual_hk = transliterate_to_hk(actual_text, language_choice)
409
  target_hk = transliterate_to_hk(intended_sentence, language_choice)
410
+
 
411
  if not is_script(actual_text, language_choice) and language_choice != "English":
412
  actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
413
+
 
414
  print("🔄 Generating visual feedback...")
415
  diff_html = highlight_differences(intended_sentence, actual_text)
416
  char_html = char_level_highlight(intended_sentence, actual_text)
417
+
418
+ status = f"✅ Advanced Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models"
419
+ print(f"✅ Advanced analysis completed successfully")
420
+
 
421
  return (
422
  status,
423
+ primary_text or "(No primary transcription)",
424
+ specialized_text or "(No specialized transcription)",
425
  f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)",
426
  f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% character accuracy)",
427
  diff_html,
428
  char_html,
429
  f"🎯 Target: {intended_sentence}"
430
  )
431
+
432
  except Exception as e:
433
  error_msg = f"❌ Analysis Error: {str(e)[:200]}"
434
  print(f"❌ FATAL ERROR: {e}")
 
436
  traceback.print_exc()
437
  return (error_msg, str(e), "", "", "", "", "", "")
438
 
 
439
  def create_interface():
440
  with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
441
+
442
  gr.Markdown("""
443
+ # 🎙️ Advanced Multilingual Pronunciation Trainer
444
+
445
+ Practice pronunciation in Tamil, Malayalam & English using high-performance ASR models!
446
+
447
+ ### 🏆 Powered by Advanced Models:
448
+ - Dual-Model Analysis: Primary + specialized model comparison
449
+ - High Accuracy: Language-specific fine-tuned models
450
+ - Robust Performance: Automatic fallback for reliability
451
+
452
  ### 📋 How to Use:
453
+ 1. Select your target language 🌍
454
+ 2. Generate a practice sentence 🎲
455
+ 3. Record yourself reading it aloud 🎤
456
+ 4. Get detailed feedback with advanced accuracy 📊
457
+
458
  ### 🎯 Features:
459
+ - Dual-pass analysis for comprehensive assessment
460
+ - Visual highlighting of pronunciation errors
461
+ - Romanization for Indic scripts
462
+ - Advanced metrics (Word & Character accuracy)
463
  """)
464
+
465
  with gr.Row():
466
  with gr.Column(scale=3):
467
  lang_choice = gr.Dropdown(
468
+ choices=list(LANG_CODES.keys()),
469
  value="Tamil",
470
  label="🌍 Select Language"
471
  )
472
  with gr.Column(scale=1):
473
  gen_btn = gr.Button("🎲 Generate Sentence", variant="primary")
474
+
475
  intended_display = gr.Textbox(
476
  label="📝 Practice Sentence (Read this aloud)",
477
  placeholder="Click 'Generate Sentence' to get started...",
478
  interactive=False,
479
  lines=3
480
  )
481
+
482
  audio_input = gr.Audio(
483
+ sources=["microphone", "upload"],
484
  type="filepath",
485
  label="🎤 Record Your Pronunciation"
486
  )
487
+
488
+ analyze_btn = gr.Button("🔍 Analyze with Advanced Models", variant="primary")
489
+
490
  status_output = gr.Textbox(
491
+ label="📊 Advanced Analysis Results",
492
  interactive=False,
493
  lines=4
494
  )
495
+
496
  with gr.Row():
497
  with gr.Column():
498
  pass1_out = gr.Textbox(
499
+ label="🏆 Primary Model Output",
500
  interactive=False,
501
  lines=2
502
  )
503
  wer_out = gr.Textbox(
504
+ label="📈 Word Accuracy",
505
  interactive=False
506
  )
 
507
  with gr.Column():
508
  pass2_out = gr.Textbox(
509
+ label="🔧 Specialized Model Comparison",
510
  interactive=False,
511
  lines=2
512
  )
513
+
514
  cer_out = gr.Textbox(
515
+ label="📊 Character Accuracy",
516
  interactive=False
517
  )
518
+
519
  with gr.Accordion("📝 Detailed Visual Feedback", open=True):
520
  gr.Markdown("""
521
  ### 🎨 Color Guide:
522
+ - 🟢 Green: Correctly pronounced words/characters
523
+ - 🔴 Red: Missing or mispronounced (strikethrough)
524
+ - 🟠 Orange: Extra words or substitutions
525
  """)
526
+ diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
527
+ char_html_box = gr.HTML(label="🔤 Character-Level Analysis", show_label=True)
528
+
 
 
 
 
 
 
 
529
  target_display = gr.Textbox(
530
  label="🎯 Reference Text",
531
  interactive=False,
532
  visible=False
533
  )
534
+
 
535
  gen_btn.click(
536
  fn=get_random_sentence,
537
  inputs=[lang_choice],
538
  outputs=[intended_display]
539
  )
540
+
541
  analyze_btn.click(
542
  fn=compare_pronunciation,
543
  inputs=[audio_input, lang_choice, intended_display],
544
  outputs=[
545
+ status_output,
546
+ pass1_out,
547
+ pass2_out,
548
+ wer_out,
549
+ cer_out,
550
+ diff_html_box,
551
+ char_html_box,
552
+ target_display
553
  ]
554
  )
555
+
 
556
  lang_choice.change(
557
  fn=get_random_sentence,
558
  inputs=[lang_choice],
559
  outputs=[intended_display]
560
  )
561
+
 
562
  gr.Markdown("""
563
  ---
564
+ ### 🏆 Advanced Technology Stack:
565
+ - Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
566
+ - Specialized Models:
567
+ - Tamil: vasista22/whisper-tamil-large-v2
568
+ - Malayalam: thennal/whisper-medium-ml
569
+ - English: OpenAI Whisper Base EN
570
+ - Dual Analysis and Automatic Fallback
571
+
572
+ ### 🔧 Technical Details:
573
+ - Metrics: WER and CER
574
+ - Transliteration: Harvard-Kyoto for Indic scripts
575
+ - Languages: English, Tamil, Malayalam
 
 
 
576
  """)
 
577
  return demo
578
 
 
579
  if __name__ == "__main__":
580
+ print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
581
+ print(f"🔧 Device: {DEVICE} (index={DEVICE_INDEX})")
582
  print(f"🔧 PyTorch version: {torch.__version__}")
583
+ print("🏆 Using High-Performance Dual-Model Approach")
584
+ print("⚡ Automatic model selection with specialized fallbacks")
585
+ print("📊 Advanced analysis with robust error handling")
586
  print("🎮 GPU functions decorated with @spaces.GPU for HuggingFace Spaces")
587
+
588
  demo = create_interface()
589
  demo.launch(
590
  share=True,
591
  show_error=True,
592
  server_name="0.0.0.0",
593
  server_port=7860
594
+ )