sudhanm commited on
Commit
b7a8eef
·
verified ·
1 Parent(s): a950033

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -196
app.py CHANGED
@@ -5,21 +5,27 @@ import re
5
  import jiwer
6
  import torch
7
  import numpy as np
8
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
9
  import librosa
10
  import soundfile as sf
11
  from indic_transliteration import sanscript
12
  from indic_transliteration.sanscript import transliterate
13
- import unicodedata
14
  import warnings
15
- import spaces
 
 
 
 
 
 
 
16
 
17
  warnings.filterwarnings("ignore")
18
 
19
  # ---------------- CONFIG ---------------- #
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  print(f"🔧 Using device: {DEVICE}")
22
- DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
23
 
24
  LANG_CODES = {
25
  "English": "en",
@@ -27,8 +33,10 @@ LANG_CODES = {
27
  "Malayalam": "ml"
28
  }
29
 
 
30
  INDICWHISPER_MODEL = "openai/whisper-large-v2"
31
 
 
32
  SPECIALIZED_MODELS = {
33
  "English": "openai/whisper-base.en",
34
  "Tamil": "vasista22/whisper-tamil-large-v2",
@@ -83,23 +91,83 @@ SENTENCE_BANK = {
83
  ]
84
  }
85
 
86
- # Controls for stricter script checking and normalization
87
- STRICT_SCRIPT_CHECK = False # set True for strict script-only validation
88
- NORMALIZE_TEXT_FOR_METRICS = True
89
-
90
  # ---------------- MODEL CACHE ---------------- #
91
  indicwhisper_pipeline = None
92
  fallback_models = {}
93
- WHISPER_JAX_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- def normalize_text(s: str) -> str:
96
- if not NORMALIZE_TEXT_FOR_METRICS:
97
- return s
98
- # Normalize unicode and collapse whitespace; do not remove language-specific punctuation
99
- s = unicodedata.normalize("NFC", s)
100
- s = re.sub(r"\s+", " ", s).strip()
101
- return s
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def get_random_sentence(language_choice):
104
  return random.choice(SENTENCE_BANK[language_choice])
105
 
@@ -107,16 +175,7 @@ def is_script(text, lang_name):
107
  pattern = SCRIPT_PATTERNS.get(lang_name)
108
  if not pattern:
109
  return True
110
- if not STRICT_SCRIPT_CHECK:
111
- # any occurrence of script chars counts as match
112
- return bool(pattern.search(text))
113
- # strict: allow only spaces and target script chars
114
- for ch in text:
115
- if ch.isspace():
116
- continue
117
- if not pattern.match(ch):
118
- return False
119
- return True
120
 
121
  def transliterate_to_hk(text, lang_choice):
122
  mapping = {
@@ -125,6 +184,8 @@ def transliterate_to_hk(text, lang_choice):
125
  "English": None
126
  }
127
  script = mapping.get(lang_choice)
 
 
128
  if script and is_script(text, lang_choice):
129
  try:
130
  return transliterate(text, script, sanscript.HK)
@@ -134,111 +195,75 @@ def transliterate_to_hk(text, lang_choice):
134
  return text
135
 
136
  def preprocess_audio(audio_path, target_sr=16000):
 
137
  try:
138
- audio, sr = librosa.load(audio_path, sr=target_sr)
139
- if np.max(np.abs(audio)) > 0:
140
- audio = audio / np.max(np.abs(audio))
 
 
 
 
 
141
  audio, _ = librosa.effects.trim(audio, top_db=20)
142
- if len(audio) < target_sr * 0.1:
 
143
  return None, None
 
 
 
144
  return audio, target_sr
145
  except Exception as e:
146
  print(f"Audio preprocessing error: {e}")
147
  return None, None
148
 
149
- @spaces.GPU
150
- def load_indicwhisper():
151
- global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
152
- if indicwhisper_pipeline is None:
153
- try:
154
- # Try JAX pipeline
155
- try:
156
- from whisper_jax import FlaxWhisperPipeline
157
- import jax.numpy as jnp
158
- print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
159
- indicwhisper_pipeline = FlaxWhisperPipeline(
160
- INDICWHISPER_MODEL,
161
- dtype=jnp.bfloat16,
162
- batch_size=1
163
- )
164
- WHISPER_JAX_AVAILABLE = True
165
- print("✅ JAX-optimized model loaded successfully!")
166
- return indicwhisper_pipeline
167
- except Exception as e:
168
- print(f"⚠️ JAX loading failed: {e}")
169
- WHISPER_JAX_AVAILABLE = False
170
-
171
- # Fallback to transformers pipeline
172
- print(f"🔄 Loading transformers pipeline: {INDICWHISPER_MODEL}")
173
- from transformers import pipeline
174
- indicwhisper_pipeline = pipeline(
175
- "automatic-speech-recognition",
176
- model=INDICWHISPER_MODEL,
177
- device=DEVICE_INDEX,
178
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
179
- )
180
- print("✅ High-performance model loaded with transformers!")
181
- except Exception as e:
182
- print(f"❌ Failed to load primary model: {e}")
183
- indicwhisper_pipeline = None
184
- raise Exception(f"Could not load high-performance model: {str(e)}")
185
- return indicwhisper_pipeline
186
-
187
- @spaces.GPU
188
- def load_specialized_model(language):
189
- if language not in fallback_models:
190
- model_name = SPECIALIZED_MODELS[language]
191
- print(f"🔄 Loading specialized model for {language}: {model_name}")
192
- try:
193
- processor = AutoProcessor.from_pretrained(model_name)
194
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
195
- model_name,
196
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
197
- low_cpu_mem_usage=True,
198
- use_safetensors=True
199
- ).to(DEVICE)
200
- model.eval()
201
- fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
202
- print(f"✅ Specialized model loaded for {language}")
203
- except Exception as e:
204
- print(f"❌ Failed to load specialized {model_name}: {e}")
205
- raise Exception(f"Could not load specialized {language} model")
206
- return fallback_models[language]
207
-
208
- @spaces.GPU
209
  def transcribe_with_primary_model(audio_path, language):
 
210
  try:
211
  pipe = load_indicwhisper()
212
 
213
- if callable(pipe):
214
- # Try to set forced decoder ids when available
215
- if language != "English":
216
- lang_code = LANG_CODES.get(language, "en")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  try:
218
- if hasattr(pipe, "model") and hasattr(pipe, "tokenizer"):
219
- if hasattr(pipe.model, "config"):
220
- forced_ids = pipe.tokenizer.get_decoder_prompt_ids(
221
- language=lang_code, task="transcribe"
222
- )
223
- pipe.model.config.forced_decoder_ids = forced_ids
224
  except Exception as e:
225
- print(f"⚠️ Language forcing failed: {e}")
 
 
226
 
227
- result = pipe(audio_path)
228
- if isinstance(result, dict) and "text" in result:
229
- return result["text"].strip()
230
- elif isinstance(result, str):
231
- return result.strip()
232
- else:
233
- return str(result).strip()
234
  else:
235
- return "Error: Pipeline not properly initialized"
236
  except Exception as e:
237
  print(f"Primary model transcription error: {e}")
238
- raise e
239
 
240
- @spaces.GPU
241
  def transcribe_with_specialized_model(audio_path, language):
 
242
  try:
243
  components = load_specialized_model(language)
244
  processor = components["processor"]
@@ -248,15 +273,23 @@ def transcribe_with_specialized_model(audio_path, language):
248
  if audio is None:
249
  return "Error: Audio too short or could not be processed"
250
 
251
- inputs = processor(
252
- audio,
253
- sampling_rate=sr,
254
- return_tensors="pt",
255
- padding=True
256
- )
257
- input_features = inputs.input_features.to(DEVICE)
 
 
 
 
 
 
 
 
 
258
 
259
- forced_decoder_ids = None
260
  if language != "English":
261
  lang_code = LANG_CODES.get(language, "en")
262
  try:
@@ -265,60 +298,53 @@ def transcribe_with_specialized_model(audio_path, language):
265
  language=lang_code,
266
  task="transcribe"
267
  )
 
 
 
 
 
 
 
268
  except Exception as e:
269
  print(f"⚠️ Language forcing failed: {e}")
270
 
271
  with torch.no_grad():
272
- gen_kwargs = {
273
- "max_length": 200,
274
- "num_beams": 3,
275
- "do_sample": False
276
- }
277
- if forced_decoder_ids:
278
- gen_kwargs["forced_decoder_ids"] = forced_decoder_ids
279
-
280
- predicted_ids = model.generate(
281
- input_features,
282
- **gen_kwargs
283
- )
284
 
285
  transcription = processor.batch_decode(
286
  predicted_ids,
287
  skip_special_tokens=True,
288
  clean_up_tokenization_spaces=True
289
  )[0]
290
-
291
- return transcription.strip() or "(No transcription generated)"
292
  except Exception as e:
293
  print(f"Specialized model transcription error: {e}")
294
- return f"Error: {str(e)[:150]}..."
295
 
296
- @spaces.GPU
297
  def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
 
298
  try:
299
  if use_specialized:
300
  print(f"🔄 Using specialized model for {language}")
301
  return transcribe_with_specialized_model(audio_path, language)
302
  else:
303
- print(f"🔄 Using high-performance primary model for {language}")
304
  return transcribe_with_primary_model(audio_path, language)
305
  except Exception as e:
306
  print(f"Transcription failed, trying specialized model: {e}")
307
  if not use_specialized:
308
  return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
309
  else:
310
- return f"Error: All transcription methods failed - {str(e)[:100]}"
311
 
312
  def highlight_differences(ref, hyp):
313
- if not ref.strip() or not hyp.strip():
314
  return "No text to compare"
315
-
316
  ref_words = ref.strip().split()
317
  hyp_words = hyp.strip().split()
318
-
319
  sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
320
  out_html = []
321
-
322
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
323
  if tag == 'equal':
324
  out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
@@ -329,13 +355,11 @@ def highlight_differences(ref, hyp):
329
  out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
330
  elif tag == 'insert':
331
  out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
332
-
333
  return " ".join(out_html)
334
 
335
  def char_level_highlight(ref, hyp):
336
- if not ref.strip() or not hyp.strip():
337
  return "No text to compare"
338
-
339
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
340
  out = []
341
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
@@ -360,63 +384,50 @@ def get_pronunciation_score(wer_val, cer_val):
360
  else:
361
  return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
362
 
363
- @spaces.GPU
 
364
  def compare_pronunciation(audio, language_choice, intended_sentence):
365
- print(f"🔍 Starting advanced analysis with language: {language_choice}")
366
  print(f"📝 Audio file: {audio}")
367
  print(f"🎯 Intended sentence: {intended_sentence}")
368
 
369
  if audio is None:
370
- print("❌ No audio provided")
371
  return ("❌ Please record audio first.", "", "", "", "", "", "", "")
372
-
373
- if not intended_sentence.strip():
374
- print("❌ No intended sentence")
375
  return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
376
 
377
  try:
378
- print(f"🔄 Starting Pass 1: High-performance model transcription...")
379
  primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
380
- print(f"✅ Primary model result: {primary_text}")
381
 
382
- print("🔄 Starting Pass 2: Specialized model transcription...")
383
  specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
384
- print(f"✅ Specialized model result: {specialized_text}")
385
 
386
  actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
387
 
388
  if str(actual_text).startswith("Error:"):
389
- print(f"❌ Transcription error: {actual_text}")
390
  return (f"❌ {actual_text}", "", "", "", "", "", "", "")
391
 
392
- # Normalize for metrics if enabled
393
- ref_for_metrics = normalize_text(intended_sentence)
394
- hyp_for_metrics = normalize_text(actual_text)
395
-
396
  try:
397
- print("🔄 Calculating error metrics...")
398
- wer_val = jiwer.wer(ref_for_metrics, hyp_for_metrics)
399
- cer_val = jiwer.cer(ref_for_metrics, hyp_for_metrics)
400
- print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}")
401
  except Exception as e:
402
- print(f"❌ Error calculating metrics: {e}")
403
  wer_val, cer_val = 1.0, 1.0
404
 
405
  score_text, feedback = get_pronunciation_score(wer_val, cer_val)
406
 
407
- print("🔄 Generating transliterations...")
408
  actual_hk = transliterate_to_hk(actual_text, language_choice)
409
  target_hk = transliterate_to_hk(intended_sentence, language_choice)
410
-
411
- if not is_script(actual_text, language_choice) and language_choice != "English":
412
  actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
413
 
414
- print("🔄 Generating visual feedback...")
415
  diff_html = highlight_differences(intended_sentence, actual_text)
416
  char_html = char_level_highlight(intended_sentence, actual_text)
417
 
418
- status = f"✅ Advanced Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models"
419
- print(f"✅ Advanced analysis completed successfully")
420
 
421
  return (
422
  status,
@@ -431,14 +442,11 @@ def compare_pronunciation(audio, language_choice, intended_sentence):
431
 
432
  except Exception as e:
433
  error_msg = f"❌ Analysis Error: {str(e)[:200]}"
434
- print(f"❌ FATAL ERROR: {e}")
435
- import traceback
436
- traceback.print_exc()
437
  return (error_msg, str(e), "", "", "", "", "", "")
438
 
 
439
  def create_interface():
440
  with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
441
-
442
  gr.Markdown("""
443
  # 🎙️ Advanced Multilingual Pronunciation Trainer
444
 
@@ -446,12 +454,12 @@ def create_interface():
446
 
447
  ### 🏆 Powered by Advanced Models:
448
  - Dual-Model Analysis: Primary + specialized model comparison
449
- - High Accuracy: Language-specific fine-tuned models
450
  - Robust Performance: Automatic fallback for reliability
451
 
452
  ### 📋 How to Use:
453
  1. Select your target language 🌍
454
- 2. Generate a practice sentence 🎲
455
  3. Record yourself reading it aloud 🎤
456
  4. Get detailed feedback with advanced accuracy 📊
457
 
@@ -520,7 +528,7 @@ def create_interface():
520
  gr.Markdown("""
521
  ### 🎨 Color Guide:
522
  - 🟢 Green: Correctly pronounced words/characters
523
- - 🔴 Red: Missing or mispronounced (strikethrough)
524
  - 🟠 Orange: Extra words or substitutions
525
  """)
526
  diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
@@ -542,14 +550,14 @@ def create_interface():
542
  fn=compare_pronunciation,
543
  inputs=[audio_input, lang_choice, intended_display],
544
  outputs=[
545
- status_output,
546
- pass1_out,
547
- pass2_out,
548
- wer_out,
549
- cer_out,
550
- diff_html_box,
551
- char_html_box,
552
- target_display
553
  ]
554
  )
555
 
@@ -563,27 +571,33 @@ def create_interface():
563
  ---
564
  ### 🏆 Advanced Technology Stack:
565
  - Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
566
- - Specialized Models:
567
- - Tamil: vasista22/whisper-tamil-large-v2
568
- - Malayalam: thennal/whisper-medium-ml
569
- - English: OpenAI Whisper Base EN
570
- - Dual Analysis and Automatic Fallback
 
571
 
572
  ### 🔧 Technical Details:
573
- - Metrics: WER and CER
574
- - Transliteration: Harvard-Kyoto for Indic scripts
575
- - Languages: English, Tamil, Malayalam
 
576
  """)
577
  return demo
578
 
 
579
  if __name__ == "__main__":
580
  print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
581
- print(f"🔧 Device: {DEVICE} (index={DEVICE_INDEX})")
582
- print(f"🔧 PyTorch version: {torch.__version__}")
 
 
 
 
583
  print("🏆 Using High-Performance Dual-Model Approach")
584
  print("⚡ Automatic model selection with specialized fallbacks")
585
  print("📊 Advanced analysis with robust error handling")
586
- print("🎮 GPU functions decorated with @spaces.GPU for HuggingFace Spaces")
587
 
588
  demo = create_interface()
589
  demo.launch(
 
5
  import jiwer
6
  import torch
7
  import numpy as np
8
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, WhisperProcessor
9
  import librosa
10
  import soundfile as sf
11
  from indic_transliteration import sanscript
12
  from indic_transliteration.sanscript import transliterate
 
13
  import warnings
14
+
15
+ # Optional: only available on HF Spaces runtime
16
+ try:
17
+ import spaces
18
+ GPU_DECORATOR = spaces.GPU
19
+ except Exception:
20
+ def GPU_DECORATOR(fn):
21
+ return fn
22
 
23
  warnings.filterwarnings("ignore")
24
 
25
  # ---------------- CONFIG ---------------- #
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ CUDA_DEVICE_INDEX = 0 if torch.cuda.is_available() else -1 # for transformers pipeline device
28
  print(f"🔧 Using device: {DEVICE}")
 
29
 
30
  LANG_CODES = {
31
  "English": "en",
 
33
  "Malayalam": "ml"
34
  }
35
 
36
+ # Primary model
37
  INDICWHISPER_MODEL = "openai/whisper-large-v2"
38
 
39
+ # Specialized models
40
  SPECIALIZED_MODELS = {
41
  "English": "openai/whisper-base.en",
42
  "Tamil": "vasista22/whisper-tamil-large-v2",
 
91
  ]
92
  }
93
 
 
 
 
 
94
  # ---------------- MODEL CACHE ---------------- #
95
  indicwhisper_pipeline = None
96
  fallback_models = {}
97
+ WHISPER_JAX_AVAILABLE = False # default false; will set true if we load it
98
+
99
+ @GPU_DECORATOR
100
+ def load_indicwhisper():
101
+ """Load primary high-performance model (prefer transformers pipeline, optionally JAX if available)."""
102
+ global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
103
+
104
+ if indicwhisper_pipeline is not None:
105
+ return indicwhisper_pipeline
106
+
107
+ # Try JAX first (optional)
108
+ try:
109
+ from whisper_jax import FlaxWhisperPipeline
110
+ import jax.numpy as jnp
111
+ print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
112
+ indicwhisper_pipeline = FlaxWhisperPipeline(
113
+ INDICWHISPER_MODEL,
114
+ dtype=jnp.bfloat16,
115
+ batch_size=1
116
+ )
117
+ WHISPER_JAX_AVAILABLE = True
118
+ print("✅ JAX-optimized model loaded successfully!")
119
+ return indicwhisper_pipeline
120
+ except Exception as e:
121
+ print(f"⚠️ JAX loading failed: {e}")
122
+ WHISPER_JAX_AVAILABLE = False
123
+
124
+ # Fallback to transformers pipeline
125
+ try:
126
+ from transformers import pipeline
127
+ print(f"🔄 Loading transformers ASR pipeline: {INDICWHISPER_MODEL}")
128
+ indicwhisper_pipeline = pipeline(
129
+ task="automatic-speech-recognition",
130
+ model=INDICWHISPER_MODEL,
131
+ device=CUDA_DEVICE_INDEX # 0 for CUDA, -1 for CPU
132
+ )
133
+ print("✅ Transformers ASR pipeline loaded!")
134
+ return indicwhisper_pipeline
135
+ except Exception as e:
136
+ print(f"❌ Failed to load primary model: {e}")
137
+ indicwhisper_pipeline = None
138
+ raise Exception(f"Could not load primary model: {str(e)}")
139
+
140
+ @GPU_DECORATOR
141
+ def load_specialized_model(language: str):
142
+ """Load language-specific specialized model with processor."""
143
+ if language in fallback_models:
144
+ return fallback_models[language]
145
 
146
+ model_name = SPECIALIZED_MODELS[language]
147
+ print(f"🔄 Loading specialized model for {language}: {model_name}")
 
 
 
 
 
148
 
149
+ try:
150
+ # WhisperProcessor ensures get_decoder_prompt_ids is available
151
+ try:
152
+ processor = WhisperProcessor.from_pretrained(model_name)
153
+ except Exception:
154
+ processor = AutoProcessor.from_pretrained(model_name)
155
+
156
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
157
+ model_name,
158
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
159
+ low_cpu_mem_usage=True
160
+ )
161
+ model.to(DEVICE)
162
+
163
+ fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
164
+ print(f"✅ Specialized model loaded for {language}")
165
+ return fallback_models[language]
166
+ except Exception as e:
167
+ print(f"❌ Failed to load specialized {model_name}: {e}")
168
+ raise Exception(f"Could not load specialized {language} model: {str(e)}")
169
+
170
+ # ---------------- HELPERS ---------------- #
171
  def get_random_sentence(language_choice):
172
  return random.choice(SENTENCE_BANK[language_choice])
173
 
 
175
  pattern = SCRIPT_PATTERNS.get(lang_name)
176
  if not pattern:
177
  return True
178
+ return bool(pattern.search(text or ""))
 
 
 
 
 
 
 
 
 
179
 
180
  def transliterate_to_hk(text, lang_choice):
181
  mapping = {
 
184
  "English": None
185
  }
186
  script = mapping.get(lang_choice)
187
+ if not text:
188
+ return ""
189
  if script and is_script(text, lang_choice):
190
  try:
191
  return transliterate(text, script, sanscript.HK)
 
195
  return text
196
 
197
  def preprocess_audio(audio_path, target_sr=16000):
198
+ """Load, normalize, trim, return float32 audio."""
199
  try:
200
+ audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
201
+ if audio is None or len(audio) == 0:
202
+ return None, None
203
+ # Normalize
204
+ m = np.max(np.abs(audio))
205
+ if m > 0:
206
+ audio = audio / m
207
+ # Trim silence
208
  audio, _ = librosa.effects.trim(audio, top_db=20)
209
+ # Ensure min length
210
+ if len(audio) < int(target_sr * 0.1):
211
  return None, None
212
+ # Ensure float32
213
+ if audio.dtype != np.float32:
214
+ audio = audio.astype(np.float32)
215
  return audio, target_sr
216
  except Exception as e:
217
  print(f"Audio preprocessing error: {e}")
218
  return None, None
219
 
220
+ @GPU_DECORATOR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def transcribe_with_primary_model(audio_path, language):
222
+ """Transcribe using primary model (JAX if available else transformers pipeline)."""
223
  try:
224
  pipe = load_indicwhisper()
225
 
226
+ lang_code = LANG_CODES.get(language, "en")
227
+
228
+ if WHISPER_JAX_AVAILABLE:
229
+ # whisper-jax expects array or path; pass path is okay
230
+ result = pipe(audio_path, task="transcribe", language=lang_code)
231
+ # whisper-jax returns dict with 'text'
232
+ if isinstance(result, dict) and "text" in result:
233
+ return (result["text"] or "").strip()
234
+ return str(result).strip()
235
+
236
+ # transformers pipeline
237
+ # Some transformers versions accept language/task via generate_kwargs
238
+ generate_kwargs = {}
239
+ try:
240
+ # If underlying model is Whisper, we can set forced decoder ids
241
+ model = pipe.model if hasattr(pipe, "model") else None
242
+ tokenizer = getattr(pipe, "tokenizer", None)
243
+ processor = getattr(pipe, "feature_extractor", None)
244
+ if hasattr(pipe, "tokenizer") and hasattr(model, "config"):
245
  try:
246
+ forced_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
247
+ model.config.forced_decoder_ids = forced_ids
 
 
 
 
248
  except Exception as e:
249
+ print(f"⚠️ Primary model language forcing failed: {e}")
250
+ except Exception as e:
251
+ print(f"⚠️ Primary model prompt config error: {e}")
252
 
253
+ out = pipe(audio_path, generate_kwargs=generate_kwargs)
254
+ if isinstance(out, dict) and "text" in out:
255
+ return (out["text"] or "").strip()
256
+ elif isinstance(out, str):
257
+ return out.strip()
 
 
258
  else:
259
+ return str(out).strip()
260
  except Exception as e:
261
  print(f"Primary model transcription error: {e}")
262
+ return f"Error: {str(e)[:200]}"
263
 
264
+ @GPU_DECORATOR
265
  def transcribe_with_specialized_model(audio_path, language):
266
+ """Transcribe using language-specific models."""
267
  try:
268
  components = load_specialized_model(language)
269
  processor = components["processor"]
 
273
  if audio is None:
274
  return "Error: Audio too short or could not be processed"
275
 
276
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
277
+ # WhisperProcessor returns input_features
278
+ input_features = inputs.get("input_features", None)
279
+ if input_features is None:
280
+ # Fallback: some processors use feature_extractor path
281
+ input_features = inputs.get("input_values", None)
282
+ if input_features is None:
283
+ return "Error: Could not prepare input features"
284
+
285
+ input_features = input_features.to(DEVICE)
286
+
287
+ generate_kwargs = {
288
+ "max_length": 200,
289
+ "num_beams": 3,
290
+ "do_sample": False
291
+ }
292
 
 
293
  if language != "English":
294
  lang_code = LANG_CODES.get(language, "en")
295
  try:
 
298
  language=lang_code,
299
  task="transcribe"
300
  )
301
+ generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
302
+ elif hasattr(model, "config") and hasattr(processor, "tokenizer"):
303
+ forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
304
+ language=lang_code,
305
+ task="transcribe"
306
+ )
307
+ model.config.forced_decoder_ids = forced_decoder_ids
308
  except Exception as e:
309
  print(f"⚠️ Language forcing failed: {e}")
310
 
311
  with torch.no_grad():
312
+ predicted_ids = model.generate(input_features=input_features, **generate_kwargs)
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  transcription = processor.batch_decode(
315
  predicted_ids,
316
  skip_special_tokens=True,
317
  clean_up_tokenization_spaces=True
318
  )[0]
319
+ return (transcription or "").strip() or "(No transcription generated)"
 
320
  except Exception as e:
321
  print(f"Specialized model transcription error: {e}")
322
+ return f"Error: {str(e)[:200]}"
323
 
324
+ @GPU_DECORATOR
325
  def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
326
+ """Dispatch to primary or specialized path with fallback."""
327
  try:
328
  if use_specialized:
329
  print(f"🔄 Using specialized model for {language}")
330
  return transcribe_with_specialized_model(audio_path, language)
331
  else:
332
+ print(f"🔄 Using primary model for {language}")
333
  return transcribe_with_primary_model(audio_path, language)
334
  except Exception as e:
335
  print(f"Transcription failed, trying specialized model: {e}")
336
  if not use_specialized:
337
  return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
338
  else:
339
+ return f"Error: All transcription methods failed - {str(e)[:200]}"
340
 
341
  def highlight_differences(ref, hyp):
342
+ if not (ref or "").strip() or not (hyp or "").strip():
343
  return "No text to compare"
 
344
  ref_words = ref.strip().split()
345
  hyp_words = hyp.strip().split()
 
346
  sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
347
  out_html = []
 
348
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
349
  if tag == 'equal':
350
  out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
 
355
  out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
356
  elif tag == 'insert':
357
  out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
 
358
  return " ".join(out_html)
359
 
360
  def char_level_highlight(ref, hyp):
361
+ if not (ref or "").strip() or not (hyp or "").strip():
362
  return "No text to compare"
 
363
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
364
  out = []
365
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
 
384
  else:
385
  return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
386
 
387
+ # ---------------- MAIN FUNCTION ---------------- #
388
+ @GPU_DECORATOR
389
  def compare_pronunciation(audio, language_choice, intended_sentence):
390
+ print(f"🔍 Starting analysis with language: {language_choice}")
391
  print(f"📝 Audio file: {audio}")
392
  print(f"🎯 Intended sentence: {intended_sentence}")
393
 
394
  if audio is None:
 
395
  return ("❌ Please record audio first.", "", "", "", "", "", "", "")
396
+ if not (intended_sentence or "").strip():
 
 
397
  return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
398
 
399
  try:
400
+ print("🔄 Pass 1: Primary model transcription...")
401
  primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
402
+ print(f"✅ Primary: {primary_text}")
403
 
404
+ print("🔄 Pass 2: Specialized model transcription...")
405
  specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
406
+ print(f"✅ Specialized: {specialized_text}")
407
 
408
  actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
409
 
410
  if str(actual_text).startswith("Error:"):
 
411
  return (f"❌ {actual_text}", "", "", "", "", "", "", "")
412
 
 
 
 
 
413
  try:
414
+ wer_val = jiwer.wer(intended_sentence, actual_text)
415
+ cer_val = jiwer.cer(intended_sentence, actual_text)
 
 
416
  except Exception as e:
417
+ print(f"❌ Metrics error: {e}")
418
  wer_val, cer_val = 1.0, 1.0
419
 
420
  score_text, feedback = get_pronunciation_score(wer_val, cer_val)
421
 
 
422
  actual_hk = transliterate_to_hk(actual_text, language_choice)
423
  target_hk = transliterate_to_hk(intended_sentence, language_choice)
424
+ if language_choice != "English" and not is_script(actual_text, language_choice):
 
425
  actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
426
 
 
427
  diff_html = highlight_differences(intended_sentence, actual_text)
428
  char_html = char_level_highlight(intended_sentence, actual_text)
429
 
430
+ status = f"✅ Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models"
 
431
 
432
  return (
433
  status,
 
442
 
443
  except Exception as e:
444
  error_msg = f"❌ Analysis Error: {str(e)[:200]}"
 
 
 
445
  return (error_msg, str(e), "", "", "", "", "", "")
446
 
447
+ # ---------------- UI ---------------- #
448
  def create_interface():
449
  with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
 
450
  gr.Markdown("""
451
  # 🎙️ Advanced Multilingual Pronunciation Trainer
452
 
 
454
 
455
  ### 🏆 Powered by Advanced Models:
456
  - Dual-Model Analysis: Primary + specialized model comparison
457
+ - High Accuracy: Language-specific fine-tuned models
458
  - Robust Performance: Automatic fallback for reliability
459
 
460
  ### 📋 How to Use:
461
  1. Select your target language 🌍
462
+ 2. Generate a practice sentence 🎲
463
  3. Record yourself reading it aloud 🎤
464
  4. Get detailed feedback with advanced accuracy 📊
465
 
 
528
  gr.Markdown("""
529
  ### 🎨 Color Guide:
530
  - 🟢 Green: Correctly pronounced words/characters
531
+ - 🔴 Red: Missing or mispronounced (strikethrough)
532
  - 🟠 Orange: Extra words or substitutions
533
  """)
534
  diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
 
550
  fn=compare_pronunciation,
551
  inputs=[audio_input, lang_choice, intended_display],
552
  outputs=[
553
+ status_output, # status
554
+ pass1_out, # primary transcription
555
+ pass2_out, # specialized transcription
556
+ wer_out, # wer formatted
557
+ cer_out, # cer formatted
558
+ diff_html_box, # diff_html
559
+ char_html_box, # char_html
560
+ target_display # target_display
561
  ]
562
  )
563
 
 
571
  ---
572
  ### 🏆 Advanced Technology Stack:
573
  - Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
574
+ - Specialized Models: Fine-tuned language-specific models
575
+ - Tamil: vasista22/whisper-tamil-large-v2 (IIT Madras Speech Lab)
576
+ - Malayalam: thennal/whisper-medium-ml (Common Voice trained)
577
+ - English: openai/whisper-base.en (English-optimized)
578
+ - Dual Analysis: Primary + specialized model comparison
579
+ - Automatic Fallback: Ensures reliable results always
580
 
581
  ### 🔧 Technical Details:
582
+ - Metrics: WER (Word Error Rate) and CER (Character Error Rate)
583
+ - Transliteration: Harvard-Kyoto system for Indic scripts
584
+ - Analysis: Dual-model comparison for comprehensive feedback
585
+ - Languages: English, Tamil, and Malayalam
586
  """)
587
  return demo
588
 
589
+ # ---------------- LAUNCH ---------------- #
590
  if __name__ == "__main__":
591
  print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
592
+ print(f"🔧 Device: {DEVICE}")
593
+ try:
594
+ torch_ver = getattr(torch, '__version__', 'unknown')
595
+ except Exception:
596
+ torch_ver = 'unknown'
597
+ print(f"🔧 PyTorch version: {torch_ver}")
598
  print("🏆 Using High-Performance Dual-Model Approach")
599
  print("⚡ Automatic model selection with specialized fallbacks")
600
  print("📊 Advanced analysis with robust error handling")
 
601
 
602
  demo = create_interface()
603
  demo.launch(