sudhanm commited on
Commit
189dfd8
ยท
verified ยท
1 Parent(s): 60fa434

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -250
app.py CHANGED
@@ -1,25 +1,19 @@
1
  import gradio as gr
2
- import random
3
- import difflib
4
- import re
5
- import warnings
6
  import torch
7
  import numpy as np
8
- import librosa
9
- import soundfile as sf
10
  import jiwer
11
 
12
- # Optional: Indic transliteration
13
  try:
14
  from indic_transliteration import sanscript
15
  from indic_transliteration.sanscript import transliterate
16
  INDIC_OK = True
17
  except:
18
  INDIC_OK = False
19
- sanscript = None
20
- transliterate = None
21
 
22
- # Optional: HF Spaces GPU decorator
23
  try:
24
  import spaces
25
  GPU_DECORATOR = spaces.GPU
@@ -34,60 +28,31 @@ warnings.filterwarnings("ignore")
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
  DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
36
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
37
  print(f"๐Ÿ”ง Using device: {DEVICE}")
38
 
39
- LANG_CODES = {
40
- "English": "en",
41
- "Tamil": "ta",
42
- "Malayalam": "ml",
43
- }
44
 
45
- # AI4Bharat IndicWhisper community port
46
  INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
47
 
 
48
  SPECIALIZED_MODELS = {
49
  "English": "openai/whisper-base.en",
50
  "Tamil": "vasista22/whisper-tamil-large-v2",
51
  "Malayalam": "thennal/whisper-medium-ml",
52
  }
53
 
 
54
  SCRIPT_PATTERNS = {
55
  "Tamil": re.compile(r"[เฎ€-เฏฟ]"),
56
  "Malayalam": re.compile(r"[เด€-เตฟ]"),
57
  "English": re.compile(r"[A-Za-z]"),
58
  }
59
-
60
  SENTENCE_BANK = {
61
- "English": [
62
- "The sun sets over the beautiful horizon.",
63
- "Learning new languages opens many doors.",
64
- "I enjoy reading books in the evening.",
65
- "Technology has changed our daily lives.",
66
- "Music brings people together across cultures.",
67
- "Education is the key to a bright future.",
68
- "The flowers bloom beautifully in spring.",
69
- "Hard work always pays off in the end.",
70
- ],
71
- "Tamil": [
72
- "เฎ‡เฎฉเฏเฎฑเฏ เฎจเฎฒเฏเฎฒ เฎตเฎพเฎฉเฎฟเฎฒเฏˆ เฎ‰เฎณเฏเฎณเฎคเฏ.",
73
- "เฎจเฎพเฎฉเฏ เฎคเฎฎเฎฟเฎดเฏ เฎ•เฎฑเฏเฎฑเฏเฎ•เฏเฎ•เฏŠเฎฃเฏเฎŸเฏ เฎ‡เฎฐเฏเฎ•เฏเฎ•เฎฟเฎฑเฏ‡เฎฉเฏ.",
74
- "เฎŽเฎฉเฎ•เฏเฎ•เฏ เฎชเฏเฎคเฏเฎคเฎ•เฎฎเฏ เฎชเฎŸเฎฟเฎ•เฏเฎ• เฎตเฎฟเฎฐเฏเฎชเฏเฎชเฎฎเฏ.",
75
- "เฎคเฎฎเฎฟเฎดเฏ เฎฎเฏŠเฎดเฎฟ เฎฎเฎฟเฎ•เฎตเฏเฎฎเฏ เฎ…เฎดเฎ•เฎพเฎฉเฎคเฏ.",
76
- "เฎ•เฏเฎŸเฏเฎฎเฏเฎชเฎคเฏเฎคเฏเฎŸเฎฉเฏ เฎจเฏ‡เฎฐเฎฎเฏ เฎšเฏ†เฎฒเฎตเฎฟเฎŸเฏเฎตเฎคเฏ เฎฎเฏเฎ•เฏเฎ•เฎฟเฎฏเฎฎเฏ.",
77
- "เฎ•เฎฒเฏเฎตเฎฟ เฎจเฎฎเฎคเฏ เฎŽเฎคเฎฟเฎฐเฏเฎ•เฎพเฎฒเฎคเฏเฎคเฎฟเฎฉเฏ เฎคเฎฟเฎฑเฎตเฏเฎ•เฏ‹เฎฒเฏ.",
78
- "เฎชเฎฑเฎตเฏˆเฎ•เฎณเฏ เฎ•เฎพเฎฒเฏˆเฎฏเฎฟเฎฒเฏ เฎ‡เฎฉเฎฟเฎฎเฏˆเฎฏเฎพเฎ• เฎชเฎพเฎŸเฏเฎ•เฎฟเฎฉเฏเฎฑเฎฉ.",
79
- "เฎ‰เฎดเฏˆเฎชเฏเฎชเฏ เฎŽเฎชเฏเฎชเฏ‹เฎคเฏเฎฎเฏ เฎตเฏ†เฎฑเฏเฎฑเฎฟเฎฏเฏˆเฎคเฏ เฎคเฎฐเฏเฎฎเฏ.",
80
- ],
81
- "Malayalam": [
82
- "เดŽเดจเดฟเด•เตเด•เต เดฎเดฒเดฏเดพเดณเด‚ เดตเดณเดฐเต† เด‡เดทเตเดŸเดฎเดพเดฃเต.",
83
- "เด‡เดจเตเดจเต เดฎเดดเดชเต†เดฏเตเดฏเตเดจเตเดจเต.",
84
- "เดžเดพเตป เดชเตเดธเตเดคเด•เด‚ เดตเดพเดฏเดฟเด•เตเด•เตเดจเตเดจเต.",
85
- "เด•เต‡เดฐเดณเดคเตเดคเดฟเดจเตเดฑเต† เดชเตเดฐเด•เตƒเดคเดฟ เดธเตเดจเตเดฆเดฐเดฎเดพเดฃเต.",
86
- "เดตเดฟเดฆเตเดฏเดพเดญเตเดฏเดพเดธเด‚ เดœเต€เดตเดฟเดคเดคเตเดคเดฟเตฝ เดชเตเดฐเดงเดพเดจเดฎเดพเดฃเต.",
87
- "เดธเด‚เด—เต€เดคเด‚ เดฎเดจเดธเตเดธเดฟเดจเต เดธเดจเตเดคเต‹เดทเด‚ เดจเตฝเด•เตเดจเตเดจเต.",
88
- "เด•เตเดŸเตเด‚เดฌเดธเดฎเดฏเด‚ เดตเดณเดฐเต† เดตเดฟเดฒเดชเตเดชเต†เดŸเตเดŸเดคเดพเดฃเต.",
89
- "เด•เด เดฟเดจเดพเดงเตเดตเดพเดจเด‚ เดŽเดชเตเดชเต‹เดดเตเด‚ เดซเดฒเด‚ เดจเตฝเด•เตเด‚.",
90
- ],
91
  }
92
 
93
  # Model cache
@@ -100,180 +65,86 @@ def get_random_sentence(language_choice):
100
  return random.choice(SENTENCE_BANK[language_choice])
101
 
102
  def is_script(text, lang_name):
103
- pattern = SCRIPT_PATTERNS.get(lang_name)
104
- if not pattern:
105
- return True
106
- return bool(pattern.search(text or ""))
107
 
108
  def transliterate_to_hk(text, lang_choice):
109
  if not INDIC_OK:
110
  return text
111
- mapping = {
112
- "Tamil": sanscript.TAMIL,
113
- "Malayalam": sanscript.MALAYALAM,
114
- "English": None
115
- }
116
  script = mapping.get(lang_choice)
117
  if script and is_script(text, lang_choice):
118
- try:
119
- return transliterate(text, script, sanscript.HK)
120
- except:
121
- return text
122
  return text
123
 
124
  def preprocess_audio(audio_path, target_sr=16000):
125
  try:
126
  audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
127
- if audio is None or len(audio) == 0:
128
- return None, None
129
  audio = audio.astype(np.float32)
130
- max_abs = np.max(np.abs(audio))
131
- if max_abs > 0:
132
- audio /= max_abs
133
  audio, _ = librosa.effects.trim(audio, top_db=20)
134
- if len(audio) < target_sr * 0.1:
135
- return None, None
136
  return audio, target_sr
137
- except Exception as e:
138
- print(f"Audio preprocessing error: {e}")
139
- return None, None
140
-
141
- # Normalization for WER
142
- JIWER_TRANSFORM = jiwer.Compose([
143
- jiwer.ToLowerCase(),
144
- jiwer.RemovePunctuation(),
145
- jiwer.RemoveMultipleSpaces(),
146
- jiwer.Strip(),
147
- jiwer.ReduceToListOfListOfWords(),
148
- ])
149
 
150
- def compute_wer(ref, hyp):
151
- try:
152
- return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
153
- except:
154
- return 1.0
155
-
156
- def compute_cer(ref, hyp):
157
- try:
158
- return jiwer.cer(ref, hyp)
159
- except:
160
- return 1.0
161
-
162
- def highlight_differences(ref, hyp):
163
- if not ref.strip() or not hyp.strip():
164
- return "No text to compare"
165
- ref_words = ref.strip().split()
166
- hyp_words = hyp.strip().split()
167
- sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
168
- out_html = []
169
- for tag, i1, i2, j1, j2 in sm.get_opcodes():
170
- if tag == 'equal':
171
- out_html.extend([f"<span style='color:green; background-color:#e8f5e8;'>{w}</span>" for w in ref_words[i1:i2]])
172
- elif tag == 'replace':
173
- out_html.extend([f"<span style='color:red; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]])
174
- out_html.extend([f"<span style='color:orange;'>โ†’{w}</span>" for w in hyp_words[j1:j2]])
175
- elif tag == 'delete':
176
- out_html.extend([f"<span style='color:red; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]])
177
- elif tag == 'insert':
178
- out_html.extend([f"<span style='color:orange;'>+{w}</span>" for w in hyp_words[j1:j2]])
179
- return " ".join(out_html)
180
-
181
- def char_level_highlight(ref, hyp):
182
- if not ref.strip() or not hyp.strip():
183
- return "No text to compare"
184
- sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
185
- out = []
186
- for tag, i1, i2, j1, j2 in sm.get_opcodes():
187
- if tag == 'equal':
188
- out.extend([f"<span style='color:green;'>{c}</span>" for c in ref[i1:i2]])
189
- elif tag in ('replace', 'delete'):
190
- out.extend([f"<span style='color:red;'>{c}</span>" for c in ref[i1:i2]])
191
- elif tag == 'insert':
192
- out.extend([f"<span style='color:orange;'>{c}</span>" for c in hyp[j1:j2]])
193
- return "".join(out)
194
-
195
- def get_pronunciation_score(wer_val, cer_val):
196
- combined = (wer_val * 0.7) + (cer_val * 0.3)
197
- if combined <= 0.1:
198
- return "๐Ÿ† Excellent! (90%+)", "Your pronunciation is outstanding!"
199
- elif combined <= 0.2:
200
- return "๐ŸŽ‰ Very Good! (80-90%)", "Great pronunciation with minor areas for improvement."
201
- elif combined <= 0.4:
202
- return "๐Ÿ‘ Good! (60-80%)", "Good effort! Keep practicing."
203
- elif combined <= 0.6:
204
- return "๐Ÿ“š Needs Practice (40-60%)", "Focus on clearer pronunciation."
205
- else:
206
- return "๐Ÿ’ช Keep Trying! (<40%)", "Don't give up!"
207
 
208
- # ---------------- LOADERS ---------------- #
209
  @GPU_DECORATOR
210
  def load_indicwhisper():
211
  global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
212
- if indicwhisper_pipeline is not None:
213
- return indicwhisper_pipeline
214
- # Try JAX first
215
  try:
216
- from whisper_jax import FlaxWhisperPipeline
217
- import jax.numpy as jnp
218
- print(f"๐Ÿ”„ Loading JAX IndicWhisper: {INDICWHISPER_MODEL}")
219
- indicwhisper_pipeline = FlaxWhisperPipeline(
220
- INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1
221
- )
222
  WHISPER_JAX_AVAILABLE = True
223
- print("โœ… JAX Loaded!")
224
- return indicwhisper_pipeline
225
- except Exception as e:
226
- print(f"โš ๏ธ JAX unavailable: {e}")
227
- WHISPER_JAX_AVAILABLE = False
228
- # Fallback to Transformers
229
- try:
230
- from transformers import pipeline
231
- indicwhisper_pipeline = pipeline(
232
- "automatic-speech-recognition",
233
- model=INDICWHISPER_MODEL,
234
- device=DEVICE_INDEX
235
- )
236
- print("โœ… Transformers IndicWhisper loaded!")
237
  return indicwhisper_pipeline
238
  except Exception as e:
239
- print(f"โŒ Failed to load IndicWhisper: {e}")
240
- raise
 
 
 
241
 
242
  @GPU_DECORATOR
243
  def load_specialized_model(language):
244
- if language in fallback_models:
245
- return fallback_models[language]
246
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
247
- model_name = SPECIALIZED_MODELS[language]
248
- processor = AutoProcessor.from_pretrained(model_name)
249
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
250
- model_name, torch_dtype=DTYPE,
251
- low_cpu_mem_usage=True
252
- ).to(DEVICE)
253
- fallback_models[language] = {"processor": processor, "model": model}
254
  return fallback_models[language]
255
 
256
  # ---------------- TRANSCRIBE ---------------- #
257
  @GPU_DECORATOR
258
  def transcribe_with_primary_model(audio_path, language):
259
  try:
260
- pl = load_indicwhisper()
261
- lang_code = LANG_CODES.get(language, "en")
262
- # JAX
263
  if WHISPER_JAX_AVAILABLE:
264
- result = pl(audio_path, task="transcribe", language=lang_code)
265
- if isinstance(result, dict) and "text" in result:
266
- return result["text"].strip()
267
- return str(result).strip()
268
- # Transformers
269
  if hasattr(pl, "model") and hasattr(pl, "tokenizer"):
270
  try:
271
  forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
272
  pl.model.config.forced_decoder_ids = forced_ids
273
  except: pass
274
- out = pl(audio_path)
275
- if isinstance(out, dict) and 'text' in out:
276
- return out['text'].strip()
277
  return str(out).strip()
278
  except Exception as e:
279
  return f"Error: {str(e)}"
@@ -281,92 +152,110 @@ def transcribe_with_primary_model(audio_path, language):
281
  @GPU_DECORATOR
282
  def transcribe_with_specialized_model(audio_path, language):
283
  try:
284
- c = load_specialized_model(language)
285
  audio, sr = preprocess_audio(audio_path)
286
- if audio is None:
287
- return "Error: Audio too short"
288
- inputs = c["processor"](audio, sampling_rate=sr, return_tensors="pt")
289
- input_features = inputs.input_features.to(DEVICE)
290
- generate_kwargs = {"inputs": input_features, "max_length": 200, "num_beams": 3}
291
  if language != "English":
292
  try:
293
- forced_ids = c["processor"].tokenizer.get_decoder_prompt_ids(
294
- language=LANG_CODES[language], task="transcribe"
295
- )
296
- generate_kwargs["forced_decoder_ids"] = forced_ids
297
  except: pass
298
- with torch.no_grad():
299
- ids = c["model"].generate(**generate_kwargs)
300
- text = c["processor"].batch_decode(ids, skip_special_tokens=True)[0]
301
  return text.strip()
302
  except Exception as e:
303
  return f"Error: {str(e)}"
304
 
305
  @GPU_DECORATOR
306
  def transcribe_audio(audio_path, language, use_specialized=False):
307
- try:
308
- if use_specialized:
309
- return transcribe_with_specialized_model(audio_path, language)
310
- else:
311
- return transcribe_with_primary_model(audio_path, language)
312
- except:
313
- if not use_specialized:
314
- return transcribe_audio(audio_path, language, use_specialized=True)
315
- return "Error"
316
 
317
- # ---------------- MAIN FUNCTION ---------------- #
318
  @GPU_DECORATOR
319
- def compare_pronunciation(audio, language_choice, intended_sentence):
320
- if audio is None:
321
- return ("โŒ Please record audio first.", "", "", "", "", "", "", "")
322
- if not intended_sentence.strip():
323
- return ("โŒ Please generate a practice sentence first.", "", "", "", "", "", "", "")
324
- primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
325
- specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
326
- actual_text = primary_text if not primary_text.startswith("Error:") else specialized_text
327
- if actual_text.startswith("Error:"):
328
- return (f"โŒ {actual_text}", "", "", "", "", "", "", "")
329
- wer_val = compute_wer(intended_sentence, actual_text)
330
- cer_val = compute_cer(intended_sentence, actual_text)
331
- score_text, feedback = get_pronunciation_score(wer_val, cer_val)
332
- diff_html = highlight_differences(intended_sentence, actual_text)
333
- char_html = char_level_highlight(intended_sentence, actual_text)
334
- return (
335
- f"โœ… Analysis Complete - {score_text}\n๐Ÿ’ฌ {feedback}",
336
- primary_text, specialized_text,
337
- f"{wer_val:.3f} ({(1-wer_val)*100:.1f}% word accuracy)",
338
- f"{cer_val:.3f} ({(1-cer_val)*100:.1f}% char accuracy)",
339
- diff_html, char_html,
340
- f"๐ŸŽฏ Target: {intended_sentence}"
341
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  # ---------------- UI ---------------- #
344
  def create_interface():
345
- with gr.Blocks(title="๐ŸŽ™๏ธ IndicWhisper Pronunciation Trainer") as demo:
346
- gr.Markdown("# ๐ŸŽ™๏ธ IndicWhisper-based Pronunciation Trainer")
347
- with gr.Row():
348
- lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Tamil", label="๐ŸŒ Language")
349
- gen_btn = gr.Button("๐ŸŽฒ Generate Sentence")
350
- intended_display = gr.Textbox(label="๐Ÿ“ Practice Sentence", interactive=False, lines=3)
351
- audio_input = gr.Audio(sources=["microphone","upload"], type="filepath", label="๐ŸŽค Record")
352
- analyze_btn = gr.Button("๐Ÿ” Analyze")
353
- status_output = gr.Textbox(label="๐Ÿ“Š Results", interactive=False, lines=4)
354
  with gr.Row():
355
- pass1_out = gr.Textbox(label="๐Ÿ† Primary (IndicWhisper)", interactive=False)
356
- pass2_out = gr.Textbox(label="๐Ÿ”ง Specialized", interactive=False)
357
- wer_out = gr.Textbox(label="๐Ÿ“ˆ Word Accuracy", interactive=False)
358
- cer_out = gr.Textbox(label="๐Ÿ“Š Char Accuracy", interactive=False)
359
- diff_html_box = gr.HTML(label="Word-Level Analysis")
360
- char_html_box = gr.HTML(label="Character-Level Analysis")
361
- target_display = gr.Textbox(label="๐ŸŽฏ Reference", interactive=False, visible=False)
362
- gen_btn.click(get_random_sentence, [lang_choice], [intended_display])
363
- analyze_btn.click(compare_pronunciation,
364
- [audio_input, lang_choice, intended_display],
365
- [status_output, pass1_out, pass2_out, wer_out, cer_out, diff_html_box, char_html_box, target_display])
366
- lang_choice.change(get_random_sentence, [lang_choice], [intended_display])
 
 
 
 
 
367
  return demo
368
 
369
- # ---------------- LAUNCH ---------------- #
370
  if __name__ == "__main__":
371
  demo = create_interface()
372
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import gradio as gr
2
+ import random, difflib, re, warnings, contextlib
 
 
 
3
  import torch
4
  import numpy as np
5
+ import librosa, soundfile as sf
 
6
  import jiwer
7
 
8
+ # Optional transliteration
9
  try:
10
  from indic_transliteration import sanscript
11
  from indic_transliteration.sanscript import transliterate
12
  INDIC_OK = True
13
  except:
14
  INDIC_OK = False
 
 
15
 
16
+ # Optional HF Spaces decorator
17
  try:
18
  import spaces
19
  GPU_DECORATOR = spaces.GPU
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
30
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
31
+ amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext
32
  print(f"๐Ÿ”ง Using device: {DEVICE}")
33
 
34
+ LANG_CODES = {"English": "en", "Tamil": "ta", "Malayalam": "ml"}
 
 
 
 
35
 
36
+ # Primary: IndicWhisper
37
  INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
38
 
39
+ # Specialised fallbacks
40
  SPECIALIZED_MODELS = {
41
  "English": "openai/whisper-base.en",
42
  "Tamil": "vasista22/whisper-tamil-large-v2",
43
  "Malayalam": "thennal/whisper-medium-ml",
44
  }
45
 
46
+ # Scripts and banking
47
  SCRIPT_PATTERNS = {
48
  "Tamil": re.compile(r"[เฎ€-เฏฟ]"),
49
  "Malayalam": re.compile(r"[เด€-เตฟ]"),
50
  "English": re.compile(r"[A-Za-z]"),
51
  }
 
52
  SENTENCE_BANK = {
53
+ "English": ["The sun sets over the beautiful horizon.", "Hard work always pays off in the end."],
54
+ "Tamil": ["เฎ‡เฎฉเฏเฎฑเฏ เฎจเฎฒเฏเฎฒ เฎตเฎพเฎฉเฎฟเฎฒเฏˆ เฎ‰เฎณเฏเฎณเฎคเฏ.", "เฎ‰เฎดเฏˆเฎชเฏเฎชเฏ เฎŽเฎชเฏเฎชเฏ‹เฎคเฏเฎฎเฏ เฎตเฏ†เฎฑเฏเฎฑเฎฟเฎฏเฏˆเฎคเฏ เฎคเฎฐเฏเฎฎเฏ."],
55
+ "Malayalam": ["เดŽเดจเดฟเด•เตเด•เต เดฎเดฒเดฏเดพเดณเด‚ เดตเดณเดฐเต† เด‡เดทเตเดŸเดฎเดพเดฃเต.", "เด•เด เดฟเดจเดพเดงเตเดตเดพเดจเด‚ เดŽเดชเตเดชเต‹เดดเตเด‚ เดซเดฒเด‚ เดจเตฝเด•เตเด‚."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
  # Model cache
 
65
  return random.choice(SENTENCE_BANK[language_choice])
66
 
67
  def is_script(text, lang_name):
68
+ p = SCRIPT_PATTERNS.get(lang_name)
69
+ return not p or bool(p.search(text or ""))
 
 
70
 
71
  def transliterate_to_hk(text, lang_choice):
72
  if not INDIC_OK:
73
  return text
74
+ mapping = {"Tamil": sanscript.TAMIL, "Malayalam": sanscript.MALAYALAM, "English": None}
 
 
 
 
75
  script = mapping.get(lang_choice)
76
  if script and is_script(text, lang_choice):
77
+ try: return transliterate(text, script, sanscript.HK)
78
+ except: return text
 
 
79
  return text
80
 
81
  def preprocess_audio(audio_path, target_sr=16000):
82
  try:
83
  audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
84
+ if audio is None or len(audio) == 0: return None, None
 
85
  audio = audio.astype(np.float32)
86
+ m = np.max(np.abs(audio))
87
+ if m > 0: audio /= m
 
88
  audio, _ = librosa.effects.trim(audio, top_db=20)
89
+ if len(audio) < int(target_sr*0.1): return None, None
 
90
  return audio, target_sr
91
+ except: return None, None
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ JIWER_TRANSFORM = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation(),
94
+ jiwer.RemoveMultipleSpaces(), jiwer.Strip(),
95
+ jiwer.ReduceToListOfListOfWords()])
96
+ def compute_wer(ref,hyp):
97
+ try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
98
+ except: return 1.0
99
+ def compute_cer(ref,hyp):
100
+ try: return jiwer.cer(ref, hyp)
101
+ except: return 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # ---------------- MODEL LOADERS ---------------- #
104
  @GPU_DECORATOR
105
  def load_indicwhisper():
106
  global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
107
+ if indicwhisper_pipeline: return indicwhisper_pipeline
 
 
108
  try:
109
+ from whisper_jax import FlaxWhisperPipeline; import jax.numpy as jnp
110
+ indicwhisper_pipeline = FlaxWhisperPipeline(INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1)
 
 
 
 
111
  WHISPER_JAX_AVAILABLE = True
112
+ print("โœ… JAX IndicWhisper loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  return indicwhisper_pipeline
114
  except Exception as e:
115
+ print(f"โš ๏ธ JAX unavailable: {e}"); WHISPER_JAX_AVAILABLE = False
116
+ from transformers import pipeline
117
+ indicwhisper_pipeline = pipeline("automatic-speech-recognition", model=INDICWHISPER_MODEL, device=DEVICE_INDEX)
118
+ print("โœ… Transformers IndicWhisper loaded!")
119
+ return indicwhisper_pipeline
120
 
121
  @GPU_DECORATOR
122
  def load_specialized_model(language):
123
+ if language in fallback_models: return fallback_models[language]
 
124
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
125
+ name = SPECIALIZED_MODELS[language]
126
+ proc = AutoProcessor.from_pretrained(name)
127
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(name, torch_dtype=DTYPE).to(DEVICE)
128
+ fallback_models[language] = {"processor": proc, "model": model}
 
 
 
129
  return fallback_models[language]
130
 
131
  # ---------------- TRANSCRIBE ---------------- #
132
  @GPU_DECORATOR
133
  def transcribe_with_primary_model(audio_path, language):
134
  try:
135
+ pl = load_indicwhisper(); lang_code = LANG_CODES.get(language, "en")
 
 
136
  if WHISPER_JAX_AVAILABLE:
137
+ res = pl(audio_path, task="transcribe", language=lang_code)
138
+ if isinstance(res, dict): return res.get("text","").strip()
139
+ return str(res).strip()
 
 
140
  if hasattr(pl, "model") and hasattr(pl, "tokenizer"):
141
  try:
142
  forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
143
  pl.model.config.forced_decoder_ids = forced_ids
144
  except: pass
145
+ with amp_ctx():
146
+ out = pl(audio_path)
147
+ if isinstance(out, dict): return (out.get("text") or "").strip()
148
  return str(out).strip()
149
  except Exception as e:
150
  return f"Error: {str(e)}"
 
152
  @GPU_DECORATOR
153
  def transcribe_with_specialized_model(audio_path, language):
154
  try:
155
+ comp = load_specialized_model(language)
156
  audio, sr = preprocess_audio(audio_path)
157
+ if audio is None: return "Error: Audio too short"
158
+ inputs = comp["processor"](audio, sampling_rate=sr, return_tensors="pt")
159
+ feats = inputs.input_features.to(DEVICE)
160
+ gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3}
 
161
  if language != "English":
162
  try:
163
+ forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(LANG_CODES[language], task="transcribe")
164
+ gen_kwargs["forced_decoder_ids"] = forced_ids
 
 
165
  except: pass
166
+ with torch.no_grad(), amp_ctx():
167
+ ids = comp["model"].generate(**gen_kwargs)
168
+ text = comp["processor"].batch_decode(ids, skip_special_tokens=True)[0]
169
  return text.strip()
170
  except Exception as e:
171
  return f"Error: {str(e)}"
172
 
173
  @GPU_DECORATOR
174
  def transcribe_audio(audio_path, language, use_specialized=False):
175
+ if use_specialized:
176
+ return transcribe_with_specialized_model(audio_path, language)
177
+ else:
178
+ return transcribe_with_primary_model(audio_path, language)
 
 
 
 
 
179
 
180
+ # ---------------- MAIN ---------------- #
181
  @GPU_DECORATOR
182
+ def compare_pronunciation(audio, lang_choice, intended):
183
+ if audio is None: return ("โŒ Please record audio first.","","","","","","","")
184
+ if not intended.strip(): return ("โŒ Please generate a sentence first.","","","","","","","")
185
+ ptext = transcribe_audio(audio, lang_choice, False)
186
+ stext = transcribe_audio(audio, lang_choice, True)
187
+ actual = ptext if not ptext.startswith("Error:") else stext
188
+ if actual.startswith("Error:"): return (f"โŒ {actual}","","","","","","","")
189
+ wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual)
190
+ score, feedback = get_score(wer_val, cer_val)
191
+ return (f"โœ… Done - {score}\n๐Ÿ’ฌ {feedback}",
192
+ ptext, stext,
193
+ f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)",
194
+ f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)",
195
+ diff_html(intended, actual),
196
+ char_html(intended, actual),
197
+ f"๐ŸŽฏ Target: {intended}")
198
+
199
+ def get_score(wer, cer):
200
+ c = (wer*0.7)+(cer*0.3)
201
+ if c <= 0.1: return "๐Ÿ† Excellent!","Outstanding!"
202
+ elif c <= 0.2: return "๐ŸŽ‰ Very Good!","Minor improvements needed."
203
+ elif c <= 0.4: return "๐Ÿ‘ Good!","Keep practicing."
204
+ elif c <= 0.6: return "๐Ÿ“š Needs Practice","Focus on clearer pronunciation."
205
+ else: return "๐Ÿ’ช Keep Trying!","Don't give up!"
206
+
207
+ def diff_html(ref,hyp): return highlight_differences(ref,hyp)
208
+ def char_html(ref,hyp): return char_level_highlight(ref,hyp)
209
+
210
+ # Diff functions
211
+ def highlight_differences(ref,hyp):
212
+ ref_w, hyp_w = ref.split(), hyp.split()
213
+ sm = difflib.SequenceMatcher(None, ref_w, hyp_w)
214
+ out=[]
215
+ for tag,i1,i2,j1,j2 in sm.get_opcodes():
216
+ if tag=='equal': out += [f"<span style='color:green'>{w}</span>" for w in ref_w[i1:i2]]
217
+ elif tag=='replace':
218
+ out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
219
+ out += [f"<span style='color:orange'>โ†’{w}</span>" for w in hyp_w[j1:j2]]
220
+ elif tag=='delete':
221
+ out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
222
+ elif tag=='insert':
223
+ out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]]
224
+ return " ".join(out)
225
+
226
+ def char_level_highlight(ref,hyp):
227
+ sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
228
+ out=[]
229
+ for tag,i1,i2,j1,j2 in sm.get_opcodes():
230
+ if tag=='equal': out += [f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]]
231
+ elif tag in ('replace','delete'): out += [f"<span style='color:red'>{c}</span>" for c in ref[i1:i2]]
232
+ elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]
233
+ return "".join(out)
234
 
235
  # ---------------- UI ---------------- #
236
  def create_interface():
237
+ with gr.Blocks() as demo:
238
+ gr.Markdown("# ๐ŸŽ™๏ธ IndicWhisper Pronunciation Trainer")
 
 
 
 
 
 
 
239
  with gr.Row():
240
+ lang = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Tamil", label="Language")
241
+ btn = gr.Button("๐ŸŽฒ Generate Sentence")
242
+ intended = gr.Textbox(label="Practice Sentence", interactive=False, lines=3)
243
+ audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Record")
244
+ analyze = gr.Button("๐Ÿ” Analyze")
245
+ status = gr.Textbox(label="Results", interactive=False, lines=4)
246
+ pass1 = gr.Textbox(label="Primary (IndicWhisper)")
247
+ pass2 = gr.Textbox(label="Specialized")
248
+ wer = gr.Textbox(label="Word Accuracy")
249
+ cer = gr.Textbox(label="Char Accuracy")
250
+ diff = gr.HTML(label="Word Diff")
251
+ chars = gr.HTML(label="Char Diff")
252
+ target = gr.Textbox(label="Reference", visible=False)
253
+ btn.click(get_random_sentence, [lang], [intended])
254
+ analyze.click(compare_pronunciation, [audio, lang, intended],
255
+ [status, pass1, pass2, wer, cer, diff, chars, target])
256
+ lang.change(get_random_sentence, [lang], [intended])
257
  return demo
258
 
 
259
  if __name__ == "__main__":
260
  demo = create_interface()
261
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)