Spaces:
Sleeping
Sleeping
import gradio as gr | |
import random, difflib, re, warnings, contextlib | |
import torch | |
import numpy as np | |
import librosa, soundfile as sf | |
import jiwer | |
# Optional transliteration | |
try: | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
INDIC_OK = True | |
except: | |
INDIC_OK = False | |
# Optional HF Spaces decorator | |
try: | |
import spaces | |
GPU_DECORATOR = spaces.GPU | |
except: | |
class _NoOp: | |
def __call__(self, f): return f | |
GPU_DECORATOR = _NoOp() | |
warnings.filterwarnings("ignore") | |
# ---------------- CONFIG ---------------- # | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1 | |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
amp_ctx = torch.cuda.amp.autocast if DEVICE == "cuda" else contextlib.nullcontext | |
print(f"🔧 Using device: {DEVICE}") | |
LANG_CODES = {"English": "en", "Tamil": "ta", "Malayalam": "ml"} | |
# Primary: IndicWhisper | |
INDICWHISPER_MODEL = "parthiv11/indic_whisper_nodcil" | |
# Specialised fallbacks | |
SPECIALIZED_MODELS = { | |
"English": "openai/whisper-base.en", | |
"Tamil": "vasista22/whisper-tamil-large-v2", | |
"Malayalam": "thennal/whisper-medium-ml", | |
} | |
# Scripts and banking | |
SCRIPT_PATTERNS = { | |
"Tamil": re.compile(r"[-]"), | |
"Malayalam": re.compile(r"[ഀ-ൿ]"), | |
"English": re.compile(r"[A-Za-z]"), | |
} | |
SENTENCE_BANK = { | |
"English": ["The sun sets over the beautiful horizon.", "Hard work always pays off in the end."], | |
"Tamil": ["இன்று நல்ல வானிலை உள்ளது.", "உழைப்பு எப்போதும் வெற்றியைத் தரும்."], | |
"Malayalam": ["എനിക്ക് മലയാളം വളരെ ഇഷ്ടമാണ്.", "കഠിനാധ്വാനം എപ്പോഴും ഫലം നൽകും."] | |
} | |
# Model cache | |
indicwhisper_pipeline = None | |
fallback_models = {} | |
WHISPER_JAX_AVAILABLE = False | |
# ---------------- HELPERS ---------------- # | |
def get_random_sentence(language_choice): | |
return random.choice(SENTENCE_BANK[language_choice]) | |
def is_script(text, lang_name): | |
p = SCRIPT_PATTERNS.get(lang_name) | |
return not p or bool(p.search(text or "")) | |
def transliterate_to_hk(text, lang_choice): | |
if not INDIC_OK: | |
return text | |
mapping = {"Tamil": sanscript.TAMIL, "Malayalam": sanscript.MALAYALAM, "English": None} | |
script = mapping.get(lang_choice) | |
if script and is_script(text, lang_choice): | |
try: return transliterate(text, script, sanscript.HK) | |
except: return text | |
return text | |
def preprocess_audio(audio_path, target_sr=16000): | |
try: | |
audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) | |
if audio is None or len(audio) == 0: return None, None | |
audio = audio.astype(np.float32) | |
m = np.max(np.abs(audio)) | |
if m > 0: audio /= m | |
audio, _ = librosa.effects.trim(audio, top_db=20) | |
if len(audio) < int(target_sr*0.1): return None, None | |
return audio, target_sr | |
except: return None, None | |
JIWER_TRANSFORM = jiwer.Compose([jiwer.ToLowerCase(), jiwer.RemovePunctuation(), | |
jiwer.RemoveMultipleSpaces(), jiwer.Strip(), | |
jiwer.ReduceToListOfListOfWords()]) | |
def compute_wer(ref,hyp): | |
try: return jiwer.wer(ref, hyp, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM) | |
except: return 1.0 | |
def compute_cer(ref,hyp): | |
try: return jiwer.cer(ref, hyp) | |
except: return 1.0 | |
# ---------------- MODEL LOADERS ---------------- # | |
def load_indicwhisper(): | |
global indicwhisper_pipeline, WHISPER_JAX_AVAILABLE | |
if indicwhisper_pipeline: return indicwhisper_pipeline | |
try: | |
from whisper_jax import FlaxWhisperPipeline; import jax.numpy as jnp | |
indicwhisper_pipeline = FlaxWhisperPipeline(INDICWHISPER_MODEL, dtype=jnp.bfloat16, batch_size=1) | |
WHISPER_JAX_AVAILABLE = True | |
print("✅ JAX IndicWhisper loaded!") | |
return indicwhisper_pipeline | |
except Exception as e: | |
print(f"⚠️ JAX unavailable: {e}"); WHISPER_JAX_AVAILABLE = False | |
from transformers import pipeline | |
indicwhisper_pipeline = pipeline("automatic-speech-recognition", model=INDICWHISPER_MODEL, device=DEVICE_INDEX) | |
print("✅ Transformers IndicWhisper loaded!") | |
return indicwhisper_pipeline | |
def load_specialized_model(language): | |
if language in fallback_models: return fallback_models[language] | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
name = SPECIALIZED_MODELS[language] | |
proc = AutoProcessor.from_pretrained(name) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(name, torch_dtype=DTYPE).to(DEVICE) | |
fallback_models[language] = {"processor": proc, "model": model} | |
return fallback_models[language] | |
# ---------------- TRANSCRIBE ---------------- # | |
def transcribe_with_primary_model(audio_path, language): | |
try: | |
pl = load_indicwhisper(); lang_code = LANG_CODES.get(language, "en") | |
if WHISPER_JAX_AVAILABLE: | |
res = pl(audio_path, task="transcribe", language=lang_code) | |
if isinstance(res, dict): return res.get("text","").strip() | |
return str(res).strip() | |
if hasattr(pl, "model") and hasattr(pl, "tokenizer"): | |
try: | |
forced_ids = pl.tokenizer.get_decoder_prompt_ids(language=lang_code, task="transcribe") | |
pl.model.config.forced_decoder_ids = forced_ids | |
except: pass | |
with amp_ctx(): | |
out = pl(audio_path) | |
if isinstance(out, dict): return (out.get("text") or "").strip() | |
return str(out).strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def transcribe_with_specialized_model(audio_path, language): | |
try: | |
comp = load_specialized_model(language) | |
audio, sr = preprocess_audio(audio_path) | |
if audio is None: return "Error: Audio too short" | |
inputs = comp["processor"](audio, sampling_rate=sr, return_tensors="pt") | |
feats = inputs.input_features.to(DEVICE) | |
gen_kwargs = {"inputs": feats, "max_length": 200, "num_beams": 3} | |
if language != "English": | |
try: | |
forced_ids = comp["processor"].tokenizer.get_decoder_prompt_ids(LANG_CODES[language], task="transcribe") | |
gen_kwargs["forced_decoder_ids"] = forced_ids | |
except: pass | |
with torch.no_grad(), amp_ctx(): | |
ids = comp["model"].generate(**gen_kwargs) | |
text = comp["processor"].batch_decode(ids, skip_special_tokens=True)[0] | |
return text.strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def transcribe_audio(audio_path, language, use_specialized=False): | |
if use_specialized: | |
return transcribe_with_specialized_model(audio_path, language) | |
else: | |
return transcribe_with_primary_model(audio_path, language) | |
# ---------------- MAIN ---------------- # | |
def compare_pronunciation(audio, lang_choice, intended): | |
if audio is None: return ("❌ Please record audio first.","","","","","","","") | |
if not intended.strip(): return ("❌ Please generate a sentence first.","","","","","","","") | |
ptext = transcribe_audio(audio, lang_choice, False) | |
stext = transcribe_audio(audio, lang_choice, True) | |
actual = ptext if not ptext.startswith("Error:") else stext | |
if actual.startswith("Error:"): return (f"❌ {actual}","","","","","","","") | |
wer_val, cer_val = compute_wer(intended, actual), compute_cer(intended, actual) | |
score, feedback = get_score(wer_val, cer_val) | |
return (f"✅ Done - {score}\n💬 {feedback}", | |
ptext, stext, | |
f"{wer_val:.3f} ({(1-wer_val)*100:.1f}%)", | |
f"{cer_val:.3f} ({(1-cer_val)*100:.1f}%)", | |
diff_html(intended, actual), | |
char_html(intended, actual), | |
f"🎯 Target: {intended}") | |
def get_score(wer, cer): | |
c = (wer*0.7)+(cer*0.3) | |
if c <= 0.1: return "🏆 Excellent!","Outstanding!" | |
elif c <= 0.2: return "🎉 Very Good!","Minor improvements needed." | |
elif c <= 0.4: return "👍 Good!","Keep practicing." | |
elif c <= 0.6: return "📚 Needs Practice","Focus on clearer pronunciation." | |
else: return "💪 Keep Trying!","Don't give up!" | |
def diff_html(ref,hyp): return highlight_differences(ref,hyp) | |
def char_html(ref,hyp): return char_level_highlight(ref,hyp) | |
# Diff functions | |
def highlight_differences(ref,hyp): | |
ref_w, hyp_w = ref.split(), hyp.split() | |
sm = difflib.SequenceMatcher(None, ref_w, hyp_w) | |
out=[] | |
for tag,i1,i2,j1,j2 in sm.get_opcodes(): | |
if tag=='equal': out += [f"<span style='color:green'>{w}</span>" for w in ref_w[i1:i2]] | |
elif tag=='replace': | |
out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]] | |
out += [f"<span style='color:orange'>→{w}</span>" for w in hyp_w[j1:j2]] | |
elif tag=='delete': | |
out += [f"<span style='color:red'>{w}</span>" for w in ref_w[i1:i2]] | |
elif tag=='insert': | |
out += [f"<span style='color:orange'>+{w}</span>" for w in hyp_w[j1:j2]] | |
return " ".join(out) | |
def char_level_highlight(ref,hyp): | |
sm = difflib.SequenceMatcher(None, list(ref), list(hyp)) | |
out=[] | |
for tag,i1,i2,j1,j2 in sm.get_opcodes(): | |
if tag=='equal': out += [f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]] | |
elif tag in ('replace','delete'): out += [f"<span style='color:red'>{c}</span>" for c in ref[i1:i2]] | |
elif tag=='insert': out += [f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]] | |
return "".join(out) | |
# ---------------- UI ---------------- # | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🎙️ IndicWhisper Pronunciation Trainer") | |
with gr.Row(): | |
lang = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Tamil", label="Language") | |
btn = gr.Button("🎲 Generate Sentence") | |
intended = gr.Textbox(label="Practice Sentence", interactive=False, lines=3) | |
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Record") | |
analyze = gr.Button("🔍 Analyze") | |
status = gr.Textbox(label="Results", interactive=False, lines=4) | |
pass1 = gr.Textbox(label="Primary (IndicWhisper)") | |
pass2 = gr.Textbox(label="Specialized") | |
wer = gr.Textbox(label="Word Accuracy") | |
cer = gr.Textbox(label="Char Accuracy") | |
diff = gr.HTML(label="Word Diff") | |
chars = gr.HTML(label="Char Diff") | |
target = gr.Textbox(label="Reference", visible=False) | |
btn.click(get_random_sentence, [lang], [intended]) | |
analyze.click(compare_pronunciation, [audio, lang, intended], | |
[status, pass1, pass2, wer, cer, diff, chars, target]) | |
lang.change(get_random_sentence, [lang], [intended]) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |