sudhanm commited on
Commit
fa0e345
·
verified ·
1 Parent(s): 189dfd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -53
app.py CHANGED
@@ -13,7 +13,7 @@ try:
13
  except:
14
  INDIC_OK = False
15
 
16
- # Optional HF Spaces decorator
17
  try:
18
  import spaces
19
  GPU_DECORATOR = spaces.GPU
@@ -43,16 +43,24 @@ SPECIALIZED_MODELS = {
43
  "Malayalam": "thennal/whisper-medium-ml",
44
  }
45
 
46
- # Scripts and banking
47
  SCRIPT_PATTERNS = {
48
  "Tamil": re.compile(r"[஀-௿]"),
49
  "Malayalam": re.compile(r"[ഀ-ൿ]"),
50
  "English": re.compile(r"[A-Za-z]"),
51
  }
52
  SENTENCE_BANK = {
53
- "English": ["The sun sets over the beautiful horizon.", "Hard work always pays off in the end."],
54
- "Tamil": ["இன்று நல்ல வானிலை உள்ளது.", "உழைப்பு எப்போதும் வெற்றியைத் தரும்."],
55
- "Malayalam": ["എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", "കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും."]
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
  # Model cache
@@ -90,9 +98,11 @@ def preprocess_audio(audio_path, target_sr=16000):
90
  return audio, target_sr
91
  except: return None, None
92
 
93
- JIWER_TRANSFORM = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation(),
94
- jiwer.RemoveMultipleSpaces(), jiwer.Strip(),
95
- jiwer.ReduceToListOfListOfWords()])
 
 
96
  def compute_wer(ref,hyp):
97
  try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
98
  except: return 1.0
@@ -103,24 +113,45 @@ def compute_cer(ref,hyp):
103
  # ---------------- MODEL LOADERS ---------------- #
104
  @GPU_DECORATOR
105
  def load_indicwhisper():
 
 
 
 
106
  global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
107
- if indicwhisper_pipeline: return indicwhisper_pipeline
 
 
 
108
  try:
109
- from whisper_jax import FlaxWhisperPipeline; import jax.numpy as jnp
110
- indicwhisper_pipeline = FlaxWhisperPipeline(INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1)
 
 
 
 
111
  WHISPER_JAX_AVAILABLE = True
112
  print("✅ JAX IndicWhisper loaded!")
113
  return indicwhisper_pipeline
114
  except Exception as e:
115
- print(f"⚠️ JAX unavailable: {e}"); WHISPER_JAX_AVAILABLE = False
 
 
 
116
  from transformers import pipeline
117
- indicwhisper_pipeline = pipeline("automatic-speech-recognition", model=INDICWHISPER_MODEL, device=DEVICE_INDEX)
 
 
 
 
 
 
118
  print("✅ Transformers IndicWhisper loaded!")
119
  return indicwhisper_pipeline
120
 
121
  @GPU_DECORATOR
122
  def load_specialized_model(language):
123
- if language in fallback_models: return fallback_models[language]
 
124
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
125
  name = SPECIALIZED_MODELS[language]
126
  proc = AutoProcessor.from_pretrained(name)
@@ -132,18 +163,27 @@ def load_specialized_model(language):
132
  @GPU_DECORATOR
133
  def transcribe_with_primary_model(audio_path, language):
134
  try:
135
- pl = load_indicwhisper(); lang_code = LANG_CODES.get(language, "en")
 
 
 
136
  if WHISPER_JAX_AVAILABLE:
137
- res = pl(audio_path, task="transcribe", language=lang_code)
138
- if isinstance(res, dict): return res.get("text","").strip()
139
- return str(res).strip()
140
- if hasattr(pl, "model") and hasattr(pl, "tokenizer"):
141
- try:
142
- forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
143
- pl.model.config.forced_decoder_ids = forced_ids
144
- except: pass
 
 
 
 
 
 
145
  with amp_ctx():
146
- out = pl(audio_path)
147
  if isinstance(out, dict): return (out.get("text") or "").strip()
148
  return str(out).strip()
149
  except Exception as e:
@@ -160,7 +200,9 @@ def transcribe_with_specialized_model(audio_path, language):
160
  gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3}
161
  if language != "English":
162
  try:
163
- forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(LANG_CODES[language], task="transcribe")
 
 
164
  gen_kwargs["forced_decoder_ids"] = forced_ids
165
  except: pass
166
  with torch.no_grad(), amp_ctx():
@@ -178,24 +220,6 @@ def transcribe_audio(audio_path, language, use_specialized=False):
178
  return transcribe_with_primary_model(audio_path, language)
179
 
180
  # ---------------- MAIN ---------------- #
181
- @GPU_DECORATOR
182
- def compare_pronunciation(audio, lang_choice, intended):
183
- if audio is None: return ("❌ Please record audio first.","","","","","","","")
184
- if not intended.strip(): return ("❌ Please generate a sentence first.","","","","","","","")
185
- ptext = transcribe_audio(audio, lang_choice, False)
186
- stext = transcribe_audio(audio, lang_choice, True)
187
- actual = ptext if not ptext.startswith("Error:") else stext
188
- if actual.startswith("Error:"): return (f"❌ {actual}","","","","","","","")
189
- wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual)
190
- score, feedback = get_score(wer_val, cer_val)
191
- return (f"✅ Done - {score}\n💬 {feedback}",
192
- ptext, stext,
193
- f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)",
194
- f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)",
195
- diff_html(intended, actual),
196
- char_html(intended, actual),
197
- f"🎯 Target: {intended}")
198
-
199
  def get_score(wer, cer):
200
  c = (wer*0.7)+(cer*0.3)
201
  if c <= 0.1: return "🏆 Excellent!","Outstanding!"
@@ -204,11 +228,7 @@ def get_score(wer, cer):
204
  elif c <= 0.6: return "📚 Needs Practice","Focus on clearer pronunciation."
205
  else: return "💪 Keep Trying!","Don't give up!"
206
 
207
- def diff_html(ref,hyp): return highlight_differences(ref,hyp)
208
- def char_html(ref,hyp): return char_level_highlight(ref,hyp)
209
-
210
- # Diff functions
211
- def highlight_differences(ref,hyp):
212
  ref_w, hyp_w = ref.split(), hyp.split()
213
  sm = difflib.SequenceMatcher(None, ref_w, hyp_w)
214
  out=[]
@@ -217,13 +237,11 @@ def highlight_differences(ref,hyp):
217
  elif tag=='replace':
218
  out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
219
  out += [f"<span style='color:orange'>→{w}</span>" for w in hyp_w[j1:j2]]
220
- elif tag=='delete':
221
- out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
222
- elif tag=='insert':
223
- out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]]
224
  return " ".join(out)
225
 
226
- def char_level_highlight(ref,hyp):
227
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
228
  out=[]
229
  for tag,i1,i2,j1,j2 in sm.get_opcodes():
@@ -232,6 +250,29 @@ def char_level_highlight(ref,hyp):
232
  elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]
233
  return "".join(out)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  # ---------------- UI ---------------- #
236
  def create_interface():
237
  with gr.Blocks() as demo:
@@ -256,6 +297,7 @@ def create_interface():
256
  lang.change(get_random_sentence, [lang], [intended])
257
  return demo
258
 
 
259
  if __name__ == "__main__":
260
  demo = create_interface()
261
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
13
  except:
14
  INDIC_OK = False
15
 
16
+ # Optional HF Spaces GPU decorator
17
  try:
18
  import spaces
19
  GPU_DECORATOR = spaces.GPU
 
43
  "Malayalam": "thennal/whisper-medium-ml",
44
  }
45
 
 
46
  SCRIPT_PATTERNS = {
47
  "Tamil": re.compile(r"[஀-௿]"),
48
  "Malayalam": re.compile(r"[ഀ-ൿ]"),
49
  "English": re.compile(r"[A-Za-z]"),
50
  }
51
  SENTENCE_BANK = {
52
+ "English": [
53
+ "The sun sets over the beautiful horizon.",
54
+ "Hard work always pays off in the end."
55
+ ],
56
+ "Tamil": [
57
+ "இன்று நல்ல வானிலை உள்ளது.",
58
+ "உழைப்பு எப்போதும் வெற்றியைத் தரும்."
59
+ ],
60
+ "Malayalam": [
61
+ "എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
62
+ "കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും."
63
+ ]
64
  }
65
 
66
  # Model cache
 
98
  return audio, target_sr
99
  except: return None, None
100
 
101
+ JIWER_TRANSFORM = jiwer.Compose([
102
+ jiwer.ToLowerCase(), jiwer.RemovePunctuation(),
103
+ jiwer.RemoveMultipleSpaces(), jiwer.Strip(),
104
+ jiwer.ReduceToListOfListOfWords()
105
+ ])
106
  def compute_wer(ref,hyp):
107
  try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
108
  except: return 1.0
 
113
  # ---------------- MODEL LOADERS ---------------- #
114
  @GPU_DECORATOR
115
  def load_indicwhisper():
116
+ """
117
+ Load IndicWhisper (parthiv11/indic_whisper_nodcil) with matching config/weights.
118
+ Prefer whisper-jax if available, else Transformers pipeline.
119
+ """
120
  global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
121
+ if indicwhisper_pipeline is not None:
122
+ return indicwhisper_pipeline
123
+
124
+ # Try JAX first
125
  try:
126
+ from whisper_jax import FlaxWhisperPipeline
127
+ import jax.numpy as jnp
128
+ print(f"🔄 Loading JAX IndicWhisper: {INDICWHISPER_MODEL}")
129
+ indicwhisper_pipeline = FlaxWhisperPipeline(
130
+ INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1
131
+ )
132
  WHISPER_JAX_AVAILABLE = True
133
  print("✅ JAX IndicWhisper loaded!")
134
  return indicwhisper_pipeline
135
  except Exception as e:
136
+ print(f"⚠️ JAX unavailable: {e}")
137
+ WHISPER_JAX_AVAILABLE = False
138
+
139
+ # Transformers fallback — use model+tokenizer+feature_extractor from same repo
140
  from transformers import pipeline
141
+ indicwhisper_pipeline = pipeline(
142
+ "automatic-speech-recognition",
143
+ model=INDICWHISPER_MODEL,
144
+ tokenizer=INDICWHISPER_MODEL,
145
+ feature_extractor=INDICWHISPER_MODEL,
146
+ device=DEVICE_INDEX
147
+ )
148
  print("✅ Transformers IndicWhisper loaded!")
149
  return indicwhisper_pipeline
150
 
151
  @GPU_DECORATOR
152
  def load_specialized_model(language):
153
+ if language in fallback_models:
154
+ return fallback_models[language]
155
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
156
  name = SPECIALIZED_MODELS[language]
157
  proc = AutoProcessor.from_pretrained(name)
 
163
  @GPU_DECORATOR
164
  def transcribe_with_primary_model(audio_path, language):
165
  try:
166
+ pipe = load_indicwhisper()
167
+ lang_code = LANG_CODES.get(language, "en")
168
+
169
+ # JAX path
170
  if WHISPER_JAX_AVAILABLE:
171
+ result = pipe(audio_path, task="transcribe", language=lang_code)
172
+ if isinstance(result, dict) and "text" in result:
173
+ return result["text"].strip()
174
+ return str(result).strip()
175
+
176
+ # Transformers path
177
+ try:
178
+ if hasattr(pipe, "model") and hasattr(pipe, "tokenizer"):
179
+ forced_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
180
+ if forced_ids is not None:
181
+ pipe.model.config.forced_decoder_ids = forced_ids
182
+ except Exception as e:
183
+ print(f"⚠️ Lang forcing failed: {e}")
184
+
185
  with amp_ctx():
186
+ out = pipe(audio_path)
187
  if isinstance(out, dict): return (out.get("text") or "").strip()
188
  return str(out).strip()
189
  except Exception as e:
 
200
  gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3}
201
  if language != "English":
202
  try:
203
+ forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(
204
+ LANG_CODES[language], task="transcribe"
205
+ )
206
  gen_kwargs["forced_decoder_ids"] = forced_ids
207
  except: pass
208
  with torch.no_grad(), amp_ctx():
 
220
  return transcribe_with_primary_model(audio_path, language)
221
 
222
  # ---------------- MAIN ---------------- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def get_score(wer, cer):
224
  c = (wer*0.7)+(cer*0.3)
225
  if c <= 0.1: return "🏆 Excellent!","Outstanding!"
 
228
  elif c <= 0.6: return "📚 Needs Practice","Focus on clearer pronunciation."
229
  else: return "💪 Keep Trying!","Don't give up!"
230
 
231
+ def diff_html(ref,hyp):
 
 
 
 
232
  ref_w, hyp_w = ref.split(), hyp.split()
233
  sm = difflib.SequenceMatcher(None, ref_w, hyp_w)
234
  out=[]
 
237
  elif tag=='replace':
238
  out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
239
  out += [f"<span style='color:orange'>→{w}</span>" for w in hyp_w[j1:j2]]
240
+ elif tag=='delete': out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
241
+ elif tag=='insert': out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]]
 
 
242
  return " ".join(out)
243
 
244
+ def char_html(ref,hyp):
245
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
246
  out=[]
247
  for tag,i1,i2,j1,j2 in sm.get_opcodes():
 
250
  elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]
251
  return "".join(out)
252
 
253
+ @GPU_DECORATOR
254
+ def compare_pronunciation(audio, lang_choice, intended):
255
+ if audio is None:
256
+ return ("❌ Please record first","","","","","","","")
257
+ if not intended.strip():
258
+ return ("❌ Please generate a sentence first","","","","","","","")
259
+
260
+ ptext = transcribe_audio(audio, lang_choice, False)
261
+ stext = transcribe_audio(audio, lang_choice, True)
262
+ actual = ptext if not ptext.startswith("Error:") else stext
263
+ if actual.startswith("Error:"):
264
+ return (f"❌ {actual}","","","","","","","")
265
+
266
+ wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual)
267
+ score, feedback = get_score(wer_val, cer_val)
268
+ return (f"✅ Done - {score}\n💬 {feedback}",
269
+ ptext, stext,
270
+ f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)",
271
+ f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)",
272
+ diff_html(intended, actual),
273
+ char_html(intended, actual),
274
+ f"🎯 Target: {intended}")
275
+
276
  # ---------------- UI ---------------- #
277
  def create_interface():
278
  with gr.Blocks() as demo:
 
297
  lang.change(get_random_sentence, [lang], [intended])
298
  return demo
299
 
300
+ # ---------------- LAUNCH ---------------- #
301
  if __name__ == "__main__":
302
  demo = create_interface()
303
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)