Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,21 +5,27 @@ import re
|
|
5 |
import jiwer
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
-
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
9 |
import librosa
|
10 |
import soundfile as sf
|
11 |
from indic_transliteration import sanscript
|
12 |
from indic_transliteration.sanscript import transliterate
|
13 |
-
import unicodedata
|
14 |
import warnings
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
warnings.filterwarnings("ignore")
|
18 |
|
19 |
# ---------------- CONFIG ---------------- #
|
20 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
21 |
print(f"🔧 Using device: {DEVICE}")
|
22 |
-
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
|
23 |
|
24 |
LANG_CODES = {
|
25 |
"English": "en",
|
@@ -27,8 +33,10 @@ LANG_CODES = {
|
|
27 |
"Malayalam": "ml"
|
28 |
}
|
29 |
|
|
|
30 |
INDICWHISPER_MODEL = "openai/whisper-large-v2"
|
31 |
|
|
|
32 |
SPECIALIZED_MODELS = {
|
33 |
"English": "openai/whisper-base.en",
|
34 |
"Tamil": "vasista22/whisper-tamil-large-v2",
|
@@ -83,23 +91,83 @@ SENTENCE_BANK = {
|
|
83 |
]
|
84 |
}
|
85 |
|
86 |
-
# Controls for stricter script checking and normalization
|
87 |
-
STRICT_SCRIPT_CHECK = False # set True for strict script-only validation
|
88 |
-
NORMALIZE_TEXT_FOR_METRICS = True
|
89 |
-
|
90 |
# ---------------- MODEL CACHE ---------------- #
|
91 |
indicwhisper_pipeline = None
|
92 |
fallback_models = {}
|
93 |
-
WHISPER_JAX_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
return s
|
98 |
-
# Normalize unicode and collapse whitespace; do not remove language-specific punctuation
|
99 |
-
s = unicodedata.normalize("NFC", s)
|
100 |
-
s = re.sub(r"\s+", " ", s).strip()
|
101 |
-
return s
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def get_random_sentence(language_choice):
|
104 |
return random.choice(SENTENCE_BANK[language_choice])
|
105 |
|
@@ -107,16 +175,7 @@ def is_script(text, lang_name):
|
|
107 |
pattern = SCRIPT_PATTERNS.get(lang_name)
|
108 |
if not pattern:
|
109 |
return True
|
110 |
-
|
111 |
-
# any occurrence of script chars counts as match
|
112 |
-
return bool(pattern.search(text))
|
113 |
-
# strict: allow only spaces and target script chars
|
114 |
-
for ch in text:
|
115 |
-
if ch.isspace():
|
116 |
-
continue
|
117 |
-
if not pattern.match(ch):
|
118 |
-
return False
|
119 |
-
return True
|
120 |
|
121 |
def transliterate_to_hk(text, lang_choice):
|
122 |
mapping = {
|
@@ -125,6 +184,8 @@ def transliterate_to_hk(text, lang_choice):
|
|
125 |
"English": None
|
126 |
}
|
127 |
script = mapping.get(lang_choice)
|
|
|
|
|
128 |
if script and is_script(text, lang_choice):
|
129 |
try:
|
130 |
return transliterate(text, script, sanscript.HK)
|
@@ -134,111 +195,75 @@ def transliterate_to_hk(text, lang_choice):
|
|
134 |
return text
|
135 |
|
136 |
def preprocess_audio(audio_path, target_sr=16000):
|
|
|
137 |
try:
|
138 |
-
audio, sr = librosa.load(audio_path, sr=target_sr)
|
139 |
-
if
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
141 |
audio, _ = librosa.effects.trim(audio, top_db=20)
|
142 |
-
|
|
|
143 |
return None, None
|
|
|
|
|
|
|
144 |
return audio, target_sr
|
145 |
except Exception as e:
|
146 |
print(f"Audio preprocessing error: {e}")
|
147 |
return None, None
|
148 |
|
149 |
-
@
|
150 |
-
def load_indicwhisper():
|
151 |
-
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
|
152 |
-
if indicwhisper_pipeline is None:
|
153 |
-
try:
|
154 |
-
# Try JAX pipeline
|
155 |
-
try:
|
156 |
-
from whisper_jax import FlaxWhisperPipeline
|
157 |
-
import jax.numpy as jnp
|
158 |
-
print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
|
159 |
-
indicwhisper_pipeline = FlaxWhisperPipeline(
|
160 |
-
INDICWHISPER_MODEL,
|
161 |
-
dtype=jnp.bfloat16,
|
162 |
-
batch_size=1
|
163 |
-
)
|
164 |
-
WHISPER_JAX_AVAILABLE = True
|
165 |
-
print("✅ JAX-optimized model loaded successfully!")
|
166 |
-
return indicwhisper_pipeline
|
167 |
-
except Exception as e:
|
168 |
-
print(f"⚠️ JAX loading failed: {e}")
|
169 |
-
WHISPER_JAX_AVAILABLE = False
|
170 |
-
|
171 |
-
# Fallback to transformers pipeline
|
172 |
-
print(f"🔄 Loading transformers pipeline: {INDICWHISPER_MODEL}")
|
173 |
-
from transformers import pipeline
|
174 |
-
indicwhisper_pipeline = pipeline(
|
175 |
-
"automatic-speech-recognition",
|
176 |
-
model=INDICWHISPER_MODEL,
|
177 |
-
device=DEVICE_INDEX,
|
178 |
-
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
179 |
-
)
|
180 |
-
print("✅ High-performance model loaded with transformers!")
|
181 |
-
except Exception as e:
|
182 |
-
print(f"❌ Failed to load primary model: {e}")
|
183 |
-
indicwhisper_pipeline = None
|
184 |
-
raise Exception(f"Could not load high-performance model: {str(e)}")
|
185 |
-
return indicwhisper_pipeline
|
186 |
-
|
187 |
-
@spaces.GPU
|
188 |
-
def load_specialized_model(language):
|
189 |
-
if language not in fallback_models:
|
190 |
-
model_name = SPECIALIZED_MODELS[language]
|
191 |
-
print(f"🔄 Loading specialized model for {language}: {model_name}")
|
192 |
-
try:
|
193 |
-
processor = AutoProcessor.from_pretrained(model_name)
|
194 |
-
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
195 |
-
model_name,
|
196 |
-
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
197 |
-
low_cpu_mem_usage=True,
|
198 |
-
use_safetensors=True
|
199 |
-
).to(DEVICE)
|
200 |
-
model.eval()
|
201 |
-
fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
|
202 |
-
print(f"✅ Specialized model loaded for {language}")
|
203 |
-
except Exception as e:
|
204 |
-
print(f"❌ Failed to load specialized {model_name}: {e}")
|
205 |
-
raise Exception(f"Could not load specialized {language} model")
|
206 |
-
return fallback_models[language]
|
207 |
-
|
208 |
-
@spaces.GPU
|
209 |
def transcribe_with_primary_model(audio_path, language):
|
|
|
210 |
try:
|
211 |
pipe = load_indicwhisper()
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
try:
|
218 |
-
|
219 |
-
|
220 |
-
forced_ids = pipe.tokenizer.get_decoder_prompt_ids(
|
221 |
-
language=lang_code, task="transcribe"
|
222 |
-
)
|
223 |
-
pipe.model.config.forced_decoder_ids = forced_ids
|
224 |
except Exception as e:
|
225 |
-
print(f"⚠️
|
|
|
|
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
else:
|
233 |
-
return str(result).strip()
|
234 |
else:
|
235 |
-
return
|
236 |
except Exception as e:
|
237 |
print(f"Primary model transcription error: {e}")
|
238 |
-
|
239 |
|
240 |
-
@
|
241 |
def transcribe_with_specialized_model(audio_path, language):
|
|
|
242 |
try:
|
243 |
components = load_specialized_model(language)
|
244 |
processor = components["processor"]
|
@@ -248,15 +273,23 @@ def transcribe_with_specialized_model(audio_path, language):
|
|
248 |
if audio is None:
|
249 |
return "Error: Audio too short or could not be processed"
|
250 |
|
251 |
-
inputs = processor(
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
input_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
-
forced_decoder_ids = None
|
260 |
if language != "English":
|
261 |
lang_code = LANG_CODES.get(language, "en")
|
262 |
try:
|
@@ -265,60 +298,53 @@ def transcribe_with_specialized_model(audio_path, language):
|
|
265 |
language=lang_code,
|
266 |
task="transcribe"
|
267 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
except Exception as e:
|
269 |
print(f"⚠️ Language forcing failed: {e}")
|
270 |
|
271 |
with torch.no_grad():
|
272 |
-
|
273 |
-
"max_length": 200,
|
274 |
-
"num_beams": 3,
|
275 |
-
"do_sample": False
|
276 |
-
}
|
277 |
-
if forced_decoder_ids:
|
278 |
-
gen_kwargs["forced_decoder_ids"] = forced_decoder_ids
|
279 |
-
|
280 |
-
predicted_ids = model.generate(
|
281 |
-
input_features,
|
282 |
-
**gen_kwargs
|
283 |
-
)
|
284 |
|
285 |
transcription = processor.batch_decode(
|
286 |
predicted_ids,
|
287 |
skip_special_tokens=True,
|
288 |
clean_up_tokenization_spaces=True
|
289 |
)[0]
|
290 |
-
|
291 |
-
return transcription.strip() or "(No transcription generated)"
|
292 |
except Exception as e:
|
293 |
print(f"Specialized model transcription error: {e}")
|
294 |
-
return f"Error: {str(e)[:
|
295 |
|
296 |
-
@
|
297 |
def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
|
|
|
298 |
try:
|
299 |
if use_specialized:
|
300 |
print(f"🔄 Using specialized model for {language}")
|
301 |
return transcribe_with_specialized_model(audio_path, language)
|
302 |
else:
|
303 |
-
print(f"🔄 Using
|
304 |
return transcribe_with_primary_model(audio_path, language)
|
305 |
except Exception as e:
|
306 |
print(f"Transcription failed, trying specialized model: {e}")
|
307 |
if not use_specialized:
|
308 |
return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
|
309 |
else:
|
310 |
-
return f"Error: All transcription methods failed - {str(e)[:
|
311 |
|
312 |
def highlight_differences(ref, hyp):
|
313 |
-
if not ref.strip() or not hyp.strip():
|
314 |
return "No text to compare"
|
315 |
-
|
316 |
ref_words = ref.strip().split()
|
317 |
hyp_words = hyp.strip().split()
|
318 |
-
|
319 |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
|
320 |
out_html = []
|
321 |
-
|
322 |
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
323 |
if tag == 'equal':
|
324 |
out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
|
@@ -329,13 +355,11 @@ def highlight_differences(ref, hyp):
|
|
329 |
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
|
330 |
elif tag == 'insert':
|
331 |
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
|
332 |
-
|
333 |
return " ".join(out_html)
|
334 |
|
335 |
def char_level_highlight(ref, hyp):
|
336 |
-
if not ref.strip() or not hyp.strip():
|
337 |
return "No text to compare"
|
338 |
-
|
339 |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
|
340 |
out = []
|
341 |
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
@@ -360,63 +384,50 @@ def get_pronunciation_score(wer_val, cer_val):
|
|
360 |
else:
|
361 |
return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
|
362 |
|
363 |
-
|
|
|
364 |
def compare_pronunciation(audio, language_choice, intended_sentence):
|
365 |
-
print(f"🔍 Starting
|
366 |
print(f"📝 Audio file: {audio}")
|
367 |
print(f"🎯 Intended sentence: {intended_sentence}")
|
368 |
|
369 |
if audio is None:
|
370 |
-
print("❌ No audio provided")
|
371 |
return ("❌ Please record audio first.", "", "", "", "", "", "", "")
|
372 |
-
|
373 |
-
if not intended_sentence.strip():
|
374 |
-
print("❌ No intended sentence")
|
375 |
return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
|
376 |
|
377 |
try:
|
378 |
-
print(
|
379 |
primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
|
380 |
-
print(f"✅ Primary
|
381 |
|
382 |
-
print("🔄
|
383 |
specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
|
384 |
-
print(f"✅ Specialized
|
385 |
|
386 |
actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
|
387 |
|
388 |
if str(actual_text).startswith("Error:"):
|
389 |
-
print(f"❌ Transcription error: {actual_text}")
|
390 |
return (f"❌ {actual_text}", "", "", "", "", "", "", "")
|
391 |
|
392 |
-
# Normalize for metrics if enabled
|
393 |
-
ref_for_metrics = normalize_text(intended_sentence)
|
394 |
-
hyp_for_metrics = normalize_text(actual_text)
|
395 |
-
|
396 |
try:
|
397 |
-
|
398 |
-
|
399 |
-
cer_val = jiwer.cer(ref_for_metrics, hyp_for_metrics)
|
400 |
-
print(f"✅ WER: {wer_val:.3f}, CER: {cer_val:.3f}")
|
401 |
except Exception as e:
|
402 |
-
print(f"❌
|
403 |
wer_val, cer_val = 1.0, 1.0
|
404 |
|
405 |
score_text, feedback = get_pronunciation_score(wer_val, cer_val)
|
406 |
|
407 |
-
print("🔄 Generating transliterations...")
|
408 |
actual_hk = transliterate_to_hk(actual_text, language_choice)
|
409 |
target_hk = transliterate_to_hk(intended_sentence, language_choice)
|
410 |
-
|
411 |
-
if not is_script(actual_text, language_choice) and language_choice != "English":
|
412 |
actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
|
413 |
|
414 |
-
print("🔄 Generating visual feedback...")
|
415 |
diff_html = highlight_differences(intended_sentence, actual_text)
|
416 |
char_html = char_level_highlight(intended_sentence, actual_text)
|
417 |
|
418 |
-
status = f"✅
|
419 |
-
print(f"✅ Advanced analysis completed successfully")
|
420 |
|
421 |
return (
|
422 |
status,
|
@@ -431,14 +442,11 @@ def compare_pronunciation(audio, language_choice, intended_sentence):
|
|
431 |
|
432 |
except Exception as e:
|
433 |
error_msg = f"❌ Analysis Error: {str(e)[:200]}"
|
434 |
-
print(f"❌ FATAL ERROR: {e}")
|
435 |
-
import traceback
|
436 |
-
traceback.print_exc()
|
437 |
return (error_msg, str(e), "", "", "", "", "", "")
|
438 |
|
|
|
439 |
def create_interface():
|
440 |
with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
|
441 |
-
|
442 |
gr.Markdown("""
|
443 |
# 🎙️ Advanced Multilingual Pronunciation Trainer
|
444 |
|
@@ -446,12 +454,12 @@ def create_interface():
|
|
446 |
|
447 |
### 🏆 Powered by Advanced Models:
|
448 |
- Dual-Model Analysis: Primary + specialized model comparison
|
449 |
-
- High Accuracy: Language-specific fine-tuned models
|
450 |
- Robust Performance: Automatic fallback for reliability
|
451 |
|
452 |
### 📋 How to Use:
|
453 |
1. Select your target language 🌍
|
454 |
-
2. Generate a practice sentence 🎲
|
455 |
3. Record yourself reading it aloud 🎤
|
456 |
4. Get detailed feedback with advanced accuracy 📊
|
457 |
|
@@ -520,7 +528,7 @@ def create_interface():
|
|
520 |
gr.Markdown("""
|
521 |
### 🎨 Color Guide:
|
522 |
- 🟢 Green: Correctly pronounced words/characters
|
523 |
-
- 🔴 Red: Missing or mispronounced (strikethrough)
|
524 |
- 🟠 Orange: Extra words or substitutions
|
525 |
""")
|
526 |
diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
|
@@ -542,14 +550,14 @@ def create_interface():
|
|
542 |
fn=compare_pronunciation,
|
543 |
inputs=[audio_input, lang_choice, intended_display],
|
544 |
outputs=[
|
545 |
-
status_output,
|
546 |
-
pass1_out,
|
547 |
-
pass2_out,
|
548 |
-
wer_out,
|
549 |
-
cer_out,
|
550 |
-
diff_html_box,
|
551 |
-
char_html_box,
|
552 |
-
target_display
|
553 |
]
|
554 |
)
|
555 |
|
@@ -563,27 +571,33 @@ def create_interface():
|
|
563 |
---
|
564 |
### 🏆 Advanced Technology Stack:
|
565 |
- Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
|
566 |
-
- Specialized Models:
|
567 |
-
- Tamil: vasista22/whisper-tamil-large-v2
|
568 |
-
- Malayalam: thennal/whisper-medium-ml
|
569 |
-
- English:
|
570 |
-
- Dual Analysis
|
|
|
571 |
|
572 |
### 🔧 Technical Details:
|
573 |
-
- Metrics: WER and CER
|
574 |
-
- Transliteration: Harvard-Kyoto for Indic scripts
|
575 |
-
-
|
|
|
576 |
""")
|
577 |
return demo
|
578 |
|
|
|
579 |
if __name__ == "__main__":
|
580 |
print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
|
581 |
-
print(f"🔧 Device: {DEVICE}
|
582 |
-
|
|
|
|
|
|
|
|
|
583 |
print("🏆 Using High-Performance Dual-Model Approach")
|
584 |
print("⚡ Automatic model selection with specialized fallbacks")
|
585 |
print("📊 Advanced analysis with robust error handling")
|
586 |
-
print("🎮 GPU functions decorated with @spaces.GPU for HuggingFace Spaces")
|
587 |
|
588 |
demo = create_interface()
|
589 |
demo.launch(
|
|
|
5 |
import jiwer
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, WhisperProcessor
|
9 |
import librosa
|
10 |
import soundfile as sf
|
11 |
from indic_transliteration import sanscript
|
12 |
from indic_transliteration.sanscript import transliterate
|
|
|
13 |
import warnings
|
14 |
+
|
15 |
+
# Optional: only available on HF Spaces runtime
|
16 |
+
try:
|
17 |
+
import spaces
|
18 |
+
GPU_DECORATOR = spaces.GPU
|
19 |
+
except Exception:
|
20 |
+
def GPU_DECORATOR(fn):
|
21 |
+
return fn
|
22 |
|
23 |
warnings.filterwarnings("ignore")
|
24 |
|
25 |
# ---------------- CONFIG ---------------- #
|
26 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
CUDA_DEVICE_INDEX = 0 if torch.cuda.is_available() else -1 # for transformers pipeline device
|
28 |
print(f"🔧 Using device: {DEVICE}")
|
|
|
29 |
|
30 |
LANG_CODES = {
|
31 |
"English": "en",
|
|
|
33 |
"Malayalam": "ml"
|
34 |
}
|
35 |
|
36 |
+
# Primary model
|
37 |
INDICWHISPER_MODEL = "openai/whisper-large-v2"
|
38 |
|
39 |
+
# Specialized models
|
40 |
SPECIALIZED_MODELS = {
|
41 |
"English": "openai/whisper-base.en",
|
42 |
"Tamil": "vasista22/whisper-tamil-large-v2",
|
|
|
91 |
]
|
92 |
}
|
93 |
|
|
|
|
|
|
|
|
|
94 |
# ---------------- MODEL CACHE ---------------- #
|
95 |
indicwhisper_pipeline = None
|
96 |
fallback_models = {}
|
97 |
+
WHISPER_JAX_AVAILABLE = False # default false; will set true if we load it
|
98 |
+
|
99 |
+
@GPU_DECORATOR
|
100 |
+
def load_indicwhisper():
|
101 |
+
"""Load primary high-performance model (prefer transformers pipeline, optionally JAX if available)."""
|
102 |
+
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE
|
103 |
+
|
104 |
+
if indicwhisper_pipeline is not None:
|
105 |
+
return indicwhisper_pipeline
|
106 |
+
|
107 |
+
# Try JAX first (optional)
|
108 |
+
try:
|
109 |
+
from whisper_jax import FlaxWhisperPipeline
|
110 |
+
import jax.numpy as jnp
|
111 |
+
print(f"🔄 Loading JAX-optimized model: {INDICWHISPER_MODEL}")
|
112 |
+
indicwhisper_pipeline = FlaxWhisperPipeline(
|
113 |
+
INDICWHISPER_MODEL,
|
114 |
+
dtype=jnp.bfloat16,
|
115 |
+
batch_size=1
|
116 |
+
)
|
117 |
+
WHISPER_JAX_AVAILABLE = True
|
118 |
+
print("✅ JAX-optimized model loaded successfully!")
|
119 |
+
return indicwhisper_pipeline
|
120 |
+
except Exception as e:
|
121 |
+
print(f"⚠️ JAX loading failed: {e}")
|
122 |
+
WHISPER_JAX_AVAILABLE = False
|
123 |
+
|
124 |
+
# Fallback to transformers pipeline
|
125 |
+
try:
|
126 |
+
from transformers import pipeline
|
127 |
+
print(f"🔄 Loading transformers ASR pipeline: {INDICWHISPER_MODEL}")
|
128 |
+
indicwhisper_pipeline = pipeline(
|
129 |
+
task="automatic-speech-recognition",
|
130 |
+
model=INDICWHISPER_MODEL,
|
131 |
+
device=CUDA_DEVICE_INDEX # 0 for CUDA, -1 for CPU
|
132 |
+
)
|
133 |
+
print("✅ Transformers ASR pipeline loaded!")
|
134 |
+
return indicwhisper_pipeline
|
135 |
+
except Exception as e:
|
136 |
+
print(f"❌ Failed to load primary model: {e}")
|
137 |
+
indicwhisper_pipeline = None
|
138 |
+
raise Exception(f"Could not load primary model: {str(e)}")
|
139 |
+
|
140 |
+
@GPU_DECORATOR
|
141 |
+
def load_specialized_model(language: str):
|
142 |
+
"""Load language-specific specialized model with processor."""
|
143 |
+
if language in fallback_models:
|
144 |
+
return fallback_models[language]
|
145 |
|
146 |
+
model_name = SPECIALIZED_MODELS[language]
|
147 |
+
print(f"🔄 Loading specialized model for {language}: {model_name}")
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
try:
|
150 |
+
# WhisperProcessor ensures get_decoder_prompt_ids is available
|
151 |
+
try:
|
152 |
+
processor = WhisperProcessor.from_pretrained(model_name)
|
153 |
+
except Exception:
|
154 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
155 |
+
|
156 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
157 |
+
model_name,
|
158 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
159 |
+
low_cpu_mem_usage=True
|
160 |
+
)
|
161 |
+
model.to(DEVICE)
|
162 |
+
|
163 |
+
fallback_models[language] = {"processor": processor, "model": model, "model_name": model_name}
|
164 |
+
print(f"✅ Specialized model loaded for {language}")
|
165 |
+
return fallback_models[language]
|
166 |
+
except Exception as e:
|
167 |
+
print(f"❌ Failed to load specialized {model_name}: {e}")
|
168 |
+
raise Exception(f"Could not load specialized {language} model: {str(e)}")
|
169 |
+
|
170 |
+
# ---------------- HELPERS ---------------- #
|
171 |
def get_random_sentence(language_choice):
|
172 |
return random.choice(SENTENCE_BANK[language_choice])
|
173 |
|
|
|
175 |
pattern = SCRIPT_PATTERNS.get(lang_name)
|
176 |
if not pattern:
|
177 |
return True
|
178 |
+
return bool(pattern.search(text or ""))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
def transliterate_to_hk(text, lang_choice):
|
181 |
mapping = {
|
|
|
184 |
"English": None
|
185 |
}
|
186 |
script = mapping.get(lang_choice)
|
187 |
+
if not text:
|
188 |
+
return ""
|
189 |
if script and is_script(text, lang_choice):
|
190 |
try:
|
191 |
return transliterate(text, script, sanscript.HK)
|
|
|
195 |
return text
|
196 |
|
197 |
def preprocess_audio(audio_path, target_sr=16000):
|
198 |
+
"""Load, normalize, trim, return float32 audio."""
|
199 |
try:
|
200 |
+
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
|
201 |
+
if audio is None or len(audio) == 0:
|
202 |
+
return None, None
|
203 |
+
# Normalize
|
204 |
+
m = np.max(np.abs(audio))
|
205 |
+
if m > 0:
|
206 |
+
audio = audio / m
|
207 |
+
# Trim silence
|
208 |
audio, _ = librosa.effects.trim(audio, top_db=20)
|
209 |
+
# Ensure min length
|
210 |
+
if len(audio) < int(target_sr * 0.1):
|
211 |
return None, None
|
212 |
+
# Ensure float32
|
213 |
+
if audio.dtype != np.float32:
|
214 |
+
audio = audio.astype(np.float32)
|
215 |
return audio, target_sr
|
216 |
except Exception as e:
|
217 |
print(f"Audio preprocessing error: {e}")
|
218 |
return None, None
|
219 |
|
220 |
+
@GPU_DECORATOR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
def transcribe_with_primary_model(audio_path, language):
|
222 |
+
"""Transcribe using primary model (JAX if available else transformers pipeline)."""
|
223 |
try:
|
224 |
pipe = load_indicwhisper()
|
225 |
|
226 |
+
lang_code = LANG_CODES.get(language, "en")
|
227 |
+
|
228 |
+
if WHISPER_JAX_AVAILABLE:
|
229 |
+
# whisper-jax expects array or path; pass path is okay
|
230 |
+
result = pipe(audio_path, task="transcribe", language=lang_code)
|
231 |
+
# whisper-jax returns dict with 'text'
|
232 |
+
if isinstance(result, dict) and "text" in result:
|
233 |
+
return (result["text"] or "").strip()
|
234 |
+
return str(result).strip()
|
235 |
+
|
236 |
+
# transformers pipeline
|
237 |
+
# Some transformers versions accept language/task via generate_kwargs
|
238 |
+
generate_kwargs = {}
|
239 |
+
try:
|
240 |
+
# If underlying model is Whisper, we can set forced decoder ids
|
241 |
+
model = pipe.model if hasattr(pipe, "model") else None
|
242 |
+
tokenizer = getattr(pipe, "tokenizer", None)
|
243 |
+
processor = getattr(pipe, "feature_extractor", None)
|
244 |
+
if hasattr(pipe, "tokenizer") and hasattr(model, "config"):
|
245 |
try:
|
246 |
+
forced_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe")
|
247 |
+
model.config.forced_decoder_ids = forced_ids
|
|
|
|
|
|
|
|
|
248 |
except Exception as e:
|
249 |
+
print(f"⚠️ Primary model language forcing failed: {e}")
|
250 |
+
except Exception as e:
|
251 |
+
print(f"⚠️ Primary model prompt config error: {e}")
|
252 |
|
253 |
+
out = pipe(audio_path, generate_kwargs=generate_kwargs)
|
254 |
+
if isinstance(out, dict) and "text" in out:
|
255 |
+
return (out["text"] or "").strip()
|
256 |
+
elif isinstance(out, str):
|
257 |
+
return out.strip()
|
|
|
|
|
258 |
else:
|
259 |
+
return str(out).strip()
|
260 |
except Exception as e:
|
261 |
print(f"Primary model transcription error: {e}")
|
262 |
+
return f"Error: {str(e)[:200]}"
|
263 |
|
264 |
+
@GPU_DECORATOR
|
265 |
def transcribe_with_specialized_model(audio_path, language):
|
266 |
+
"""Transcribe using language-specific models."""
|
267 |
try:
|
268 |
components = load_specialized_model(language)
|
269 |
processor = components["processor"]
|
|
|
273 |
if audio is None:
|
274 |
return "Error: Audio too short or could not be processed"
|
275 |
|
276 |
+
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
|
277 |
+
# WhisperProcessor returns input_features
|
278 |
+
input_features = inputs.get("input_features", None)
|
279 |
+
if input_features is None:
|
280 |
+
# Fallback: some processors use feature_extractor path
|
281 |
+
input_features = inputs.get("input_values", None)
|
282 |
+
if input_features is None:
|
283 |
+
return "Error: Could not prepare input features"
|
284 |
+
|
285 |
+
input_features = input_features.to(DEVICE)
|
286 |
+
|
287 |
+
generate_kwargs = {
|
288 |
+
"max_length": 200,
|
289 |
+
"num_beams": 3,
|
290 |
+
"do_sample": False
|
291 |
+
}
|
292 |
|
|
|
293 |
if language != "English":
|
294 |
lang_code = LANG_CODES.get(language, "en")
|
295 |
try:
|
|
|
298 |
language=lang_code,
|
299 |
task="transcribe"
|
300 |
)
|
301 |
+
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
|
302 |
+
elif hasattr(model, "config") and hasattr(processor, "tokenizer"):
|
303 |
+
forced_decoder_ids = processor.tokenizer.get_decoder_prompt_ids(
|
304 |
+
language=lang_code,
|
305 |
+
task="transcribe"
|
306 |
+
)
|
307 |
+
model.config.forced_decoder_ids = forced_decoder_ids
|
308 |
except Exception as e:
|
309 |
print(f"⚠️ Language forcing failed: {e}")
|
310 |
|
311 |
with torch.no_grad():
|
312 |
+
predicted_ids = model.generate(input_features=input_features, **generate_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
transcription = processor.batch_decode(
|
315 |
predicted_ids,
|
316 |
skip_special_tokens=True,
|
317 |
clean_up_tokenization_spaces=True
|
318 |
)[0]
|
319 |
+
return (transcription or "").strip() or "(No transcription generated)"
|
|
|
320 |
except Exception as e:
|
321 |
print(f"Specialized model transcription error: {e}")
|
322 |
+
return f"Error: {str(e)[:200]}"
|
323 |
|
324 |
+
@GPU_DECORATOR
|
325 |
def transcribe_audio(audio_path, language, initial_prompt="", use_specialized=False):
|
326 |
+
"""Dispatch to primary or specialized path with fallback."""
|
327 |
try:
|
328 |
if use_specialized:
|
329 |
print(f"🔄 Using specialized model for {language}")
|
330 |
return transcribe_with_specialized_model(audio_path, language)
|
331 |
else:
|
332 |
+
print(f"🔄 Using primary model for {language}")
|
333 |
return transcribe_with_primary_model(audio_path, language)
|
334 |
except Exception as e:
|
335 |
print(f"Transcription failed, trying specialized model: {e}")
|
336 |
if not use_specialized:
|
337 |
return transcribe_audio(audio_path, language, initial_prompt, use_specialized=True)
|
338 |
else:
|
339 |
+
return f"Error: All transcription methods failed - {str(e)[:200]}"
|
340 |
|
341 |
def highlight_differences(ref, hyp):
|
342 |
+
if not (ref or "").strip() or not (hyp or "").strip():
|
343 |
return "No text to compare"
|
|
|
344 |
ref_words = ref.strip().split()
|
345 |
hyp_words = hyp.strip().split()
|
|
|
346 |
sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
|
347 |
out_html = []
|
|
|
348 |
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
349 |
if tag == 'equal':
|
350 |
out_html.extend([f"<span style='color:green; font-weight:bold; background-color:#e8f5e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
|
|
|
355 |
out_html.extend([f"<span style='color:red; text-decoration:line-through; background-color:#ffe8e8; padding:2px 4px; margin:1px; border-radius:3px;'>{w}</span>" for w in ref_words[i1:i2]])
|
356 |
elif tag == 'insert':
|
357 |
out_html.extend([f"<span style='color:orange; font-weight:bold; background-color:#fff3cd; padding:2px 4px; margin:1px; border-radius:3px;'>+{w}</span>" for w in hyp_words[j1:j2]])
|
|
|
358 |
return " ".join(out_html)
|
359 |
|
360 |
def char_level_highlight(ref, hyp):
|
361 |
+
if not (ref or "").strip() or not (hyp or "").strip():
|
362 |
return "No text to compare"
|
|
|
363 |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
|
364 |
out = []
|
365 |
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
|
|
384 |
else:
|
385 |
return "💪 Keep Trying! (<40%)", "Don't give up! Practice makes perfect."
|
386 |
|
387 |
+
# ---------------- MAIN FUNCTION ---------------- #
|
388 |
+
@GPU_DECORATOR
|
389 |
def compare_pronunciation(audio, language_choice, intended_sentence):
|
390 |
+
print(f"🔍 Starting analysis with language: {language_choice}")
|
391 |
print(f"📝 Audio file: {audio}")
|
392 |
print(f"🎯 Intended sentence: {intended_sentence}")
|
393 |
|
394 |
if audio is None:
|
|
|
395 |
return ("❌ Please record audio first.", "", "", "", "", "", "", "")
|
396 |
+
if not (intended_sentence or "").strip():
|
|
|
|
|
397 |
return ("❌ Please generate a practice sentence first.", "", "", "", "", "", "", "")
|
398 |
|
399 |
try:
|
400 |
+
print("🔄 Pass 1: Primary model transcription...")
|
401 |
primary_text = transcribe_audio(audio, language_choice, use_specialized=False)
|
402 |
+
print(f"✅ Primary: {primary_text}")
|
403 |
|
404 |
+
print("🔄 Pass 2: Specialized model transcription...")
|
405 |
specialized_text = transcribe_audio(audio, language_choice, use_specialized=True)
|
406 |
+
print(f"✅ Specialized: {specialized_text}")
|
407 |
|
408 |
actual_text = primary_text if not str(primary_text).startswith("Error:") else specialized_text
|
409 |
|
410 |
if str(actual_text).startswith("Error:"):
|
|
|
411 |
return (f"❌ {actual_text}", "", "", "", "", "", "", "")
|
412 |
|
|
|
|
|
|
|
|
|
413 |
try:
|
414 |
+
wer_val = jiwer.wer(intended_sentence, actual_text)
|
415 |
+
cer_val = jiwer.cer(intended_sentence, actual_text)
|
|
|
|
|
416 |
except Exception as e:
|
417 |
+
print(f"❌ Metrics error: {e}")
|
418 |
wer_val, cer_val = 1.0, 1.0
|
419 |
|
420 |
score_text, feedback = get_pronunciation_score(wer_val, cer_val)
|
421 |
|
|
|
422 |
actual_hk = transliterate_to_hk(actual_text, language_choice)
|
423 |
target_hk = transliterate_to_hk(intended_sentence, language_choice)
|
424 |
+
if language_choice != "English" and not is_script(actual_text, language_choice):
|
|
|
425 |
actual_hk = f"⚠️ Expected {language_choice} script, got mixed/other script"
|
426 |
|
|
|
427 |
diff_html = highlight_differences(intended_sentence, actual_text)
|
428 |
char_html = char_level_highlight(intended_sentence, actual_text)
|
429 |
|
430 |
+
status = f"✅ Analysis Complete - {score_text}\n💬 {feedback}\n🚀 Powered by High-Performance ASR Models"
|
|
|
431 |
|
432 |
return (
|
433 |
status,
|
|
|
442 |
|
443 |
except Exception as e:
|
444 |
error_msg = f"❌ Analysis Error: {str(e)[:200]}"
|
|
|
|
|
|
|
445 |
return (error_msg, str(e), "", "", "", "", "", "")
|
446 |
|
447 |
+
# ---------------- UI ---------------- #
|
448 |
def create_interface():
|
449 |
with gr.Blocks(title="🎙️ SOTA Multilingual Pronunciation Trainer") as demo:
|
|
|
450 |
gr.Markdown("""
|
451 |
# 🎙️ Advanced Multilingual Pronunciation Trainer
|
452 |
|
|
|
454 |
|
455 |
### 🏆 Powered by Advanced Models:
|
456 |
- Dual-Model Analysis: Primary + specialized model comparison
|
457 |
+
- High Accuracy: Language-specific fine-tuned models
|
458 |
- Robust Performance: Automatic fallback for reliability
|
459 |
|
460 |
### 📋 How to Use:
|
461 |
1. Select your target language 🌍
|
462 |
+
2. Generate a practice sentence 🎲
|
463 |
3. Record yourself reading it aloud 🎤
|
464 |
4. Get detailed feedback with advanced accuracy 📊
|
465 |
|
|
|
528 |
gr.Markdown("""
|
529 |
### 🎨 Color Guide:
|
530 |
- 🟢 Green: Correctly pronounced words/characters
|
531 |
+
- 🔴 Red: Missing or mispronounced (strikethrough)
|
532 |
- 🟠 Orange: Extra words or substitutions
|
533 |
""")
|
534 |
diff_html_box = gr.HTML(label="🔍 Word-Level Analysis", show_label=True)
|
|
|
550 |
fn=compare_pronunciation,
|
551 |
inputs=[audio_input, lang_choice, intended_display],
|
552 |
outputs=[
|
553 |
+
status_output, # status
|
554 |
+
pass1_out, # primary transcription
|
555 |
+
pass2_out, # specialized transcription
|
556 |
+
wer_out, # wer formatted
|
557 |
+
cer_out, # cer formatted
|
558 |
+
diff_html_box, # diff_html
|
559 |
+
char_html_box, # char_html
|
560 |
+
target_display # target_display
|
561 |
]
|
562 |
)
|
563 |
|
|
|
571 |
---
|
572 |
### 🏆 Advanced Technology Stack:
|
573 |
- Primary ASR: OpenAI Whisper Large v2 (High-performance multilingual model)
|
574 |
+
- Specialized Models: Fine-tuned language-specific models
|
575 |
+
- Tamil: vasista22/whisper-tamil-large-v2 (IIT Madras Speech Lab)
|
576 |
+
- Malayalam: thennal/whisper-medium-ml (Common Voice trained)
|
577 |
+
- English: openai/whisper-base.en (English-optimized)
|
578 |
+
- Dual Analysis: Primary + specialized model comparison
|
579 |
+
- Automatic Fallback: Ensures reliable results always
|
580 |
|
581 |
### 🔧 Technical Details:
|
582 |
+
- Metrics: WER (Word Error Rate) and CER (Character Error Rate)
|
583 |
+
- Transliteration: Harvard-Kyoto system for Indic scripts
|
584 |
+
- Analysis: Dual-model comparison for comprehensive feedback
|
585 |
+
- Languages: English, Tamil, and Malayalam
|
586 |
""")
|
587 |
return demo
|
588 |
|
589 |
+
# ---------------- LAUNCH ---------------- #
|
590 |
if __name__ == "__main__":
|
591 |
print("🚀 Starting Advanced Multilingual Pronunciation Trainer...")
|
592 |
+
print(f"🔧 Device: {DEVICE}")
|
593 |
+
try:
|
594 |
+
torch_ver = getattr(torch, '__version__', 'unknown')
|
595 |
+
except Exception:
|
596 |
+
torch_ver = 'unknown'
|
597 |
+
print(f"🔧 PyTorch version: {torch_ver}")
|
598 |
print("🏆 Using High-Performance Dual-Model Approach")
|
599 |
print("⚡ Automatic model selection with specialized fallbacks")
|
600 |
print("📊 Advanced analysis with robust error handling")
|
|
|
601 |
|
602 |
demo = create_interface()
|
603 |
demo.launch(
|