Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ try:
|
|
13 |
except:
|
14 |
INDIC_OK = False
|
15 |
|
16 |
-
# Optional HF Spaces decorator
|
17 |
try:
|
18 |
import spaces
|
19 |
GPU_DECORATOR = spaces.GPU
|
@@ -43,16 +43,24 @@ SPECIALIZED_MODELS = {
|
|
43 |
"Malayalam": "thennal/whisper-medium-ml",
|
44 |
}
|
45 |
|
46 |
-
# Scripts and banking
|
47 |
SCRIPT_PATTERNS = {
|
48 |
"Tamil": re.compile(r"[-]"),
|
49 |
"Malayalam": re.compile(r"[ഀ-ൿ]"),
|
50 |
"English": re.compile(r"[A-Za-z]"),
|
51 |
}
|
52 |
SENTENCE_BANK = {
|
53 |
-
"English": [
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
}
|
57 |
|
58 |
# Model cache
|
@@ -90,9 +98,11 @@ def preprocess_audio(audio_path, target_sr=16000):
|
|
90 |
return audio, target_sr
|
91 |
except: return None, None
|
92 |
|
93 |
-
JIWER_TRANSFORM = jiwer.Compose([
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
def compute_wer(ref,hyp):
|
97 |
try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
|
98 |
except: return 1.0
|
@@ -103,24 +113,45 @@ def compute_cer(ref,hyp):
|
|
103 |
# ---------------- MODEL LOADERS ---------------- #
|
104 |
@GPU_DECORATOR
|
105 |
def load_indicwhisper():
|
|
|
|
|
|
|
|
|
106 |
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
|
107 |
-
if indicwhisper_pipeline
|
|
|
|
|
|
|
108 |
try:
|
109 |
-
from whisper_jax import FlaxWhisperPipeline
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
WHISPER_JAX_AVAILABLE = True
|
112 |
print("✅ JAX IndicWhisper loaded!")
|
113 |
return indicwhisper_pipeline
|
114 |
except Exception as e:
|
115 |
-
print(f"⚠️ JAX unavailable: {e}")
|
|
|
|
|
|
|
116 |
from transformers import pipeline
|
117 |
-
indicwhisper_pipeline = pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
print("✅ Transformers IndicWhisper loaded!")
|
119 |
return indicwhisper_pipeline
|
120 |
|
121 |
@GPU_DECORATOR
|
122 |
def load_specialized_model(language):
|
123 |
-
if language in fallback_models:
|
|
|
124 |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
125 |
name = SPECIALIZED_MODELS[language]
|
126 |
proc = AutoProcessor.from_pretrained(name)
|
@@ -132,18 +163,27 @@ def load_specialized_model(language):
|
|
132 |
@GPU_DECORATOR
|
133 |
def transcribe_with_primary_model(audio_path, language):
|
134 |
try:
|
135 |
-
|
|
|
|
|
|
|
136 |
if WHISPER_JAX_AVAILABLE:
|
137 |
-
|
138 |
-
if isinstance(
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
with amp_ctx():
|
146 |
-
out =
|
147 |
if isinstance(out, dict): return (out.get("text") or "").strip()
|
148 |
return str(out).strip()
|
149 |
except Exception as e:
|
@@ -160,7 +200,9 @@ def transcribe_with_specialized_model(audio_path, language):
|
|
160 |
gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3}
|
161 |
if language != "English":
|
162 |
try:
|
163 |
-
forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(
|
|
|
|
|
164 |
gen_kwargs["forced_decoder_ids"] = forced_ids
|
165 |
except: pass
|
166 |
with torch.no_grad(), amp_ctx():
|
@@ -178,24 +220,6 @@ def transcribe_audio(audio_path, language, use_specialized=False):
|
|
178 |
return transcribe_with_primary_model(audio_path, language)
|
179 |
|
180 |
# ---------------- MAIN ---------------- #
|
181 |
-
@GPU_DECORATOR
|
182 |
-
def compare_pronunciation(audio, lang_choice, intended):
|
183 |
-
if audio is None: return ("❌ Please record audio first.","","","","","","","")
|
184 |
-
if not intended.strip(): return ("❌ Please generate a sentence first.","","","","","","","")
|
185 |
-
ptext = transcribe_audio(audio, lang_choice, False)
|
186 |
-
stext = transcribe_audio(audio, lang_choice, True)
|
187 |
-
actual = ptext if not ptext.startswith("Error:") else stext
|
188 |
-
if actual.startswith("Error:"): return (f"❌ {actual}","","","","","","","")
|
189 |
-
wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual)
|
190 |
-
score, feedback = get_score(wer_val, cer_val)
|
191 |
-
return (f"✅ Done - {score}\n💬 {feedback}",
|
192 |
-
ptext, stext,
|
193 |
-
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)",
|
194 |
-
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)",
|
195 |
-
diff_html(intended, actual),
|
196 |
-
char_html(intended, actual),
|
197 |
-
f"🎯 Target: {intended}")
|
198 |
-
|
199 |
def get_score(wer, cer):
|
200 |
c = (wer*0.7)+(cer*0.3)
|
201 |
if c <= 0.1: return "🏆 Excellent!","Outstanding!"
|
@@ -204,11 +228,7 @@ def get_score(wer, cer):
|
|
204 |
elif c <= 0.6: return "📚 Needs Practice","Focus on clearer pronunciation."
|
205 |
else: return "💪 Keep Trying!","Don't give up!"
|
206 |
|
207 |
-
def diff_html(ref,hyp):
|
208 |
-
def char_html(ref,hyp): return char_level_highlight(ref,hyp)
|
209 |
-
|
210 |
-
# Diff functions
|
211 |
-
def highlight_differences(ref,hyp):
|
212 |
ref_w, hyp_w = ref.split(), hyp.split()
|
213 |
sm = difflib.SequenceMatcher(None, ref_w, hyp_w)
|
214 |
out=[]
|
@@ -217,13 +237,11 @@ def highlight_differences(ref,hyp):
|
|
217 |
elif tag=='replace':
|
218 |
out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
|
219 |
out += [f"<span style='color:orange'>→{w}</span>" for w in hyp_w[j1:j2]]
|
220 |
-
elif tag=='delete':
|
221 |
-
|
222 |
-
elif tag=='insert':
|
223 |
-
out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]]
|
224 |
return " ".join(out)
|
225 |
|
226 |
-
def
|
227 |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
|
228 |
out=[]
|
229 |
for tag,i1,i2,j1,j2 in sm.get_opcodes():
|
@@ -232,6 +250,29 @@ def char_level_highlight(ref,hyp):
|
|
232 |
elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]
|
233 |
return "".join(out)
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
# ---------------- UI ---------------- #
|
236 |
def create_interface():
|
237 |
with gr.Blocks() as demo:
|
@@ -256,6 +297,7 @@ def create_interface():
|
|
256 |
lang.change(get_random_sentence, [lang], [intended])
|
257 |
return demo
|
258 |
|
|
|
259 |
if __name__ == "__main__":
|
260 |
demo = create_interface()
|
261 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
13 |
except:
|
14 |
INDIC_OK = False
|
15 |
|
16 |
+
# Optional HF Spaces GPU decorator
|
17 |
try:
|
18 |
import spaces
|
19 |
GPU_DECORATOR = spaces.GPU
|
|
|
43 |
"Malayalam": "thennal/whisper-medium-ml",
|
44 |
}
|
45 |
|
|
|
46 |
SCRIPT_PATTERNS = {
|
47 |
"Tamil": re.compile(r"[-]"),
|
48 |
"Malayalam": re.compile(r"[ഀ-ൿ]"),
|
49 |
"English": re.compile(r"[A-Za-z]"),
|
50 |
}
|
51 |
SENTENCE_BANK = {
|
52 |
+
"English": [
|
53 |
+
"The sun sets over the beautiful horizon.",
|
54 |
+
"Hard work always pays off in the end."
|
55 |
+
],
|
56 |
+
"Tamil": [
|
57 |
+
"இன்று நல்ல வானிலை உள்ளது.",
|
58 |
+
"உழைப்பு எப்போதும் வெற்றியைத் தரும்."
|
59 |
+
],
|
60 |
+
"Malayalam": [
|
61 |
+
"എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.",
|
62 |
+
"കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും."
|
63 |
+
]
|
64 |
}
|
65 |
|
66 |
# Model cache
|
|
|
98 |
return audio, target_sr
|
99 |
except: return None, None
|
100 |
|
101 |
+
JIWER_TRANSFORM = jiwer.Compose([
|
102 |
+
jiwer.ToLowerCase(), jiwer.RemovePunctuation(),
|
103 |
+
jiwer.RemoveMultipleSpaces(), jiwer.Strip(),
|
104 |
+
jiwer.ReduceToListOfListOfWords()
|
105 |
+
])
|
106 |
def compute_wer(ref,hyp):
|
107 |
try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
|
108 |
except: return 1.0
|
|
|
113 |
# ---------------- MODEL LOADERS ---------------- #
|
114 |
@GPU_DECORATOR
|
115 |
def load_indicwhisper():
|
116 |
+
"""
|
117 |
+
Load IndicWhisper (parthiv11/indic_whisper_nodcil) with matching config/weights.
|
118 |
+
Prefer whisper-jax if available, else Transformers pipeline.
|
119 |
+
"""
|
120 |
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
|
121 |
+
if indicwhisper_pipeline is not None:
|
122 |
+
return indicwhisper_pipeline
|
123 |
+
|
124 |
+
# Try JAX first
|
125 |
try:
|
126 |
+
from whisper_jax import FlaxWhisperPipeline
|
127 |
+
import jax.numpy as jnp
|
128 |
+
print(f"🔄 Loading JAX IndicWhisper: {INDICWHISPER_MODEL}")
|
129 |
+
indicwhisper_pipeline = FlaxWhisperPipeline(
|
130 |
+
INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1
|
131 |
+
)
|
132 |
WHISPER_JAX_AVAILABLE = True
|
133 |
print("✅ JAX IndicWhisper loaded!")
|
134 |
return indicwhisper_pipeline
|
135 |
except Exception as e:
|
136 |
+
print(f"⚠️ JAX unavailable: {e}")
|
137 |
+
WHISPER_JAX_AVAILABLE = False
|
138 |
+
|
139 |
+
# Transformers fallback — use model+tokenizer+feature_extractor from same repo
|
140 |
from transformers import pipeline
|
141 |
+
indicwhisper_pipeline = pipeline(
|
142 |
+
"automatic-speech-recognition",
|
143 |
+
model=INDICWHISPER_MODEL,
|
144 |
+
tokenizer=INDICWHISPER_MODEL,
|
145 |
+
feature_extractor=INDICWHISPER_MODEL,
|
146 |
+
device=DEVICE_INDEX
|
147 |
+
)
|
148 |
print("✅ Transformers IndicWhisper loaded!")
|
149 |
return indicwhisper_pipeline
|
150 |
|
151 |
@GPU_DECORATOR
|
152 |
def load_specialized_model(language):
|
153 |
+
if language in fallback_models:
|
154 |
+
return fallback_models[language]
|
155 |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
156 |
name = SPECIALIZED_MODELS[language]
|
157 |
proc = AutoProcessor.from_pretrained(name)
|
|
|
163 |
@GPU_DECORATOR
|
164 |
def transcribe_with_primary_model(audio_path, language):
|
165 |
try:
|
166 |
+
pipe = load_indicwhisper()
|
167 |
+
lang_code = LANG_CODES.get(language, "en")
|
168 |
+
|
169 |
+
# JAX path
|
170 |
if WHISPER_JAX_AVAILABLE:
|
171 |
+
result = pipe(audio_path, task="transcribe", language=lang_code)
|
172 |
+
if isinstance(result, dict) and "text" in result:
|
173 |
+
return result["text"].strip()
|
174 |
+
return str(result).strip()
|
175 |
+
|
176 |
+
# Transformers path
|
177 |
+
try:
|
178 |
+
if hasattr(pipe, "model") and hasattr(pipe, "tokenizer"):
|
179 |
+
forced_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
|
180 |
+
if forced_ids is not None:
|
181 |
+
pipe.model.config.forced_decoder_ids = forced_ids
|
182 |
+
except Exception as e:
|
183 |
+
print(f"⚠️ Lang forcing failed: {e}")
|
184 |
+
|
185 |
with amp_ctx():
|
186 |
+
out = pipe(audio_path)
|
187 |
if isinstance(out, dict): return (out.get("text") or "").strip()
|
188 |
return str(out).strip()
|
189 |
except Exception as e:
|
|
|
200 |
gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3}
|
201 |
if language != "English":
|
202 |
try:
|
203 |
+
forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(
|
204 |
+
LANG_CODES[language], task="transcribe"
|
205 |
+
)
|
206 |
gen_kwargs["forced_decoder_ids"] = forced_ids
|
207 |
except: pass
|
208 |
with torch.no_grad(), amp_ctx():
|
|
|
220 |
return transcribe_with_primary_model(audio_path, language)
|
221 |
|
222 |
# ---------------- MAIN ---------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
def get_score(wer, cer):
|
224 |
c = (wer*0.7)+(cer*0.3)
|
225 |
if c <= 0.1: return "🏆 Excellent!","Outstanding!"
|
|
|
228 |
elif c <= 0.6: return "📚 Needs Practice","Focus on clearer pronunciation."
|
229 |
else: return "💪 Keep Trying!","Don't give up!"
|
230 |
|
231 |
+
def diff_html(ref,hyp):
|
|
|
|
|
|
|
|
|
232 |
ref_w, hyp_w = ref.split(), hyp.split()
|
233 |
sm = difflib.SequenceMatcher(None, ref_w, hyp_w)
|
234 |
out=[]
|
|
|
237 |
elif tag=='replace':
|
238 |
out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
|
239 |
out += [f"<span style='color:orange'>→{w}</span>" for w in hyp_w[j1:j2]]
|
240 |
+
elif tag=='delete': out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
|
241 |
+
elif tag=='insert': out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]]
|
|
|
|
|
242 |
return " ".join(out)
|
243 |
|
244 |
+
def char_html(ref,hyp):
|
245 |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
|
246 |
out=[]
|
247 |
for tag,i1,i2,j1,j2 in sm.get_opcodes():
|
|
|
250 |
elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]
|
251 |
return "".join(out)
|
252 |
|
253 |
+
@GPU_DECORATOR
|
254 |
+
def compare_pronunciation(audio, lang_choice, intended):
|
255 |
+
if audio is None:
|
256 |
+
return ("❌ Please record first","","","","","","","")
|
257 |
+
if not intended.strip():
|
258 |
+
return ("❌ Please generate a sentence first","","","","","","","")
|
259 |
+
|
260 |
+
ptext = transcribe_audio(audio, lang_choice, False)
|
261 |
+
stext = transcribe_audio(audio, lang_choice, True)
|
262 |
+
actual = ptext if not ptext.startswith("Error:") else stext
|
263 |
+
if actual.startswith("Error:"):
|
264 |
+
return (f"❌ {actual}","","","","","","","")
|
265 |
+
|
266 |
+
wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual)
|
267 |
+
score, feedback = get_score(wer_val, cer_val)
|
268 |
+
return (f"✅ Done - {score}\n💬 {feedback}",
|
269 |
+
ptext, stext,
|
270 |
+
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)",
|
271 |
+
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)",
|
272 |
+
diff_html(intended, actual),
|
273 |
+
char_html(intended, actual),
|
274 |
+
f"🎯 Target: {intended}")
|
275 |
+
|
276 |
# ---------------- UI ---------------- #
|
277 |
def create_interface():
|
278 |
with gr.Blocks() as demo:
|
|
|
297 |
lang.change(get_random_sentence, [lang], [intended])
|
298 |
return demo
|
299 |
|
300 |
+
# ---------------- LAUNCH ---------------- #
|
301 |
if __name__ == "__main__":
|
302 |
demo = create_interface()
|
303 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|