Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,19 @@
|
|
1 |
import gradio as gr
|
2 |
-
import random
|
3 |
-
import difflib
|
4 |
-
import re
|
5 |
-
import warnings
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
-
import librosa
|
9 |
-
import soundfile as sf
|
10 |
import jiwer
|
11 |
|
12 |
-
# Optional
|
13 |
try:
|
14 |
from indic_transliteration import sanscript
|
15 |
from indic_transliteration.sanscript import transliterate
|
16 |
INDIC_OK = True
|
17 |
except:
|
18 |
INDIC_OK = False
|
19 |
-
sanscript = None
|
20 |
-
transliterate = None
|
21 |
|
22 |
-
# Optional
|
23 |
try:
|
24 |
import spaces
|
25 |
GPU_DECORATOR = spaces.GPU
|
@@ -34,60 +28,31 @@ warnings.filterwarnings("ignore")
|
|
34 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
|
36 |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
|
|
37 |
print(f"๐ง Using device: {DEVICE}")
|
38 |
|
39 |
-
LANG_CODES = {
|
40 |
-
"English": "en",
|
41 |
-
"Tamil": "ta",
|
42 |
-
"Malayalam": "ml",
|
43 |
-
}
|
44 |
|
45 |
-
#
|
46 |
INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
|
47 |
|
|
|
48 |
SPECIALIZED_MODELS = {
|
49 |
"English": "openai/whisper-base.en",
|
50 |
"Tamil": "vasista22/whisper-tamil-large-v2",
|
51 |
"Malayalam": "thennal/whisper-medium-ml",
|
52 |
}
|
53 |
|
|
|
54 |
SCRIPT_PATTERNS = {
|
55 |
"Tamil": re.compile(r"[เฎ-เฏฟ]"),
|
56 |
"Malayalam": re.compile(r"[เด-เตฟ]"),
|
57 |
"English": re.compile(r"[A-Za-z]"),
|
58 |
}
|
59 |
-
|
60 |
SENTENCE_BANK = {
|
61 |
-
"English": [
|
62 |
-
|
63 |
-
|
64 |
-
"I enjoy reading books in the evening.",
|
65 |
-
"Technology has changed our daily lives.",
|
66 |
-
"Music brings people together across cultures.",
|
67 |
-
"Education is the key to a bright future.",
|
68 |
-
"The flowers bloom beautifully in spring.",
|
69 |
-
"Hard work always pays off in the end.",
|
70 |
-
],
|
71 |
-
"Tamil": [
|
72 |
-
"เฎเฎฉเฏเฎฑเฏ เฎจเฎฒเฏเฎฒ เฎตเฎพเฎฉเฎฟเฎฒเฏ เฎเฎณเฏเฎณเฎคเฏ.",
|
73 |
-
"เฎจเฎพเฎฉเฏ เฎคเฎฎเฎฟเฎดเฏ เฎเฎฑเฏเฎฑเฏเฎเฏเฎเฏเฎฃเฏเฎเฏ เฎเฎฐเฏเฎเฏเฎเฎฟเฎฑเฏเฎฉเฏ.",
|
74 |
-
"เฎเฎฉเฎเฏเฎเฏ เฎชเฏเฎคเฏเฎคเฎเฎฎเฏ เฎชเฎเฎฟเฎเฏเฎ เฎตเฎฟเฎฐเฏเฎชเฏเฎชเฎฎเฏ.",
|
75 |
-
"เฎคเฎฎเฎฟเฎดเฏ เฎฎเฏเฎดเฎฟ เฎฎเฎฟเฎเฎตเฏเฎฎเฏ เฎ
เฎดเฎเฎพเฎฉเฎคเฏ.",
|
76 |
-
"เฎเฏเฎเฏเฎฎเฏเฎชเฎคเฏเฎคเฏเฎเฎฉเฏ เฎจเฏเฎฐเฎฎเฏ เฎเฏเฎฒเฎตเฎฟเฎเฏเฎตเฎคเฏ เฎฎเฏเฎเฏเฎเฎฟเฎฏเฎฎเฏ.",
|
77 |
-
"เฎเฎฒเฏเฎตเฎฟ เฎจเฎฎเฎคเฏ เฎเฎคเฎฟเฎฐเฏเฎเฎพเฎฒเฎคเฏเฎคเฎฟเฎฉเฏ เฎคเฎฟเฎฑเฎตเฏเฎเฏเฎฒเฏ.",
|
78 |
-
"เฎชเฎฑเฎตเฏเฎเฎณเฏ เฎเฎพเฎฒเฏเฎฏเฎฟเฎฒเฏ เฎเฎฉเฎฟเฎฎเฏเฎฏเฎพเฎ เฎชเฎพเฎเฏเฎเฎฟเฎฉเฏเฎฑเฎฉ.",
|
79 |
-
"เฎเฎดเฏเฎชเฏเฎชเฏ เฎเฎชเฏเฎชเฏเฎคเฏเฎฎเฏ เฎตเฏเฎฑเฏเฎฑเฎฟเฎฏเฏเฎคเฏ เฎคเฎฐเฏเฎฎเฏ.",
|
80 |
-
],
|
81 |
-
"Malayalam": [
|
82 |
-
"เดเดจเดฟเดเตเดเต เดฎเดฒเดฏเดพเดณเด เดตเดณเดฐเต เดเดทเตเดเดฎเดพเดฃเต.",
|
83 |
-
"เดเดจเตเดจเต เดฎเดดเดชเตเดฏเตเดฏเตเดจเตเดจเต.",
|
84 |
-
"เดเดพเตป เดชเตเดธเตเดคเดเด เดตเดพเดฏเดฟเดเตเดเตเดจเตเดจเต.",
|
85 |
-
"เดเตเดฐเดณเดคเตเดคเดฟเดจเตเดฑเต เดชเตเดฐเดเตเดคเดฟ เดธเตเดจเตเดฆเดฐเดฎเดพเดฃเต.",
|
86 |
-
"เดตเดฟเดฆเตเดฏเดพเดญเตเดฏเดพเดธเด เดเตเดตเดฟเดคเดคเตเดคเดฟเตฝ เดชเตเดฐเดงเดพเดจเดฎเดพเดฃเต.",
|
87 |
-
"เดธเดเดเตเดคเด เดฎเดจเดธเตเดธเดฟเดจเต เดธเดจเตเดคเตเดทเด เดจเตฝเดเตเดจเตเดจเต.",
|
88 |
-
"เดเตเดเตเดเดฌเดธเดฎเดฏเด เดตเดณเดฐเต เดตเดฟเดฒเดชเตเดชเตเดเตเดเดคเดพเดฃเต.",
|
89 |
-
"เดเด เดฟเดจเดพเดงเตเดตเดพเดจเด เดเดชเตเดชเตเดดเตเด เดซเดฒเด เดจเตฝเดเตเด.",
|
90 |
-
],
|
91 |
}
|
92 |
|
93 |
# Model cache
|
@@ -100,180 +65,86 @@ def get_random_sentence(language_choice):
|
|
100 |
return random.choice(SENTENCE_BANK[language_choice])
|
101 |
|
102 |
def is_script(text, lang_name):
|
103 |
-
|
104 |
-
|
105 |
-
return True
|
106 |
-
return bool(pattern.search(text or ""))
|
107 |
|
108 |
def transliterate_to_hk(text, lang_choice):
|
109 |
if not INDIC_OK:
|
110 |
return text
|
111 |
-
mapping = {
|
112 |
-
"Tamil": sanscript.TAMIL,
|
113 |
-
"Malayalam": sanscript.MALAYALAM,
|
114 |
-
"English": None
|
115 |
-
}
|
116 |
script = mapping.get(lang_choice)
|
117 |
if script and is_script(text, lang_choice):
|
118 |
-
try:
|
119 |
-
|
120 |
-
except:
|
121 |
-
return text
|
122 |
return text
|
123 |
|
124 |
def preprocess_audio(audio_path, target_sr=16000):
|
125 |
try:
|
126 |
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
|
127 |
-
if audio is None or len(audio) == 0:
|
128 |
-
return None, None
|
129 |
audio = audio.astype(np.float32)
|
130 |
-
|
131 |
-
if
|
132 |
-
audio /= max_abs
|
133 |
audio, _ = librosa.effects.trim(audio, top_db=20)
|
134 |
-
if len(audio) < target_sr
|
135 |
-
return None, None
|
136 |
return audio, target_sr
|
137 |
-
except
|
138 |
-
print(f"Audio preprocessing error: {e}")
|
139 |
-
return None, None
|
140 |
-
|
141 |
-
# Normalization for WER
|
142 |
-
JIWER_TRANSFORM = jiwer.Compose([
|
143 |
-
jiwer.ToLowerCase(),
|
144 |
-
jiwer.RemovePunctuation(),
|
145 |
-
jiwer.RemoveMultipleSpaces(),
|
146 |
-
jiwer.Strip(),
|
147 |
-
jiwer.ReduceToListOfListOfWords(),
|
148 |
-
])
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
def compute_cer(ref,
|
157 |
-
try:
|
158 |
-
|
159 |
-
except:
|
160 |
-
return 1.0
|
161 |
-
|
162 |
-
def highlight_differences(ref, hyp):
|
163 |
-
if not ref.strip() or not hyp.strip():
|
164 |
-
return "No text to compare"
|
165 |
-
ref_words = ref.strip().split()
|
166 |
-
hyp_words = hyp.strip().split()
|
167 |
-
sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
|
168 |
-
out_html = []
|
169 |
-
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
170 |
-
if tag == 'equal':
|
171 |
-
out_html.extend([f"<span style='color:green; background-color:#e8f5e8;'>{w}</span>" for w in ref_words[i1:i2]])
|
172 |
-
elif tag == 'replace':
|
173 |
-
out_html.extend([f"<span style='color:red; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]])
|
174 |
-
out_html.extend([f"<span style='color:orange;'>โ{w}</span>" for w in hyp_words[j1:j2]])
|
175 |
-
elif tag == 'delete':
|
176 |
-
out_html.extend([f"<span style='color:red; text-decoration:line-through;'>{w}</span>" for w in ref_words[i1:i2]])
|
177 |
-
elif tag == 'insert':
|
178 |
-
out_html.extend([f"<span style='color:orange;'>+{w}</span>" for w in hyp_words[j1:j2]])
|
179 |
-
return " ".join(out_html)
|
180 |
-
|
181 |
-
def char_level_highlight(ref, hyp):
|
182 |
-
if not ref.strip() or not hyp.strip():
|
183 |
-
return "No text to compare"
|
184 |
-
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
|
185 |
-
out = []
|
186 |
-
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
187 |
-
if tag == 'equal':
|
188 |
-
out.extend([f"<span style='color:green;'>{c}</span>" for c in ref[i1:i2]])
|
189 |
-
elif tag in ('replace', 'delete'):
|
190 |
-
out.extend([f"<span style='color:red;'>{c}</span>" for c in ref[i1:i2]])
|
191 |
-
elif tag == 'insert':
|
192 |
-
out.extend([f"<span style='color:orange;'>{c}</span>" for c in hyp[j1:j2]])
|
193 |
-
return "".join(out)
|
194 |
-
|
195 |
-
def get_pronunciation_score(wer_val, cer_val):
|
196 |
-
combined = (wer_val * 0.7) + (cer_val * 0.3)
|
197 |
-
if combined <= 0.1:
|
198 |
-
return "๐ Excellent! (90%+)", "Your pronunciation is outstanding!"
|
199 |
-
elif combined <= 0.2:
|
200 |
-
return "๐ Very Good! (80-90%)", "Great pronunciation with minor areas for improvement."
|
201 |
-
elif combined <= 0.4:
|
202 |
-
return "๐ Good! (60-80%)", "Good effort! Keep practicing."
|
203 |
-
elif combined <= 0.6:
|
204 |
-
return "๐ Needs Practice (40-60%)", "Focus on clearer pronunciation."
|
205 |
-
else:
|
206 |
-
return "๐ช Keep Trying! (<40%)", "Don't give up!"
|
207 |
|
208 |
-
# ---------------- LOADERS ---------------- #
|
209 |
@GPU_DECORATOR
|
210 |
def load_indicwhisper():
|
211 |
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
|
212 |
-
if indicwhisper_pipeline
|
213 |
-
return indicwhisper_pipeline
|
214 |
-
# Try JAX first
|
215 |
try:
|
216 |
-
from whisper_jax import FlaxWhisperPipeline
|
217 |
-
|
218 |
-
print(f"๐ Loading JAX IndicWhisper: {INDICWHISPER_MODEL}")
|
219 |
-
indicwhisper_pipeline = FlaxWhisperPipeline(
|
220 |
-
INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1
|
221 |
-
)
|
222 |
WHISPER_JAX_AVAILABLE = True
|
223 |
-
print("โ
JAX
|
224 |
-
return indicwhisper_pipeline
|
225 |
-
except Exception as e:
|
226 |
-
print(f"โ ๏ธ JAX unavailable: {e}")
|
227 |
-
WHISPER_JAX_AVAILABLE = False
|
228 |
-
# Fallback to Transformers
|
229 |
-
try:
|
230 |
-
from transformers import pipeline
|
231 |
-
indicwhisper_pipeline = pipeline(
|
232 |
-
"automatic-speech-recognition",
|
233 |
-
model=INDICWHISPER_MODEL,
|
234 |
-
device=DEVICE_INDEX
|
235 |
-
)
|
236 |
-
print("โ
Transformers IndicWhisper loaded!")
|
237 |
return indicwhisper_pipeline
|
238 |
except Exception as e:
|
239 |
-
print(f"
|
240 |
-
|
|
|
|
|
|
|
241 |
|
242 |
@GPU_DECORATOR
|
243 |
def load_specialized_model(language):
|
244 |
-
if language in fallback_models:
|
245 |
-
return fallback_models[language]
|
246 |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
247 |
-
|
248 |
-
|
249 |
-
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
250 |
-
|
251 |
-
low_cpu_mem_usage=True
|
252 |
-
).to(DEVICE)
|
253 |
-
fallback_models[language] = {"processor": processor, "model": model}
|
254 |
return fallback_models[language]
|
255 |
|
256 |
# ---------------- TRANSCRIBE ---------------- #
|
257 |
@GPU_DECORATOR
|
258 |
def transcribe_with_primary_model(audio_path, language):
|
259 |
try:
|
260 |
-
pl = load_indicwhisper()
|
261 |
-
lang_code = LANG_CODES.get(language, "en")
|
262 |
-
# JAX
|
263 |
if WHISPER_JAX_AVAILABLE:
|
264 |
-
|
265 |
-
if isinstance(
|
266 |
-
|
267 |
-
return str(result).strip()
|
268 |
-
# Transformers
|
269 |
if hasattr(pl, "model") and hasattr(pl, "tokenizer"):
|
270 |
try:
|
271 |
forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
|
272 |
pl.model.config.forced_decoder_ids = forced_ids
|
273 |
except: pass
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
return str(out).strip()
|
278 |
except Exception as e:
|
279 |
return f"Error: {str(e)}"
|
@@ -281,92 +152,110 @@ def transcribe_with_primary_model(audio_path, language):
|
|
281 |
@GPU_DECORATOR
|
282 |
def transcribe_with_specialized_model(audio_path, language):
|
283 |
try:
|
284 |
-
|
285 |
audio, sr = preprocess_audio(audio_path)
|
286 |
-
if audio is None:
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
generate_kwargs = {"inputs": input_features, "max_length": 200, "num_beams": 3}
|
291 |
if language != "English":
|
292 |
try:
|
293 |
-
forced_ids =
|
294 |
-
|
295 |
-
)
|
296 |
-
generate_kwargs["forced_decoder_ids"] = forced_ids
|
297 |
except: pass
|
298 |
-
with torch.no_grad():
|
299 |
-
ids =
|
300 |
-
text =
|
301 |
return text.strip()
|
302 |
except Exception as e:
|
303 |
return f"Error: {str(e)}"
|
304 |
|
305 |
@GPU_DECORATOR
|
306 |
def transcribe_audio(audio_path, language, use_specialized=False):
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
return transcribe_with_primary_model(audio_path, language)
|
312 |
-
except:
|
313 |
-
if not use_specialized:
|
314 |
-
return transcribe_audio(audio_path, language, use_specialized=True)
|
315 |
-
return "Error"
|
316 |
|
317 |
-
# ---------------- MAIN
|
318 |
@GPU_DECORATOR
|
319 |
-
def compare_pronunciation(audio,
|
320 |
-
if audio is None:
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
# ---------------- UI ---------------- #
|
344 |
def create_interface():
|
345 |
-
with gr.Blocks(
|
346 |
-
gr.Markdown("# ๐๏ธ IndicWhisper
|
347 |
-
with gr.Row():
|
348 |
-
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Tamil", label="๐ Language")
|
349 |
-
gen_btn = gr.Button("๐ฒ Generate Sentence")
|
350 |
-
intended_display = gr.Textbox(label="๐ Practice Sentence", interactive=False, lines=3)
|
351 |
-
audio_input = gr.Audio(sources=["microphone","upload"], type="filepath", label="๐ค Record")
|
352 |
-
analyze_btn = gr.Button("๐ Analyze")
|
353 |
-
status_output = gr.Textbox(label="๐ Results", interactive=False, lines=4)
|
354 |
with gr.Row():
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
367 |
return demo
|
368 |
|
369 |
-
# ---------------- LAUNCH ---------------- #
|
370 |
if __name__ == "__main__":
|
371 |
demo = create_interface()
|
372 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
1 |
import gradio as gr
|
2 |
+
import random, difflib, re, warnings, contextlib
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
+
import librosa, soundfile as sf
|
|
|
6 |
import jiwer
|
7 |
|
8 |
+
# Optional transliteration
|
9 |
try:
|
10 |
from indic_transliteration import sanscript
|
11 |
from indic_transliteration.sanscript import transliterate
|
12 |
INDIC_OK = True
|
13 |
except:
|
14 |
INDIC_OK = False
|
|
|
|
|
15 |
|
16 |
+
# Optional HF Spaces decorator
|
17 |
try:
|
18 |
import spaces
|
19 |
GPU_DECORATOR = spaces.GPU
|
|
|
28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
|
30 |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
31 |
+
amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext
|
32 |
print(f"๐ง Using device: {DEVICE}")
|
33 |
|
34 |
+
LANG_CODES = {"English": "en", "Tamil": "ta", "Malayalam": "ml"}
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Primary: IndicWhisper
|
37 |
INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil"
|
38 |
|
39 |
+
# Specialised fallbacks
|
40 |
SPECIALIZED_MODELS = {
|
41 |
"English": "openai/whisper-base.en",
|
42 |
"Tamil": "vasista22/whisper-tamil-large-v2",
|
43 |
"Malayalam": "thennal/whisper-medium-ml",
|
44 |
}
|
45 |
|
46 |
+
# Scripts and banking
|
47 |
SCRIPT_PATTERNS = {
|
48 |
"Tamil": re.compile(r"[เฎ-เฏฟ]"),
|
49 |
"Malayalam": re.compile(r"[เด-เตฟ]"),
|
50 |
"English": re.compile(r"[A-Za-z]"),
|
51 |
}
|
|
|
52 |
SENTENCE_BANK = {
|
53 |
+
"English": ["The sun sets over the beautiful horizon.", "Hard work always pays off in the end."],
|
54 |
+
"Tamil": ["เฎเฎฉเฏเฎฑเฏ เฎจเฎฒเฏเฎฒ เฎตเฎพเฎฉเฎฟเฎฒเฏ เฎเฎณเฏเฎณเฎคเฏ.", "เฎเฎดเฏเฎชเฏเฎชเฏ เฎเฎชเฏเฎชเฏเฎคเฏเฎฎเฏ เฎตเฏเฎฑเฏเฎฑเฎฟเฎฏเฏเฎคเฏ เฎคเฎฐเฏเฎฎเฏ."],
|
55 |
+
"Malayalam": ["เดเดจเดฟเดเตเดเต เดฎเดฒเดฏเดพเดณเด เดตเดณเดฐเต เดเดทเตเดเดฎเดพเดฃเต.", "เดเด เดฟเดจเดพเดงเตเดตเดพเดจเด เดเดชเตเดชเตเดดเตเด เดซเดฒเด เดจเตฝเดเตเด."]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
}
|
57 |
|
58 |
# Model cache
|
|
|
65 |
return random.choice(SENTENCE_BANK[language_choice])
|
66 |
|
67 |
def is_script(text, lang_name):
|
68 |
+
p = SCRIPT_PATTERNS.get(lang_name)
|
69 |
+
return not p or bool(p.search(text or ""))
|
|
|
|
|
70 |
|
71 |
def transliterate_to_hk(text, lang_choice):
|
72 |
if not INDIC_OK:
|
73 |
return text
|
74 |
+
mapping = {"Tamil": sanscript.TAMIL, "Malayalam": sanscript.MALAYALAM, "English": None}
|
|
|
|
|
|
|
|
|
75 |
script = mapping.get(lang_choice)
|
76 |
if script and is_script(text, lang_choice):
|
77 |
+
try: return transliterate(text, script, sanscript.HK)
|
78 |
+
except: return text
|
|
|
|
|
79 |
return text
|
80 |
|
81 |
def preprocess_audio(audio_path, target_sr=16000):
|
82 |
try:
|
83 |
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
|
84 |
+
if audio is None or len(audio) == 0: return None, None
|
|
|
85 |
audio = audio.astype(np.float32)
|
86 |
+
m = np.max(np.abs(audio))
|
87 |
+
if m > 0: audio /= m
|
|
|
88 |
audio, _ = librosa.effects.trim(audio, top_db=20)
|
89 |
+
if len(audio) < int(target_sr*0.1): return None, None
|
|
|
90 |
return audio, target_sr
|
91 |
+
except: return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
JIWER_TRANSFORM = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation(),
|
94 |
+
jiwer.RemoveMultipleSpaces(), jiwer.Strip(),
|
95 |
+
jiwer.ReduceToListOfListOfWords()])
|
96 |
+
def compute_wer(ref,hyp):
|
97 |
+
try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
|
98 |
+
except: return 1.0
|
99 |
+
def compute_cer(ref,hyp):
|
100 |
+
try: return jiwer.cer(ref, hyp)
|
101 |
+
except: return 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
# ---------------- MODEL LOADERS ---------------- #
|
104 |
@GPU_DECORATOR
|
105 |
def load_indicwhisper():
|
106 |
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
|
107 |
+
if indicwhisper_pipeline: return indicwhisper_pipeline
|
|
|
|
|
108 |
try:
|
109 |
+
from whisper_jax import FlaxWhisperPipeline; import jax.numpy as jnp
|
110 |
+
indicwhisper_pipeline = FlaxWhisperPipeline(INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1)
|
|
|
|
|
|
|
|
|
111 |
WHISPER_JAX_AVAILABLE = True
|
112 |
+
print("โ
JAX IndicWhisper loaded!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
return indicwhisper_pipeline
|
114 |
except Exception as e:
|
115 |
+
print(f"โ ๏ธ JAX unavailable: {e}"); WHISPER_JAX_AVAILABLE = False
|
116 |
+
from transformers import pipeline
|
117 |
+
indicwhisper_pipeline = pipeline("automatic-speech-recognition", model=INDICWHISPER_MODEL, device=DEVICE_INDEX)
|
118 |
+
print("โ
Transformers IndicWhisper loaded!")
|
119 |
+
return indicwhisper_pipeline
|
120 |
|
121 |
@GPU_DECORATOR
|
122 |
def load_specialized_model(language):
|
123 |
+
if language in fallback_models: return fallback_models[language]
|
|
|
124 |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
125 |
+
name = SPECIALIZED_MODELS[language]
|
126 |
+
proc = AutoProcessor.from_pretrained(name)
|
127 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(name, torch_dtype=DTYPE).to(DEVICE)
|
128 |
+
fallback_models[language] = {"processor": proc, "model": model}
|
|
|
|
|
|
|
129 |
return fallback_models[language]
|
130 |
|
131 |
# ---------------- TRANSCRIBE ---------------- #
|
132 |
@GPU_DECORATOR
|
133 |
def transcribe_with_primary_model(audio_path, language):
|
134 |
try:
|
135 |
+
pl = load_indicwhisper(); lang_code = LANG_CODES.get(language, "en")
|
|
|
|
|
136 |
if WHISPER_JAX_AVAILABLE:
|
137 |
+
res = pl(audio_path, task="transcribe", language=lang_code)
|
138 |
+
if isinstance(res, dict): return res.get("text","").strip()
|
139 |
+
return str(res).strip()
|
|
|
|
|
140 |
if hasattr(pl, "model") and hasattr(pl, "tokenizer"):
|
141 |
try:
|
142 |
forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
|
143 |
pl.model.config.forced_decoder_ids = forced_ids
|
144 |
except: pass
|
145 |
+
with amp_ctx():
|
146 |
+
out = pl(audio_path)
|
147 |
+
if isinstance(out, dict): return (out.get("text") or "").strip()
|
148 |
return str(out).strip()
|
149 |
except Exception as e:
|
150 |
return f"Error: {str(e)}"
|
|
|
152 |
@GPU_DECORATOR
|
153 |
def transcribe_with_specialized_model(audio_path, language):
|
154 |
try:
|
155 |
+
comp = load_specialized_model(language)
|
156 |
audio, sr = preprocess_audio(audio_path)
|
157 |
+
if audio is None: return "Error: Audio too short"
|
158 |
+
inputs = comp["processor"](audio, sampling_rate=sr, return_tensors="pt")
|
159 |
+
feats = inputs.input_features.to(DEVICE)
|
160 |
+
gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3}
|
|
|
161 |
if language != "English":
|
162 |
try:
|
163 |
+
forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(LANG_CODES[language], task="transcribe")
|
164 |
+
gen_kwargs["forced_decoder_ids"] = forced_ids
|
|
|
|
|
165 |
except: pass
|
166 |
+
with torch.no_grad(), amp_ctx():
|
167 |
+
ids = comp["model"].generate(**gen_kwargs)
|
168 |
+
text = comp["processor"].batch_decode(ids, skip_special_tokens=True)[0]
|
169 |
return text.strip()
|
170 |
except Exception as e:
|
171 |
return f"Error: {str(e)}"
|
172 |
|
173 |
@GPU_DECORATOR
|
174 |
def transcribe_audio(audio_path, language, use_specialized=False):
|
175 |
+
if use_specialized:
|
176 |
+
return transcribe_with_specialized_model(audio_path, language)
|
177 |
+
else:
|
178 |
+
return transcribe_with_primary_model(audio_path, language)
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
+
# ---------------- MAIN ---------------- #
|
181 |
@GPU_DECORATOR
|
182 |
+
def compare_pronunciation(audio, lang_choice, intended):
|
183 |
+
if audio is None: return ("โ Please record audio first.","","","","","","","")
|
184 |
+
if not intended.strip(): return ("โ Please generate a sentence first.","","","","","","","")
|
185 |
+
ptext = transcribe_audio(audio, lang_choice, False)
|
186 |
+
stext = transcribe_audio(audio, lang_choice, True)
|
187 |
+
actual = ptext if not ptext.startswith("Error:") else stext
|
188 |
+
if actual.startswith("Error:"): return (f"โ {actual}","","","","","","","")
|
189 |
+
wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual)
|
190 |
+
score, feedback = get_score(wer_val, cer_val)
|
191 |
+
return (f"โ
Done - {score}\n๐ฌ {feedback}",
|
192 |
+
ptext, stext,
|
193 |
+
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)",
|
194 |
+
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)",
|
195 |
+
diff_html(intended, actual),
|
196 |
+
char_html(intended, actual),
|
197 |
+
f"๐ฏ Target: {intended}")
|
198 |
+
|
199 |
+
def get_score(wer, cer):
|
200 |
+
c = (wer*0.7)+(cer*0.3)
|
201 |
+
if c <= 0.1: return "๐ Excellent!","Outstanding!"
|
202 |
+
elif c <= 0.2: return "๐ Very Good!","Minor improvements needed."
|
203 |
+
elif c <= 0.4: return "๐ Good!","Keep practicing."
|
204 |
+
elif c <= 0.6: return "๐ Needs Practice","Focus on clearer pronunciation."
|
205 |
+
else: return "๐ช Keep Trying!","Don't give up!"
|
206 |
+
|
207 |
+
def diff_html(ref,hyp): return highlight_differences(ref,hyp)
|
208 |
+
def char_html(ref,hyp): return char_level_highlight(ref,hyp)
|
209 |
+
|
210 |
+
# Diff functions
|
211 |
+
def highlight_differences(ref,hyp):
|
212 |
+
ref_w, hyp_w = ref.split(), hyp.split()
|
213 |
+
sm = difflib.SequenceMatcher(None, ref_w, hyp_w)
|
214 |
+
out=[]
|
215 |
+
for tag,i1,i2,j1,j2 in sm.get_opcodes():
|
216 |
+
if tag=='equal': out += [f"<span style='color:green'>{w}</span>" for w in ref_w[i1:i2]]
|
217 |
+
elif tag=='replace':
|
218 |
+
out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
|
219 |
+
out += [f"<span style='color:orange'>โ{w}</span>" for w in hyp_w[j1:j2]]
|
220 |
+
elif tag=='delete':
|
221 |
+
out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]]
|
222 |
+
elif tag=='insert':
|
223 |
+
out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]]
|
224 |
+
return " ".join(out)
|
225 |
+
|
226 |
+
def char_level_highlight(ref,hyp):
|
227 |
+
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
|
228 |
+
out=[]
|
229 |
+
for tag,i1,i2,j1,j2 in sm.get_opcodes():
|
230 |
+
if tag=='equal': out += [f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]]
|
231 |
+
elif tag in ('replace','delete'): out += [f"<span style='color:red'>{c}</span>" for c in ref[i1:i2]]
|
232 |
+
elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]]
|
233 |
+
return "".join(out)
|
234 |
|
235 |
# ---------------- UI ---------------- #
|
236 |
def create_interface():
|
237 |
+
with gr.Blocks() as demo:
|
238 |
+
gr.Markdown("# ๐๏ธ IndicWhisper Pronunciation Trainer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
with gr.Row():
|
240 |
+
lang = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Tamil", label="Language")
|
241 |
+
btn = gr.Button("๐ฒ Generate Sentence")
|
242 |
+
intended = gr.Textbox(label="Practice Sentence", interactive=False, lines=3)
|
243 |
+
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Record")
|
244 |
+
analyze = gr.Button("๐ Analyze")
|
245 |
+
status = gr.Textbox(label="Results", interactive=False, lines=4)
|
246 |
+
pass1 = gr.Textbox(label="Primary (IndicWhisper)")
|
247 |
+
pass2 = gr.Textbox(label="Specialized")
|
248 |
+
wer = gr.Textbox(label="Word Accuracy")
|
249 |
+
cer = gr.Textbox(label="Char Accuracy")
|
250 |
+
diff = gr.HTML(label="Word Diff")
|
251 |
+
chars = gr.HTML(label="Char Diff")
|
252 |
+
target = gr.Textbox(label="Reference", visible=False)
|
253 |
+
btn.click(get_random_sentence, [lang], [intended])
|
254 |
+
analyze.click(compare_pronunciation, [audio, lang, intended],
|
255 |
+
[status, pass1, pass2, wer, cer, diff, chars, target])
|
256 |
+
lang.change(get_random_sentence, [lang], [intended])
|
257 |
return demo
|
258 |
|
|
|
259 |
if __name__ == "__main__":
|
260 |
demo = create_interface()
|
261 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|