sudhanm commited on
Commit
60fa434
·
verified ·
1 Parent(s): b7a8eef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -457
app.py CHANGED
@@ -2,60 +2,59 @@ import gradio as gr
2
  import random
3
  import difflib
4
  import re
5
- import jiwer
6
  import torch
7
  import numpy as np
8
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, WhisperProcessor
9
  import librosa
10
  import soundfile as sf
11
- from indic_transliteration import sanscript
12
- from indic_transliteration.sanscript import transliterate
13
- import warnings
14
 
15
- # Optional: only available on HF Spaces runtime
 
 
 
 
 
 
 
 
 
 
16
  try:
17
  import spaces
18
  GPU_DECORATOR = spaces.GPU
19
- except Exception:
20
- def GPU_DECORATOR(fn):
21
- return fn
 
22
 
23
  warnings.filterwarnings("ignore")
24
 
25
  # ---------------- CONFIG ---------------- #
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- CUDA_DEVICE_INDEX = 0 if torch.cuda.is_available() else -1 # for transformers pipeline device
 
28
  print(f"🔧 Using device: {DEVICE}")
29
 
30
  LANG_CODES = {
31
  "English": "en",
32
  "Tamil": "ta",
33
- "Malayalam": "ml"
34
  }
35
 
36
- # Primary model
37
- INDICWHISPER_MODEL = "openai/whisper-large-v2"
38
 
39
- # Specialized models
40
  SPECIALIZED_MODELS = {
41
  "English": "openai/whisper-base.en",
42
  "Tamil": "vasista22/whisper-tamil-large-v2",
43
- "Malayalam": "thennal/whisper-medium-ml"
44
- }
45
-
46
- LANG_PRIMERS = {
47
- "English": ("Transcribe in English.",
48
- "Write only in English. Example: This is an English sentence."),
49
- "Tamil": ("தமிழில் எழுதுக.",
50
- "தமிழ் எழுத்துக்களில் மட்டும் எழுதவும். உதாரணம்: இது ஒரு தமிழ் வாக்கியம்."),
51
- "Malayalam": ("മലയാളത്തിൽ എഴുതുക.",
52
- "മലയാള ലിപിയിൽ മാത്രം എഴുതുക. ഉദാഹരണം: ഇതൊരു മലയാള വാക്യമാണ്.")
53
  }
54
 
55
  SCRIPT_PATTERNS = {
56
  "Tamil": re.compile(r"[஀-௿]"),
57
  "Malayalam": re.compile(r"[ഀ-ൿ]"),
58
- "English": re.compile(r"[A-Za-z]")
59
  }
60
 
61
  SENTENCE_BANK = {
@@ -67,7 +66,7 @@ SENTENCE_BANK = {
67
  "Music brings people together across cultures.",
68
  "Education is the key to a bright future.",
69
  "The flowers bloom beautifully in spring.",
70
- "Hard work always pays off in the end."
71
  ],
72
  "Tamil": [
73
  "இன்று நல்ல வானிலை உள்ளது.",
@@ -77,7 +76,7 @@ SENTENCE_BANK = {
77
  "குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.",
78
  "கல்வி நமது எதிர்காலத்தின் திறவுகோல்.",
79
  "பறவைகள் காலையில் இனிமையாக பாடுகின்றன.",
80
- "உழைப்பு எப்போதும் வெற்றியைத் தரும்."
81
  ],
82
  "Malayalam": [
83
  "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
@@ -87,85 +86,14 @@ SENTENCE_BANK = {
87
  "വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്.",
88
  "സംഗീതം മനസ്സിന് സന്തോഷം നൽകുന്നു.",
89
  "കുടുംബസമയം വളരെ വിലപ്പെട്ടതാണ്.",
90
- "കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും."
91
- ]
92
  }
93
 
94
- # ---------------- MODEL CACHE ---------------- #
95
  indicwhisper_pipeline = None
96
  fallback_models = {}
97
- WHISPER_JAX_AVAILABLE = False # default false; will set true if we load it
98
-
99
- @GPU_DECORATOR
100
- def load_indicwhisper():
101
- """Load primary high-performance model (prefer transformers pipeline, optionally JAX if available)."""
102
- global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
103
-
104
- if indicwhisper_pipeline is not None:
105
- return indicwhisper_pipeline
106
-
107
- # Try JAX first (optional)
108
- try:
109
- from whisper_jax import FlaxWhisperPipeline
110
- import jax.numpy as jnp
111
- print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
112
- indicwhisper_pipeline = FlaxWhisperPipeline(
113
- INDICWHISPER_MODEL,
114
- dtype=jnp.bfloat16,
115
- batch_size=1
116
- )
117
- WHISPER_JAX_AVAILABLE = True
118
- print("✅ JAX-optimized model loaded successfully!")
119
- return indicwhisper_pipeline
120
- except Exception as e:
121
- print(f"⚠️ JAX loading failed: {e}")
122
- WHISPER_JAX_AVAILABLE = False
123
-
124
- # Fallback to transformers pipeline
125
- try:
126
- from transformers import pipeline
127
- print(f"🔄 Loading transformers ASR pipeline: {INDICWHISPER_MODEL}")
128
- indicwhisper_pipeline = pipeline(
129
- task="automatic-speech-recognition",
130
- model=INDICWHISPER_MODEL,
131
- device=CUDA_DEVICE_INDEX # 0 for CUDA, -1 for CPU
132
- )
133
- print("✅ Transformers ASR pipeline loaded!")
134
- return indicwhisper_pipeline
135
- except Exception as e:
136
- print(f"❌ Failed to load primary model: {e}")
137
- indicwhisper_pipeline = None
138
- raise Exception(f"Could not load primary model: {str(e)}")
139
-
140
- @GPU_DECORATOR
141
- def load_specialized_model(language: str):
142
- """Load language-specific specialized model with processor."""
143
- if language in fallback_models:
144
- return fallback_models[language]
145
-
146
- model_name = SPECIALIZED_MODELS[language]
147
- print(f"🔄 Loading specialized model for {language}: {model_name}")
148
-
149
- try:
150
- # WhisperProcessor ensures get_decoder_prompt_ids is available
151
- try:
152
- processor = WhisperProcessor.from_pretrained(model_name)
153
- except Exception:
154
- processor = AutoProcessor.from_pretrained(model_name)
155
-
156
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
157
- model_name,
158
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
159
- low_cpu_mem_usage=True
160
- )
161
- model.to(DEVICE)
162
-
163
- fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
164
- print(f"✅ Specialized model loaded for {language}")
165
- return fallback_models[language]
166
- except Exception as e:
167
- print(f"❌ Failed to load specialized {model_name}: {e}")
168
- raise Exception(f"Could not load specialized {language} model: {str(e)}")
169
 
170
  # ---------------- HELPERS ---------------- #
171
  def get_random_sentence(language_choice):
@@ -178,168 +106,61 @@ def is_script(text, lang_name):
178
  return bool(pattern.search(text or ""))
179
 
180
  def transliterate_to_hk(text, lang_choice):
 
 
181
  mapping = {
182
  "Tamil": sanscript.TAMIL,
183
  "Malayalam": sanscript.MALAYALAM,
184
  "English": None
185
  }
186
  script = mapping.get(lang_choice)
187
- if not text:
188
- return ""
189
  if script and is_script(text, lang_choice):
190
  try:
191
  return transliterate(text, script, sanscript.HK)
192
- except Exception as e:
193
- print(f"Transliteration error: {e}")
194
  return text
195
  return text
196
 
197
  def preprocess_audio(audio_path, target_sr=16000):
198
- """Load, normalize, trim, return float32 audio."""
199
  try:
200
  audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
201
  if audio is None or len(audio) == 0:
202
  return None, None
203
- # Normalize
204
- m = np.max(np.abs(audio))
205
- if m > 0:
206
- audio = audio / m
207
- # Trim silence
208
  audio, _ = librosa.effects.trim(audio, top_db=20)
209
- # Ensure min length
210
- if len(audio) < int(target_sr * 0.1):
211
  return None, None
212
- # Ensure float32
213
- if audio.dtype != np.float32:
214
- audio = audio.astype(np.float32)
215
  return audio, target_sr
216
  except Exception as e:
217
  print(f"Audio preprocessing error: {e}")
218
  return None, None
219
 
220
- @GPU_DECORATOR
221
- def transcribe_with_primary_model(audio_path, language):
222
- """Transcribe using primary model (JAX if available else transformers pipeline)."""
223
- try:
224
- pipe = load_indicwhisper()
225
-
226
- lang_code = LANG_CODES.get(language, "en")
227
-
228
- if WHISPER_JAX_AVAILABLE:
229
- # whisper-jax expects array or path; pass path is okay
230
- result = pipe(audio_path, task="transcribe", language=lang_code)
231
- # whisper-jax returns dict with 'text'
232
- if isinstance(result, dict) and "text" in result:
233
- return (result["text"] or "").strip()
234
- return str(result).strip()
235
-
236
- # transformers pipeline
237
- # Some transformers versions accept language/task via generate_kwargs
238
- generate_kwargs = {}
239
- try:
240
- # If underlying model is Whisper, we can set forced decoder ids
241
- model = pipe.model if hasattr(pipe, "model") else None
242
- tokenizer = getattr(pipe, "tokenizer", None)
243
- processor = getattr(pipe, "feature_extractor", None)
244
- if hasattr(pipe, "tokenizer") and hasattr(model, "config"):
245
- try:
246
- forced_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
247
- model.config.forced_decoder_ids = forced_ids
248
- except Exception as e:
249
- print(f"⚠️ Primary model language forcing failed: {e}")
250
- except Exception as e:
251
- print(f"⚠️ Primary model prompt config error: {e}")
252
-
253
- out = pipe(audio_path, generate_kwargs=generate_kwargs)
254
- if isinstance(out, dict) and "text" in out:
255
- return (out["text"] or "").strip()
256
- elif isinstance(out, str):
257
- return out.strip()
258
- else:
259
- return str(out).strip()
260
- except Exception as e:
261
- print(f"Primary model transcription error: {e}")
262
- return f"Error: {str(e)[:200]}"
263
 
264
- @GPU_DECORATOR
265
- def transcribe_with_specialized_model(audio_path, language):
266
- """Transcribe using language-specific models."""
267
  try:
268
- components = load_specialized_model(language)
269
- processor = components["processor"]
270
- model = components["model"]
271
-
272
- audio, sr = preprocess_audio(audio_path)
273
- if audio is None:
274
- return "Error: Audio too short or could not be processed"
275
-
276
- inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
277
- # WhisperProcessor returns input_features
278
- input_features = inputs.get("input_features", None)
279
- if input_features is None:
280
- # Fallback: some processors use feature_extractor path
281
- input_features = inputs.get("input_values", None)
282
- if input_features is None:
283
- return "Error: Could not prepare input features"
284
-
285
- input_features = input_features.to(DEVICE)
286
-
287
- generate_kwargs = {
288
- "max_length": 200,
289
- "num_beams": 3,
290
- "do_sample": False
291
- }
292
-
293
- if language != "English":
294
- lang_code = LANG_CODES.get(language, "en")
295
- try:
296
- if hasattr(processor, "get_decoder_prompt_ids"):
297
- forced_decoder_ids = processor.get_decoder_prompt_ids(
298
- language=lang_code,
299
- task="transcribe"
300
- )
301
- generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
302
- elif hasattr(model, "config") and hasattr(processor, "tokenizer"):
303
- forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
304
- language=lang_code,
305
- task="transcribe"
306
- )
307
- model.config.forced_decoder_ids = forced_decoder_ids
308
- except Exception as e:
309
- print(f"⚠️ Language forcing failed: {e}")
310
 
311
- with torch.no_grad():
312
- predicted_ids = model.generate(input_features=input_features, **generate_kwargs)
313
-
314
- transcription = processor.batch_decode(
315
- predicted_ids,
316
- skip_special_tokens=True,
317
- clean_up_tokenization_spaces=True
318
- )[0]
319
- return (transcription or "").strip() or "(No transcription generated)"
320
- except Exception as e:
321
- print(f"Specialized model transcription error: {e}")
322
- return f"Error: {str(e)[:200]}"
323
-
324
- @GPU_DECORATOR
325
- def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
326
- """Dispatch to primary or specialized path with fallback."""
327
  try:
328
- if use_specialized:
329
- print(f"🔄 Using specialized model for {language}")
330
- return transcribe_with_specialized_model(audio_path, language)
331
- else:
332
- print(f"🔄 Using primary model for {language}")
333
- return transcribe_with_primary_model(audio_path, language)
334
- except Exception as e:
335
- print(f"Transcription failed, trying specialized model: {e}")
336
- if not use_specialized:
337
- return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
338
- else:
339
- return f"Error: All transcription methods failed - {str(e)[:200]}"
340
 
341
  def highlight_differences(ref, hyp):
342
- if not (ref or "").strip() or not (hyp or "").strip():
343
  return "No text to compare"
344
  ref_words = ref.strip().split()
345
  hyp_words = hyp.strip().split()
@@ -347,262 +168,205 @@ def highlight_differences(ref, hyp):
347
  out_html = []
348
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
349
  if tag == 'equal':
350
- out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
351
  elif tag == 'replace':
352
- out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
353
- out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>→{w}</span>" for w in hyp_words[j1:j2]])
354
  elif tag == 'delete':
355
- out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
356
  elif tag == 'insert':
357
- out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
358
  return " ".join(out_html)
359
 
360
  def char_level_highlight(ref, hyp):
361
- if not (ref or "").strip() or not (hyp or "").strip():
362
  return "No text to compare"
363
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
364
  out = []
365
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
366
  if tag == 'equal':
367
- out.extend([f"<span style='color:green; background-color:#e8f5e8;'>{c}</span>" for c in ref[i1:i2]])
368
  elif tag in ('replace', 'delete'):
369
- out.extend([f"<span style='color:red; text-decoration:underline; background-color:#ffe8e8; font-weight:bold;'>{c}</span>" for c in ref[i1:i2]])
370
  elif tag == 'insert':
371
- out.extend([f"<span style='color:orange; background-color:#fff3cd; font-weight:bold;'>{c}</span>" for c in hyp[j1:j2]])
372
  return "".join(out)
373
 
374
  def get_pronunciation_score(wer_val, cer_val):
375
- combined_score = (wer_val * 0.7) + (cer_val * 0.3)
376
- if combined_score <= 0.1:
377
  return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!"
378
- elif combined_score <= 0.2:
379
  return "🎉 Very Good! (80-90%)", "Great pronunciation with minor areas for improvement."
380
- elif combined_score <= 0.4:
381
- return "👍 Good! (60-80%)", "Good effort! Keep practicing for better accuracy."
382
- elif combined_score <= 0.6:
383
- return "📚 Needs Practice (40-60%)", "Focus on clearer pronunciation of highlighted words."
384
  else:
385
- return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
386
 
387
- # ---------------- MAIN FUNCTION ---------------- #
388
  @GPU_DECORATOR
389
- def compare_pronunciation(audio, language_choice, intended_sentence):
390
- print(f"🔍 Starting analysis with language: {language_choice}")
391
- print(f"📝 Audio file: {audio}")
392
- print(f"🎯 Intended sentence: {intended_sentence}")
393
-
394
- if audio is None:
395
- return ("❌ Please record audio first.", "", "", "", "", "", "", "")
396
- if not (intended_sentence or "").strip():
397
- return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
398
-
399
  try:
400
- print("🔄 Pass 1: Primary model transcription...")
401
- primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
402
- print(f" Primary: {primary_text}")
403
-
404
- print("🔄 Pass 2: Specialized model transcription...")
405
- specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
406
- print(f"✅ Specialized: {specialized_text}")
407
-
408
- actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
409
-
410
- if str(actual_text).startswith("Error:"):
411
- return (f"❌ {actual_text}", "", "", "", "", "", "", "")
412
-
413
- try:
414
- wer_val = jiwer.wer(intended_sentence, actual_text)
415
- cer_val = jiwer.cer(intended_sentence, actual_text)
416
- except Exception as e:
417
- print(f"❌ Metrics error: {e}")
418
- wer_val, cer_val = 1.0, 1.0
419
-
420
- score_text, feedback = get_pronunciation_score(wer_val, cer_val)
421
-
422
- actual_hk = transliterate_to_hk(actual_text, language_choice)
423
- target_hk = transliterate_to_hk(intended_sentence, language_choice)
424
- if language_choice != "English" and not is_script(actual_text, language_choice):
425
- actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
426
-
427
- diff_html = highlight_differences(intended_sentence, actual_text)
428
- char_html = char_level_highlight(intended_sentence, actual_text)
429
-
430
- status = f"✅ Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models"
431
-
432
- return (
433
- status,
434
- primary_text or "(No primary transcription)",
435
- specialized_text or "(No specialized transcription)",
436
- f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)",
437
- f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% character accuracy)",
438
- diff_html,
439
- char_html,
440
- f"🎯 Target: {intended_sentence}"
441
  )
442
-
 
 
443
  except Exception as e:
444
- error_msg = f" Analysis Error: {str(e)[:200]}"
445
- return (error_msg, str(e), "", "", "", "", "", "")
446
-
447
- # ---------------- UI ---------------- #
448
- def create_interface():
449
- with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
450
- gr.Markdown("""
451
- # 🎙️ Advanced Multilingual Pronunciation Trainer
452
-
453
- Practice pronunciation in Tamil, Malayalam & English using high-performance ASR models!
454
-
455
- ### 🏆 Powered by Advanced Models:
456
- - Dual-Model Analysis: Primary + specialized model comparison
457
- - High Accuracy: Language-specific fine-tuned models
458
- - Robust Performance: Automatic fallback for reliability
459
-
460
- ### 📋 How to Use:
461
- 1. Select your target language 🌍
462
- 2. Generate a practice sentence 🎲
463
- 3. Record yourself reading it aloud 🎤
464
- 4. Get detailed feedback with advanced accuracy 📊
465
-
466
- ### 🎯 Features:
467
- - Dual-pass analysis for comprehensive assessment
468
- - Visual highlighting of pronunciation errors
469
- - Romanization for Indic scripts
470
- - Advanced metrics (Word & Character accuracy)
471
- """)
472
-
473
- with gr.Row():
474
- with gr.Column(scale=3):
475
- lang_choice = gr.Dropdown(
476
- choices=list(LANG_CODES.keys()),
477
- value="Tamil",
478
- label="🌍 Select Language"
479
- )
480
- with gr.Column(scale=1):
481
- gen_btn = gr.Button("🎲 Generate Sentence", variant="primary")
482
-
483
- intended_display = gr.Textbox(
484
- label="📝 Practice Sentence (Read this aloud)",
485
- placeholder="Click 'Generate Sentence' to get started...",
486
- interactive=False,
487
- lines=3
488
- )
489
-
490
- audio_input = gr.Audio(
491
- sources=["microphone", "upload"],
492
- type="filepath",
493
- label="🎤 Record Your Pronunciation"
494
  )
 
 
 
 
 
495
 
496
- analyze_btn = gr.Button("🔍 Analyze with Advanced Models", variant="primary")
497
-
498
- status_output = gr.Textbox(
499
- label="📊 Advanced Analysis Results",
500
- interactive=False,
501
- lines=4
502
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
- with gr.Row():
505
- with gr.Column():
506
- pass1_out = gr.Textbox(
507
- label="🏆 Primary Model Output",
508
- interactive=False,
509
- lines=2
510
- )
511
- wer_out = gr.Textbox(
512
- label="📈 Word Accuracy",
513
- interactive=False
514
- )
515
- with gr.Column():
516
- pass2_out = gr.Textbox(
517
- label="🔧 Specialized Model Comparison",
518
- interactive=False,
519
- lines=2
520
  )
 
 
 
 
 
 
 
 
521
 
522
- cer_out = gr.Textbox(
523
- label="📊 Character Accuracy",
524
- interactive=False
525
- )
526
-
527
- with gr.Accordion("📝 Detailed Visual Feedback", open=True):
528
- gr.Markdown("""
529
- ### 🎨 Color Guide:
530
- - 🟢 Green: Correctly pronounced words/characters
531
- - 🔴 Red: Missing or mispronounced (strikethrough)
532
- - 🟠 Orange: Extra words or substitutions
533
- """)
534
- diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
535
- char_html_box = gr.HTML(label="🔤 Character-Level Analysis", show_label=True)
536
-
537
- target_display = gr.Textbox(
538
- label="🎯 Reference Text",
539
- interactive=False,
540
- visible=False
541
- )
542
-
543
- gen_btn.click(
544
- fn=get_random_sentence,
545
- inputs=[lang_choice],
546
- outputs=[intended_display]
547
- )
548
-
549
- analyze_btn.click(
550
- fn=compare_pronunciation,
551
- inputs=[audio_input, lang_choice, intended_display],
552
- outputs=[
553
- status_output, # status
554
- pass1_out, # primary transcription
555
- pass2_out, # specialized transcription
556
- wer_out, # wer formatted
557
- cer_out, # cer formatted
558
- diff_html_box, # diff_html
559
- char_html_box, # char_html
560
- target_display # target_display
561
- ]
562
- )
563
-
564
- lang_choice.change(
565
- fn=get_random_sentence,
566
- inputs=[lang_choice],
567
- outputs=[intended_display]
568
- )
569
 
570
- gr.Markdown("""
571
- ---
572
- ### 🏆 Advanced Technology Stack:
573
- - Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
574
- - Specialized Models: Fine-tuned language-specific models
575
- - Tamil: vasista22/whisper-tamil-large-v2 (IIT Madras Speech Lab)
576
- - Malayalam: thennal/whisper-medium-ml (Common Voice trained)
577
- - English: openai/whisper-base.en (English-optimized)
578
- - Dual Analysis: Primary + specialized model comparison
579
- - Automatic Fallback: Ensures reliable results always
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
- ### 🔧 Technical Details:
582
- - Metrics: WER (Word Error Rate) and CER (Character Error Rate)
583
- - Transliteration: Harvard-Kyoto system for Indic scripts
584
- - Analysis: Dual-model comparison for comprehensive feedback
585
- - Languages: English, Tamil, and Malayalam
586
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  return demo
588
 
589
  # ---------------- LAUNCH ---------------- #
590
  if __name__ == "__main__":
591
- print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
592
- print(f"🔧 Device: {DEVICE}")
593
- try:
594
- torch_ver = getattr(torch, '__version__', 'unknown')
595
- except Exception:
596
- torch_ver = 'unknown'
597
- print(f"🔧 PyTorch version: {torch_ver}")
598
- print("🏆 Using High-Performance Dual-Model Approach")
599
- print("⚡ Automatic model selection with specialized fallbacks")
600
- print("📊 Advanced analysis with robust error handling")
601
-
602
  demo = create_interface()
603
- demo.launch(
604
- share=True,
605
- show_error=True,
606
- server_name="0.0.0.0",
607
- server_port=7860
608
- )
 
2
  import random
3
  import difflib
4
  import re
5
+ import warnings
6
  import torch
7
  import numpy as np
 
8
  import librosa
9
  import soundfile as sf
10
+ import jiwer
 
 
11
 
12
+ # Optional: Indic transliteration
13
+ try:
14
+ from indic_transliteration import sanscript
15
+ from indic_transliteration.sanscript import transliterate
16
+ INDIC_OK = True
17
+ except:
18
+ INDIC_OK = False
19
+ sanscript = None
20
+ transliterate = None
21
+
22
+ # Optional: HF Spaces GPU decorator
23
  try:
24
  import spaces
25
  GPU_DECORATOR = spaces.GPU
26
+ except:
27
+ class _NoOp:
28
+ def __call__(self, f): return f
29
+ GPU_DECORATOR = _NoOp()
30
 
31
  warnings.filterwarnings("ignore")
32
 
33
  # ---------------- CONFIG ---------------- #
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+ DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
36
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
37
  print(f"🔧 Using device: {DEVICE}")
38
 
39
  LANG_CODES = {
40
  "English": "en",
41
  "Tamil": "ta",
42
+ "Malayalam": "ml",
43
  }
44
 
45
+ # AI4Bharat IndicWhisper community port
46
+ INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
47
 
 
48
  SPECIALIZED_MODELS = {
49
  "English": "openai/whisper-base.en",
50
  "Tamil": "vasista22/whisper-tamil-large-v2",
51
+ "Malayalam": "thennal/whisper-medium-ml",
 
 
 
 
 
 
 
 
 
52
  }
53
 
54
  SCRIPT_PATTERNS = {
55
  "Tamil": re.compile(r"[஀-௿]"),
56
  "Malayalam": re.compile(r"[ഀ-ൿ]"),
57
+ "English": re.compile(r"[A-Za-z]"),
58
  }
59
 
60
  SENTENCE_BANK = {
 
66
  "Music brings people together across cultures.",
67
  "Education is the key to a bright future.",
68
  "The flowers bloom beautifully in spring.",
69
+ "Hard work always pays off in the end.",
70
  ],
71
  "Tamil": [
72
  "இன்று நல்ல வானிலை உள்ளது.",
 
76
  "குடும்பத்துடன் நேரம் செலவிடுவது முக்கியம்.",
77
  "கல்வி நமது எதிர்காலத்தின் திறவுகோல்.",
78
  "பறவைகள் காலையில் இனிமையாக பாடுகின்றன.",
79
+ "உழைப்பு எப்போதும் வெற்றியைத் தரும்.",
80
  ],
81
  "Malayalam": [
82
  "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
 
86
  "വിദ്യാഭ്യാസം ജീവിതത്തിൽ പ്രധാനമാണ്.",
87
  "സംഗീതം മനസ്സിന് സന്തോഷം നൽകുന്നു.",
88
  "കുടുംബസമയം വളരെ വിലപ്പെട്ടതാണ്.",
89
+ "കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും.",
90
+ ],
91
  }
92
 
93
+ # Model cache
94
  indicwhisper_pipeline = None
95
  fallback_models = {}
96
+ WHISPER_JAX_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # ---------------- HELPERS ---------------- #
99
  def get_random_sentence(language_choice):
 
106
  return bool(pattern.search(text or ""))
107
 
108
  def transliterate_to_hk(text, lang_choice):
109
+ if not INDIC_OK:
110
+ return text
111
  mapping = {
112
  "Tamil": sanscript.TAMIL,
113
  "Malayalam": sanscript.MALAYALAM,
114
  "English": None
115
  }
116
  script = mapping.get(lang_choice)
 
 
117
  if script and is_script(text, lang_choice):
118
  try:
119
  return transliterate(text, script, sanscript.HK)
120
+ except:
 
121
  return text
122
  return text
123
 
124
  def preprocess_audio(audio_path, target_sr=16000):
 
125
  try:
126
  audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
127
  if audio is None or len(audio) == 0:
128
  return None, None
129
+ audio = audio.astype(np.float32)
130
+ max_abs = np.max(np.abs(audio))
131
+ if max_abs > 0:
132
+ audio /= max_abs
 
133
  audio, _ = librosa.effects.trim(audio, top_db=20)
134
+ if len(audio) < target_sr * 0.1:
 
135
  return None, None
 
 
 
136
  return audio, target_sr
137
  except Exception as e:
138
  print(f"Audio preprocessing error: {e}")
139
  return None, None
140
 
141
+ # Normalization for WER
142
+ JIWER_TRANSFORM = jiwer.Compose([
143
+ jiwer.ToLowerCase(),
144
+ jiwer.RemovePunctuation(),
145
+ jiwer.RemoveMultipleSpaces(),
146
+ jiwer.Strip(),
147
+ jiwer.ReduceToListOfListOfWords(),
148
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ def compute_wer(ref, hyp):
 
 
151
  try:
152
+ return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
153
+ except:
154
+ return 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ def compute_cer(ref, hyp):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  try:
158
+ return jiwer.cer(ref, hyp)
159
+ except:
160
+ return 1.0
 
 
 
 
 
 
 
 
 
161
 
162
  def highlight_differences(ref, hyp):
163
+ if not ref.strip() or not hyp.strip():
164
  return "No text to compare"
165
  ref_words = ref.strip().split()
166
  hyp_words = hyp.strip().split()
 
168
  out_html = []
169
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
170
  if tag == 'equal':
171
+ out_html.extend([f"<span style='color:green; background-color:#e8f5e8;'>{w}</span>" for w in ref_words[i1:i2]])
172
  elif tag == 'replace':
173
+ out_html.extend([f"<span style='color:red; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]])
174
+ out_html.extend([f"<span style='color:orange;'>→{w}</span>" for w in hyp_words[j1:j2]])
175
  elif tag == 'delete':
176
+ out_html.extend([f"<span style='color:red; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]])
177
  elif tag == 'insert':
178
+ out_html.extend([f"<span style='color:orange;'>+{w}</span>" for w in hyp_words[j1:j2]])
179
  return " ".join(out_html)
180
 
181
  def char_level_highlight(ref, hyp):
182
+ if not ref.strip() or not hyp.strip():
183
  return "No text to compare"
184
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
185
  out = []
186
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
187
  if tag == 'equal':
188
+ out.extend([f"<span style='color:green;'>{c}</span>" for c in ref[i1:i2]])
189
  elif tag in ('replace', 'delete'):
190
+ out.extend([f"<span style='color:red;'>{c}</span>" for c in ref[i1:i2]])
191
  elif tag == 'insert':
192
+ out.extend([f"<span style='color:orange;'>{c}</span>" for c in hyp[j1:j2]])
193
  return "".join(out)
194
 
195
  def get_pronunciation_score(wer_val, cer_val):
196
+ combined = (wer_val * 0.7) + (cer_val * 0.3)
197
+ if combined <= 0.1:
198
  return "🏆 Excellent! (90%+)", "Your pronunciation is outstanding!"
199
+ elif combined <= 0.2:
200
  return "🎉 Very Good! (80-90%)", "Great pronunciation with minor areas for improvement."
201
+ elif combined <= 0.4:
202
+ return "👍 Good! (60-80%)", "Good effort! Keep practicing."
203
+ elif combined <= 0.6:
204
+ return "📚 Needs Practice (40-60%)", "Focus on clearer pronunciation."
205
  else:
206
+ return "💪 Keep Trying! (<40%)", "Don't give up!"
207
 
208
+ # ---------------- LOADERS ---------------- #
209
  @GPU_DECORATOR
210
+ def load_indicwhisper():
211
+ global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
212
+ if indicwhisper_pipeline is not None:
213
+ return indicwhisper_pipeline
214
+ # Try JAX first
 
 
 
 
 
215
  try:
216
+ from whisper_jax import FlaxWhisperPipeline
217
+ import jax.numpy as jnp
218
+ print(f"🔄 Loading JAX IndicWhisper: {INDICWHISPER_MODEL}")
219
+ indicwhisper_pipeline = FlaxWhisperPipeline(
220
+ INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  )
222
+ WHISPER_JAX_AVAILABLE = True
223
+ print("✅ JAX Loaded!")
224
+ return indicwhisper_pipeline
225
  except Exception as e:
226
+ print(f"⚠️ JAX unavailable: {e}")
227
+ WHISPER_JAX_AVAILABLE = False
228
+ # Fallback to Transformers
229
+ try:
230
+ from transformers import pipeline
231
+ indicwhisper_pipeline = pipeline(
232
+ "automatic-speech-recognition",
233
+ model=INDICWHISPER_MODEL,
234
+ device=DEVICE_INDEX
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  )
236
+ print("✅ Transformers IndicWhisper loaded!")
237
+ return indicwhisper_pipeline
238
+ except Exception as e:
239
+ print(f"❌ Failed to load IndicWhisper: {e}")
240
+ raise
241
 
242
+ @GPU_DECORATOR
243
+ def load_specialized_model(language):
244
+ if language in fallback_models:
245
+ return fallback_models[language]
246
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
247
+ model_name = SPECIALIZED_MODELS[language]
248
+ processor = AutoProcessor.from_pretrained(model_name)
249
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
250
+ model_name, torch_dtype=DTYPE,
251
+ low_cpu_mem_usage=True
252
+ ).to(DEVICE)
253
+ fallback_models[language] = {"processor": processor, "model": model}
254
+ return fallback_models[language]
255
+
256
+ # ---------------- TRANSCRIBE ---------------- #
257
+ @GPU_DECORATOR
258
+ def transcribe_with_primary_model(audio_path, language):
259
+ try:
260
+ pl = load_indicwhisper()
261
+ lang_code = LANG_CODES.get(language, "en")
262
+ # JAX
263
+ if WHISPER_JAX_AVAILABLE:
264
+ result = pl(audio_path, task="transcribe", language=lang_code)
265
+ if isinstance(result, dict) and "text" in result:
266
+ return result["text"].strip()
267
+ return str(result).strip()
268
+ # Transformers
269
+ if hasattr(pl, "model") and hasattr(pl, "tokenizer"):
270
+ try:
271
+ forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
272
+ pl.model.config.forced_decoder_ids = forced_ids
273
+ except: pass
274
+ out = pl(audio_path)
275
+ if isinstance(out, dict) and 'text' in out:
276
+ return out['text'].strip()
277
+ return str(out).strip()
278
+ except Exception as e:
279
+ return f"Error: {str(e)}"
280
 
281
+ @GPU_DECORATOR
282
+ def transcribe_with_specialized_model(audio_path, language):
283
+ try:
284
+ c = load_specialized_model(language)
285
+ audio, sr = preprocess_audio(audio_path)
286
+ if audio is None:
287
+ return "Error: Audio too short"
288
+ inputs = c["processor"](audio, sampling_rate=sr, return_tensors="pt")
289
+ input_features = inputs.input_features.to(DEVICE)
290
+ generate_kwargs = {"inputs": input_features, "max_length": 200, "num_beams": 3}
291
+ if language != "English":
292
+ try:
293
+ forced_ids = c["processor"].tokenizer.get_decoder_prompt_ids(
294
+ language=LANG_CODES[language], task="transcribe"
 
 
295
  )
296
+ generate_kwargs["forced_decoder_ids"] = forced_ids
297
+ except: pass
298
+ with torch.no_grad():
299
+ ids = c["model"].generate(**generate_kwargs)
300
+ text = c["processor"].batch_decode(ids, skip_special_tokens=True)[0]
301
+ return text.strip()
302
+ except Exception as e:
303
+ return f"Error: {str(e)}"
304
 
305
+ @GPU_DECORATOR
306
+ def transcribe_audio(audio_path, language, use_specialized=False):
307
+ try:
308
+ if use_specialized:
309
+ return transcribe_with_specialized_model(audio_path, language)
310
+ else:
311
+ return transcribe_with_primary_model(audio_path, language)
312
+ except:
313
+ if not use_specialized:
314
+ return transcribe_audio(audio_path, language, use_specialized=True)
315
+ return "Error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ # ---------------- MAIN FUNCTION ---------------- #
318
+ @GPU_DECORATOR
319
+ def compare_pronunciation(audio, language_choice, intended_sentence):
320
+ if audio is None:
321
+ return ("❌ Please record audio first.", "", "", "", "", "", "", "")
322
+ if not intended_sentence.strip():
323
+ return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
324
+ primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
325
+ specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
326
+ actual_text = primary_text if not primary_text.startswith("Error:") else specialized_text
327
+ if actual_text.startswith("Error:"):
328
+ return (f"❌ {actual_text}", "", "", "", "", "", "", "")
329
+ wer_val = compute_wer(intended_sentence, actual_text)
330
+ cer_val = compute_cer(intended_sentence, actual_text)
331
+ score_text, feedback = get_pronunciation_score(wer_val, cer_val)
332
+ diff_html = highlight_differences(intended_sentence, actual_text)
333
+ char_html = char_level_highlight(intended_sentence, actual_text)
334
+ return (
335
+ f"✅ Analysis Complete - {score_text}\n💬 {feedback}",
336
+ primary_text, specialized_text,
337
+ f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)",
338
+ f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% char accuracy)",
339
+ diff_html, char_html,
340
+ f"🎯 Target: {intended_sentence}"
341
+ )
342
 
343
+ # ---------------- UI ---------------- #
344
+ def create_interface():
345
+ with gr.Blocks(title="🎙️ IndicWhisper Pronunciation Trainer") as demo:
346
+ gr.Markdown("# 🎙️ IndicWhisper-based Pronunciation Trainer")
347
+ with gr.Row():
348
+ lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Tamil", label="🌍 Language")
349
+ gen_btn = gr.Button("🎲 Generate Sentence")
350
+ intended_display = gr.Textbox(label="📝 Practice Sentence", interactive=False, lines=3)
351
+ audio_input = gr.Audio(sources=["microphone","upload"], type="filepath", label="🎤 Record")
352
+ analyze_btn = gr.Button("🔍 Analyze")
353
+ status_output = gr.Textbox(label="📊 Results", interactive=False, lines=4)
354
+ with gr.Row():
355
+ pass1_out = gr.Textbox(label="🏆 Primary (IndicWhisper)", interactive=False)
356
+ pass2_out = gr.Textbox(label="🔧 Specialized", interactive=False)
357
+ wer_out = gr.Textbox(label="📈 Word Accuracy", interactive=False)
358
+ cer_out = gr.Textbox(label="📊 Char Accuracy", interactive=False)
359
+ diff_html_box = gr.HTML(label="Word-Level Analysis")
360
+ char_html_box = gr.HTML(label="Character-Level Analysis")
361
+ target_display = gr.Textbox(label="🎯 Reference", interactive=False, visible=False)
362
+ gen_btn.click(get_random_sentence, [lang_choice], [intended_display])
363
+ analyze_btn.click(compare_pronunciation,
364
+ [audio_input, lang_choice, intended_display],
365
+ [status_output, pass1_out, pass2_out, wer_out, cer_out, diff_html_box, char_html_box, target_display])
366
+ lang_choice.change(get_random_sentence, [lang_choice], [intended_display])
367
  return demo
368
 
369
  # ---------------- LAUNCH ---------------- #
370
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
371
  demo = create_interface()
372
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)