sudhanm commited on
Commit
5a75be5
ยท
verified ยท
1 Parent(s): 9d552f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -162
app.py CHANGED
@@ -4,41 +4,66 @@ import difflib
4
  import re
5
  import jiwer
6
  import torch
7
- from parler_tts import ParlerTTSForConditionalGeneration
8
- from transformers import AutoTokenizer
9
- from faster_whisper import WhisperModel
 
 
 
 
 
 
 
 
10
  from indic_transliteration import sanscript
11
  from indic_transliteration.sanscript import transliterate
12
- import soundfile as sf
 
13
 
14
  # ---------------- CONFIG ---------------- #
15
- MODEL_NAME = "large-v2"
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  LANG_CODES = {
19
  "English": "en",
20
- "Tamil": "ta",
21
  "Malayalam": "ml",
22
  "Hindi": "hi",
23
  "Sanskrit": "sa"
24
  }
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  LANG_PRIMERS = {
27
- "English": ("The transcript should be in English only.",
28
- "Write only in English without translation. Example: This is an English sentence."),
29
- "Tamil": ("เฎจเฎ•เฎฒเฏ เฎคเฎฎเฎฟเฎดเฏ เฎŽเฎดเฏเฎคเฏเฎคเฏเฎ•เฏเฎ•เฎณเฎฟเฎฒเฏ เฎฎเฎŸเฏเฎŸเฏเฎฎเฏ เฎ‡เฎฐเฏเฎ•เฏเฎ• เฎตเฏ‡เฎฃเฏเฎŸเฏเฎฎเฏ.",
30
- "เฎคเฎฎเฎฟเฎดเฏ เฎŽเฎดเฏเฎคเฏเฎคเฏเฎ•เฏเฎ•เฎณเฎฟเฎฒเฏ เฎฎเฎŸเฏเฎŸเฏเฎฎเฏ เฎŽเฎดเฏเฎคเฎตเฏเฎฎเฏ, เฎฎเฏŠเฎดเฎฟเฎชเฏ†เฎฏเฎฐเฏเฎชเฏเฎชเฏ เฎšเฏ†เฎฏเฏเฎฏเฎ•เฏเฎ•เฏ‚เฎŸเฎพเฎคเฏ. เฎ‰เฎคเฎพเฎฐเฎฃเฎฎเฏ: เฎ‡เฎคเฏ เฎ’เฎฐเฏ เฎคเฎฎเฎฟเฎดเฏ เฎตเฎพเฎ•เฏเฎ•เฎฟเฎฏเฎฎเฏ."),
31
- "Malayalam": ("เดŸเตเดฐเดพเตปเดธเตเด–เตเดฐเดฟเดชเตเดฑเตเดฑเต เดฎเดฒเดฏเดพเดณ เดฒเดฟเดชเดฟเดฏเดฟเตฝ เด†เดฏเดฟเดฐเดฟเด•เตเด•เดฃเด‚.",
32
- "เดฎเดฒเดฏเดพเดณ เดฒเดฟเดชเดฟเดฏเดฟเตฝ เดฎเดพเดคเตเดฐเด‚ เดŽเดดเตเดคเตเด•, เดตเดฟเดตเตผเดคเตเดคเดจเด‚ เดšเต†เดฏเตเดฏเดฐเตเดคเต. เด‰เดฆเดพเดนเดฐเดฃเด‚: เด‡เดคเตŠเดฐเต เดฎเดฒเดฏเดพเดณ เดตเดพเด•เตเดฏเดฎเดพเดฃเต. เดŽเดจเดฟเด•เตเด•เต เดฎเดฒเดฏเดพเดณเด‚ เด…เดฑเดฟเดฏเดพเด‚."),
33
- "Hindi": ("เคชเฅเคฐเคคเคฟเคฒเคฟเคชเคฟ เค•เฅ‡เคตเคฒ เคฆเฅ‡เคตเคจเคพเค—เคฐเฅ€ เคฒเคฟเคชเคฟ เคฎเฅ‡เค‚ เคนเฅ‹เคจเฅ€ เคšเคพเคนเคฟเคเฅค",
34
- "เค•เฅ‡เคตเคฒ เคฆเฅ‡เคตเคจเคพเค—เคฐเฅ€ เคฒเคฟเคชเคฟ เคฎเฅ‡เค‚ เคฒเคฟเค–เฅ‡เค‚, เค…เคจเฅเคตเคพเคฆ เคจ เค•เคฐเฅ‡เค‚เฅค เค‰เคฆเคพเคนเคฐเคฃ: เคฏเคน เคเค• เคนเคฟเค‚เคฆเฅ€ เคตเคพเค•เฅเคฏ เคนเฅˆเฅค"),
35
- "Sanskrit": ("เคชเฅเคฐเคคเคฟเคฒเคฟเคชเคฟ เค•เฅ‡เคตเคฒ เคฆเฅ‡เคตเคจเคพเค—เคฐเฅ€ เคฒเคฟเคชเคฟ เคฎเฅ‡เค‚ เคนเฅ‹เคจเฅ€ เคšเคพเคนเคฟเคเฅค",
36
- "เค•เฅ‡เคตเคฒ เคฆเฅ‡เคตเคจเคพเค—เคฐเฅ€ เคฒเคฟเคชเคฟ เคฎเฅ‡เค‚ เคฒเคฟเค–เฅ‡เค‚, เค…เคจเฅเคตเคพเคฆ เคจ เค•เคฐเฅ‡เค‚เฅค เค‰เคฆเคพเคนเคฐเคฃ: เค…เคนเค‚ เคธเค‚เคธเฅเค•เฅƒเคคเค‚ เคœเคพเคจเคพเคฎเคฟเฅค")
37
  }
38
 
39
  SCRIPT_PATTERNS = {
40
  "Tamil": re.compile(r"[เฎ€-เฏฟ]"),
41
- "Malayalam": re.compile(r"[เด€-เตฟ]"),
42
  "Hindi": re.compile(r"[เค€-เฅฟ]"),
43
  "Sanskrit": re.compile(r"[เค€-เฅฟ]"),
44
  "English": re.compile(r"[A-Za-z]")
@@ -46,206 +71,404 @@ SCRIPT_PATTERNS = {
46
 
47
  SENTENCE_BANK = {
48
  "English": [
49
- "The sun sets over the horizon.",
50
- "Learning languages is fun.",
51
- "I like to drink coffee in the morning."
 
 
52
  ],
53
  "Tamil": [
54
  "เฎ‡เฎฉเฏเฎฑเฏ เฎจเฎฒเฏเฎฒ เฎตเฎพเฎฉเฎฟเฎฒเฏˆ เฎ‰เฎณเฏเฎณเฎคเฏ.",
55
- "เฎจเฎพเฎฉเฏ เฎคเฎฎเฎฟเฎดเฏ เฎ•เฎฑเฏเฎฑเฏเฎ•เฏเฎ•เฏŠเฎฃเฏเฎŸเฏ เฎ‡เฎฐเฏเฎ•เฏเฎ•เฎฟเฎฑเฏ‡เฎฉเฏ.",
56
- "เฎŽเฎฉเฎ•เฏเฎ•เฏ เฎชเฏเฎคเฏเฎคเฎ•เฎฎเฏ เฎชเฎŸเฎฟเฎ•เฏเฎ• เฎตเฎฟเฎฐเฏเฎชเฏเฎชเฎฎเฏ."
 
 
57
  ],
58
  "Malayalam": [
59
  "เดŽเดจเดฟเด•เตเด•เต เดฎเดฒเดฏเดพเดณเด‚ เดตเดณเดฐเต† เด‡เดทเตเดŸเดฎเดพเดฃเต.",
60
  "เด‡เดจเตเดจเต เดฎเดดเดชเต†เดฏเตเดฏเตเดจเตเดจเต.",
61
- "เดžเดพเตป เดชเตเดธเตเดคเด•เด‚ เดตเดพเดฏเดฟเด•เตเด•เตเดจเตเดจเต."
 
 
62
  ],
63
  "Hindi": [
64
- "เค†เคœ เคฎเฅŒเคธเคฎ เค…เคšเฅเค›เคพ เคนเฅˆเฅค",
65
- "เคฎเฅเคเฅ‡ เคนเคฟเค‚เคฆเฅ€ เคฌเฅ‹เคฒเคจเคพ เคชเคธเค‚เคฆ เคนเฅˆเฅค",
66
- "เคฎเฅˆเค‚ เค•เคฟเคคเคพเคฌ เคชเคขเคผ เคฐเคนเคพ เคนเฅ‚เคเฅค"
 
 
67
  ],
68
  "Sanskrit": [
69
  "เค…เคนเค‚ เค—เฅเคฐเคจเฅเคฅเค‚ เคชเค เคพเคฎเคฟเฅค",
70
  "เค…เคฆเฅเคฏ เคธเฅ‚เคฐเฅเคฏเคƒ เคคเฅ‡เคœเคธเฅเคตเฅ€ เค…เคธเฅเคคเคฟเฅค",
71
- "เคฎเคฎ เคจเคพเคฎ เคฐเคพเคฎเคƒเฅค"
 
 
72
  ]
73
  }
74
 
75
- VOICE_STYLE = {
76
- "English": "An English female voice with a neutral Indian accent.",
77
- "Tamil": "A female speaker with a clear Tamil accent.",
78
- "Malayalam": "A female speaker with a clear Malayali accent.",
79
- "Hindi": "A female speaker with a neutral Hindi accent.",
80
- "Sanskrit": "A female speaker reading in classical Sanskrit style."
81
- }
82
-
83
- # ---------------- LOAD MODELS ---------------- #
84
- print("Loading Whisper model...")
85
- whisper_model = WhisperModel(MODEL_NAME, device=DEVICE)
86
-
87
- print("Loading Parler-TTS model...")
88
- parler_model_id = "parler-tts/parler-tts-mini-v1" # You may switch to larger models if desired
89
- parler_tts_model = ParlerTTSForConditionalGeneration.from_pretrained(parler_model_id).to(DEVICE)
90
- parler_tts_tokenizer = AutoTokenizer.from_pretrained(parler_model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # ---------------- HELPERS ---------------- #
93
  def get_random_sentence(language_choice):
 
94
  return random.choice(SENTENCE_BANK[language_choice])
95
 
96
  def is_script(text, lang_name):
 
97
  pattern = SCRIPT_PATTERNS.get(lang_name)
98
  return bool(pattern.search(text)) if pattern else True
99
 
100
  def transliterate_to_hk(text, lang_choice):
 
101
  mapping = {
102
  "Tamil": sanscript.TAMIL,
103
- "Malayalam": sanscript.MALAYALAM,
104
  "Hindi": sanscript.DEVANAGARI,
105
  "Sanskrit": sanscript.DEVANAGARI,
106
  "English": None
107
  }
108
- return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text
109
-
110
- def transcribe_once(audio_path, lang_code, initial_prompt, beam_size, temperature, condition_on_previous_text):
111
- segments, _ = whisper_model.transcribe(
112
- audio_path,
113
- language=lang_code,
114
- task="transcribe",
115
- initial_prompt=initial_prompt,
116
- beam_size=beam_size,
117
- temperature=temperature,
118
- condition_on_previous_text=condition_on_previous_text,
119
- word_timestamps=False
120
- )
121
- return "".join(s.text for s in segments).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def highlight_differences(ref, hyp):
124
- ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
 
 
 
125
  sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
126
  out_html = []
 
127
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
128
  if tag == 'equal':
129
- out_html.extend([f"<span style='color:green'>{w}</span>" for w in ref_words[i1:i2]])
130
  elif tag == 'replace':
131
- out_html.extend([f"<span style='color:red'>{w}</span>" for w in ref_words[i1:i2]])
132
- out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
133
  elif tag == 'delete':
134
- out_html.extend([f"<span style='color:red;text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
135
  elif tag == 'insert':
136
- out_html.extend([f"<span style='color:orange'>{w}</span>" for w in hyp_words[j1:j2]])
 
137
  return " ".join(out_html)
138
 
139
  def char_level_highlight(ref, hyp):
 
140
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
141
  out = []
 
142
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
143
  if tag == 'equal':
144
  out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]])
145
  elif tag in ('replace', 'delete'):
146
- out.extend([f"<span style='color:red;text-decoration:underline'>{c}</span>" for c in ref[i1:i2]])
147
  elif tag == 'insert':
148
- out.extend([f"<span style='color:orange'>{c}</span>" for c in hyp[j1:j2]])
 
149
  return "".join(out)
150
 
151
- def synthesize_tts(text, lang_choice):
152
- if not text.strip():
153
- return None
154
- description = VOICE_STYLE.get(lang_choice, "")
155
- description_input = parler_tts_tokenizer(description, return_tensors='pt').to(DEVICE)
156
- prompt_input = parler_tts_tokenizer(text, return_tensors='pt').to(DEVICE)
157
- generation = parler_tts_model.generate(
158
- input_ids=description_input.input_ids,
159
- attention_mask=description_input.attention_mask,
160
- prompt_input_ids=prompt_input.input_ids,
161
- prompt_attention_mask=prompt_input.attention_mask
162
- )
163
- audio_arr = generation.cpu().numpy().squeeze()
164
- # Parler-TTS default sample rate is 24000
165
- return 24000, audio_arr
166
-
167
- # ---------------- MAIN ---------------- #
168
- def compare_pronunciation(audio, language_choice, intended_sentence,
169
- pass1_beam, pass1_temp, pass1_condition):
170
  if audio is None or not intended_sentence.strip():
171
- return ("No audio or intended sentence.", "", "", "", "", "",
172
  None, None, "", "")
173
-
174
- lang_code = LANG_CODES[language_choice]
175
- primer_weak, primer_strong = LANG_PRIMERS[language_choice]
176
-
177
- # Pass 1: raw transcription with user-configured decoding parameters
178
- actual_text = transcribe_once(audio, lang_code, primer_weak,
179
- pass1_beam, pass1_temp, pass1_condition)
180
-
181
- # Pass 2: strict transcription biased by intended sentence (fixed decoding params)
182
- strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
183
- corrected_text = transcribe_once(audio, lang_code, strict_prompt,
184
- beam_size=5, temperature=0.0, condition_on_previous_text=False)
185
-
186
- # Compute WER and CER
187
- wer_val = jiwer.wer(intended_sentence, actual_text)
188
- cer_val = jiwer.cer(intended_sentence, actual_text)
189
-
190
- # Transliteration of Pass 1 output
191
- hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]"
192
-
193
- # Highlight word-level and character-level differences
194
- diff_html = highlight_differences(intended_sentence, actual_text)
195
- char_html = char_level_highlight(intended_sentence, actual_text)
196
-
197
- # Synthesized TTS audios for intended and Pass 1 text
198
- tts_intended = synthesize_tts(intended_sentence, language_choice)
199
- tts_pass1 = synthesize_tts(actual_text, language_choice)
200
-
201
- return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}",
202
- diff_html, tts_intended, tts_pass1, char_html, intended_sentence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  # ---------------- UI ---------------- #
205
- with gr.Blocks() as demo:
206
- gr.Markdown("## ๐ŸŽ™ Pronunciation Comparator + Parler-TTS + Highlights")
207
-
208
- with gr.Row():
209
- lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
210
- gen_btn = gr.Button("๐ŸŽฒ Generate Sentence")
211
-
212
- intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
213
-
214
- with gr.Row():
215
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath")
216
- pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size")
217
- pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature")
218
- pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text")
219
-
220
- with gr.Row():
221
- pass1_out = gr.Textbox(label="Pass 1: What You Actually Said")
222
- pass2_out = gr.Textbox(label="Pass 2: Target-Biased Output")
223
- hk_out = gr.Textbox(label="Harvard-Kyoto Transliteration (Pass 1)")
224
-
225
- with gr.Row():
226
- wer_out = gr.Textbox(label="Word Error Rate")
227
- cer_out = gr.Textbox(label="Character Error Rate")
228
-
229
- diff_html_box = gr.HTML(label="Word Differences Highlighted")
230
- char_html_box = gr.HTML(label="Character-Level Highlighting (mispronounced = red underline)")
231
-
232
- with gr.Row():
233
- intended_tts_audio = gr.Audio(label="TTS - Intended Sentence", type="numpy")
234
- pass1_tts_audio = gr.Audio(label="TTS - Pass1 Output", type="numpy")
235
-
236
- gen_btn.click(fn=get_random_sentence, inputs=[lang_choice], outputs=[intended_display])
237
-
238
- submit_btn = gr.Button("Analyze Pronunciation")
239
-
240
- submit_btn.click(
241
- fn=compare_pronunciation,
242
- inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
243
- outputs=[
244
- pass1_out, pass2_out, hk_out, wer_out, cer_out,
245
- diff_html_box, intended_tts_audio, pass1_tts_audio,
246
- char_html_box, intended_display
247
- ]
248
- )
249
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if __name__ == "__main__":
251
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import re
5
  import jiwer
6
  import torch
7
+ import torchaudio
8
+ import numpy as np
9
+ from transformers import (
10
+ AutoProcessor,
11
+ AutoModelForSpeechSeq2Seq,
12
+ AutoTokenizer,
13
+ AutoModel
14
+ )
15
+ from TTS.api import TTS
16
+ import librosa
17
+ import soundfile as sf
18
  from indic_transliteration import sanscript
19
  from indic_transliteration.sanscript import transliterate
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
 
23
  # ---------------- CONFIG ---------------- #
 
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  LANG_CODES = {
27
  "English": "en",
28
+ "Tamil": "ta",
29
  "Malayalam": "ml",
30
  "Hindi": "hi",
31
  "Sanskrit": "sa"
32
  }
33
 
34
+ # AI4Bharat model configurations
35
+ ASR_MODELS = {
36
+ "English": "openai/whisper-base.en",
37
+ "Tamil": "ai4bharat/whisper-medium-ta",
38
+ "Malayalam": "ai4bharat/whisper-medium-ml",
39
+ "Hindi": "ai4bharat/whisper-medium-hi",
40
+ "Sanskrit": "ai4bharat/whisper-medium-hi" # Fallback to Hindi for Sanskrit
41
+ }
42
+
43
+ TTS_MODELS = {
44
+ "English": "tts_models/en/ljspeech/tacotron2-DDC",
45
+ "Tamil": "tts_models/ta/mai/tacotron2-DDC",
46
+ "Malayalam": "tts_models/ml/mai/tacotron2-DDC",
47
+ "Hindi": "tts_models/hi/mai/tacotron2-DDC",
48
+ "Sanskrit": "tts_models/hi/mai/tacotron2-DDC" # Fallback to Hindi
49
+ }
50
+
51
  LANG_PRIMERS = {
52
+ "English": ("Transcribe in English.",
53
+ "Write only in English. Example: This is an English sentence."),
54
+ "Tamil": ("เฎคเฎฎเฎฟเฎดเฎฟเฎฒเฏ เฎŽเฎดเฏเฎคเฏเฎ•.",
55
+ "เฎคเฎฎเฎฟเฎดเฏ เฎŽเฎดเฏเฎคเฏเฎคเฏเฎ•เฏเฎ•เฎณเฎฟเฎฒเฏ เฎฎเฎŸเฏเฎŸเฏเฎฎเฏ เฎŽเฎดเฏเฎคเฎตเฏเฎฎเฏ. เฎ‰เฎคเฎพเฎฐเฎฃเฎฎเฏ: เฎ‡เฎคเฏ เฎ’เฎฐเฏ เฎคเฎฎเฎฟเฎดเฏ เฎตเฎพเฎ•เฏเฎ•เฎฟเฎฏเฎฎเฏ."),
56
+ "Malayalam": ("เดฎเดฒเดฏเดพเดณเดคเตเดคเดฟเตฝ เดŽเดดเตเดคเตเด•.",
57
+ "เดฎเดฒเดฏเดพเดณ เดฒเดฟเดชเดฟเดฏเดฟเตฝ เดฎเดพเดคเตเดฐเด‚ เดŽเดดเตเดคเตเด•. เด‰เดฆเดพเดนเดฐเดฃเด‚: เด‡เดคเตŠเดฐเต เดฎเดฒเดฏเดพเดณ เดตเดพเด•เตเดฏเดฎเดพเดฃเต."),
58
+ "Hindi": ("เคนเคฟเค‚เคฆเฅ€ เคฎเฅ‡เค‚ เคฒเคฟเค–เฅ‡เค‚เฅค",
59
+ "เค•เฅ‡เคตเคฒ เคฆเฅ‡เคตเคจเคพเค—เคฐเฅ€ เคฒเคฟเคชเคฟ เคฎเฅ‡เค‚ เคฒเคฟเค–เฅ‡เค‚เฅค เค‰เคฆเคพเคนเคฐเคฃ: เคฏเคน เคเค• เคนเคฟเค‚เคฆเฅ€ เคตเคพเค•เฅเคฏ เคนเฅˆเฅค"),
60
+ "Sanskrit": ("เคธเค‚เคธเฅเค•เฅƒเคคเฅ‡ เคฒเคฟเค–เคคเฅค",
61
+ "เคฆเฅ‡เคตเคจเคพเค—เคฐเฅ€ เคฒเคฟเคชเคฟ เคฎเฅ‡เค‚ เคฒเคฟเค–เฅ‡เค‚เฅค เค‰เคฆเคพเคนเคฐเคฃ: เค…เคนเค‚ เคธเค‚เคธเฅเค•เฅƒเคคเค‚ เคœเคพเคจเคพเคฎเคฟเฅค")
62
  }
63
 
64
  SCRIPT_PATTERNS = {
65
  "Tamil": re.compile(r"[เฎ€-เฏฟ]"),
66
+ "Malayalam": re.compile(r"[เด€-เตฟ]"),
67
  "Hindi": re.compile(r"[เค€-เฅฟ]"),
68
  "Sanskrit": re.compile(r"[เค€-เฅฟ]"),
69
  "English": re.compile(r"[A-Za-z]")
 
71
 
72
  SENTENCE_BANK = {
73
  "English": [
74
+ "The sun sets over the beautiful horizon.",
75
+ "Learning new languages opens many doors.",
76
+ "I enjoy reading books in the evening.",
77
+ "Technology has changed our daily lives.",
78
+ "Music brings people together across cultures."
79
  ],
80
  "Tamil": [
81
  "เฎ‡เฎฉเฏเฎฑเฏ เฎจเฎฒเฏเฎฒ เฎตเฎพเฎฉเฎฟเฎฒเฏˆ เฎ‰เฎณเฏเฎณเฎคเฏ.",
82
+ "เฎจเฎพเฎฉเฏ เฎคเฎฎเฎฟเฎดเฏ เฎ•เฎฑเฏเฎฑเฏเฎ•เฏเฎ•เฏŠเฎฃเฏเฎŸเฏ เฎ‡เฎฐเฏเฎ•เฏเฎ•เฎฟเฎฑเฏ‡เฎฉเฏ.",
83
+ "เฎŽเฎฉเฎ•เฏเฎ•เฏ เฎชเฏเฎคเฏเฎคเฎ•เฎฎเฏ เฎชเฎŸเฎฟเฎ•เฏเฎ• เฎตเฎฟเฎฐเฏเฎชเฏเฎชเฎฎเฏ.",
84
+ "เฎคเฎฎเฎฟเฎดเฏ เฎฎเฏŠเฎดเฎฟ เฎฎเฎฟเฎ•เฎตเฏเฎฎเฏ เฎ…เฎดเฎ•เฎพเฎฉเฎคเฏ.",
85
+ "เฎ•เฏเฎŸเฏเฎฎเฏเฎชเฎคเฏเฎคเฏเฎŸเฎฉเฏ เฎจเฏ‡เฎฐเฎฎเฏ เฎšเฏ†เฎฒเฎตเฎฟเฎŸเฏเฎตเฎคเฏ เฎฎเฏเฎ•เฏเฎ•เฎฟเฎฏเฎฎเฏ."
86
  ],
87
  "Malayalam": [
88
  "เดŽเดจเดฟเด•เตเด•เต เดฎเดฒเดฏเดพเดณเด‚ เดตเดณเดฐเต† เด‡เดทเตเดŸเดฎเดพเดฃเต.",
89
  "เด‡เดจเตเดจเต เดฎเดดเดชเต†เดฏเตเดฏเตเดจเตเดจเต.",
90
+ "เดžเดพเตป เดชเตเดธเตเดคเด•เด‚ เดตเดพเดฏเดฟเด•เตเด•เตเดจเตเดจเต.",
91
+ "เด•เต‡เดฐเดณเดคเตเดคเดฟเดจเตเดฑเต† เดชเตเดฐเด•เตƒเดคเดฟ เดธเตเดจเตเดฆเดฐเดฎเดพเดฃเต.",
92
+ "เดตเดฟเดฆเตเดฏเดพเดญเตเดฏเดพเดธเด‚ เดœเต€เดตเดฟเดคเดคเตเดคเดฟเตฝ เดชเตเดฐเดงเดพเดจเดฎเดพเดฃเต."
93
  ],
94
  "Hindi": [
95
+ "เค†เคœ เคฎเฅŒเคธเคฎ เคฌเคนเฅเคค เค…เคšเฅเค›เคพ เคนเฅˆเฅค",
96
+ "เคฎเฅเคเฅ‡ เคนเคฟเค‚เคฆเฅ€ เคฌเฅ‹เคฒเคจเคพ เคชเคธเค‚เคฆ เคนเฅˆเฅค",
97
+ "เคฎเฅˆเค‚ เคฐเฅ‹เคœ เค•เคฟเคคเคพเคฌ เคชเคขเคผเคคเคพ เคนเฅ‚เคเฅค",
98
+ "เคญเคพเคฐเคค เค•เฅ€ เคธเค‚เคธเฅเค•เฅƒเคคเคฟ เคตเคฟเคตเคฟเคงเคคเคพเคชเฅ‚เคฐเฅเคฃ เคนเฅˆเฅค",
99
+ "เคถเคฟเค•เฅเคทเคพ เคนเคฎเคพเคฐเฅ‡ เคญเคตเคฟเคทเฅเคฏ เค•เฅ€ เค•เฅเค‚เคœเฅ€ เคนเฅˆเฅค"
100
  ],
101
  "Sanskrit": [
102
  "เค…เคนเค‚ เค—เฅเคฐเคจเฅเคฅเค‚ เคชเค เคพเคฎเคฟเฅค",
103
  "เค…เคฆเฅเคฏ เคธเฅ‚เคฐเฅเคฏเคƒ เคคเฅ‡เคœเคธเฅเคตเฅ€ เค…เคธเฅเคคเคฟเฅค",
104
+ "เคฎเคฎ เคจเคพเคฎ เคฐเคพเคฎเคƒเฅค",
105
+ "เคตเคฟเคฆเฅเคฏเคพ เคธเคฐเฅเคตเคคเฅเคฐ เคชเฅ‚เคœเฅเคฏเคคเฅ‡เฅค",
106
+ "เคธเคคเฅเคฏเคฎเฅ‡เคต เคœเคฏเคคเฅ‡เฅค"
107
  ]
108
  }
109
 
110
+ # ---------------- MODEL CACHE ---------------- #
111
+ asr_models = {}
112
+ tts_models = {}
113
+
114
+ def load_asr_model(language):
115
+ """Load ASR model for specific language"""
116
+ if language not in asr_models:
117
+ try:
118
+ model_name = ASR_MODELS[language]
119
+ print(f"Loading ASR model for {language}: {model_name}")
120
+
121
+ processor = AutoProcessor.from_pretrained(model_name)
122
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(DEVICE)
123
+
124
+ asr_models[language] = {"processor": processor, "model": model}
125
+ print(f"โœ… ASR model loaded for {language}")
126
+ except Exception as e:
127
+ print(f"โŒ Failed to load ASR for {language}: {e}")
128
+ # Fallback to English model
129
+ if language != "English":
130
+ print(f"๐Ÿ”„ Falling back to English ASR for {language}")
131
+ load_asr_model("English")
132
+ asr_models[language] = asr_models["English"]
133
+
134
+ return asr_models[language]
135
+
136
+ def load_tts_model(language):
137
+ """Load TTS model for specific language"""
138
+ if language not in tts_models:
139
+ try:
140
+ model_name = TTS_MODELS[language]
141
+ print(f"Loading TTS model for {language}: {model_name}")
142
+
143
+ tts = TTS(model_name=model_name).to(DEVICE)
144
+ tts_models[language] = tts
145
+ print(f"โœ… TTS model loaded for {language}")
146
+ except Exception as e:
147
+ print(f"โŒ Failed to load TTS for {language}: {e}")
148
+ # Fallback to English
149
+ if language != "English":
150
+ print(f"๐Ÿ”„ Falling back to English TTS for {language}")
151
+ load_tts_model("English")
152
+ tts_models[language] = tts_models["English"]
153
+
154
+ return tts_models[language]
155
 
156
  # ---------------- HELPERS ---------------- #
157
  def get_random_sentence(language_choice):
158
+ """Get random sentence for practice"""
159
  return random.choice(SENTENCE_BANK[language_choice])
160
 
161
  def is_script(text, lang_name):
162
+ """Check if text is in expected script"""
163
  pattern = SCRIPT_PATTERNS.get(lang_name)
164
  return bool(pattern.search(text)) if pattern else True
165
 
166
  def transliterate_to_hk(text, lang_choice):
167
+ """Transliterate Indic text to Harvard-Kyoto"""
168
  mapping = {
169
  "Tamil": sanscript.TAMIL,
170
+ "Malayalam": sanscript.MALAYALAM,
171
  "Hindi": sanscript.DEVANAGARI,
172
  "Sanskrit": sanscript.DEVANAGARI,
173
  "English": None
174
  }
175
+
176
+ script = mapping.get(lang_choice)
177
+ if script and is_script(text, lang_choice):
178
+ try:
179
+ return transliterate(text, script, sanscript.HK)
180
+ except:
181
+ return text
182
+ return text
183
+
184
+ def preprocess_audio(audio_path, target_sr=16000):
185
+ """Preprocess audio for ASR"""
186
+ try:
187
+ # Load audio
188
+ audio, sr = librosa.load(audio_path, sr=target_sr)
189
+
190
+ # Normalize audio
191
+ audio = audio / np.max(np.abs(audio))
192
+
193
+ # Remove silence
194
+ audio, _ = librosa.effects.trim(audio, top_db=20)
195
+
196
+ return audio, target_sr
197
+ except Exception as e:
198
+ print(f"Audio preprocessing error: {e}")
199
+ return None, None
200
+
201
+ def transcribe_with_ai4bharat(audio_path, language, initial_prompt=""):
202
+ """Transcribe audio using AI4Bharat models"""
203
+ try:
204
+ # Load model
205
+ asr_components = load_asr_model(language)
206
+ processor = asr_components["processor"]
207
+ model = asr_components["model"]
208
+
209
+ # Preprocess audio
210
+ audio, sr = preprocess_audio(audio_path)
211
+ if audio is None:
212
+ return "Error: Could not process audio"
213
+
214
+ # Prepare inputs
215
+ inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
216
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
217
+
218
+ # Generate transcription
219
+ with torch.no_grad():
220
+ predicted_ids = model.generate(**inputs, max_length=200)
221
+
222
+ # Decode
223
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
224
+
225
+ return transcription.strip()
226
+
227
+ except Exception as e:
228
+ print(f"Transcription error for {language}: {e}")
229
+ return f"Error: Transcription failed - {str(e)}"
230
+
231
+ def synthesize_with_ai4bharat(text, language):
232
+ """Synthesize speech using AI4Bharat TTS"""
233
+ if not text.strip():
234
+ return None
235
+
236
+ try:
237
+ # Load TTS model
238
+ tts = load_tts_model(language)
239
+
240
+ # Generate audio
241
+ audio_path = f"/tmp/tts_output_{hash(text)}.wav"
242
+ tts.tts_to_file(text=text, file_path=audio_path)
243
+
244
+ # Load generated audio
245
+ audio, sr = librosa.load(audio_path, sr=22050)
246
+
247
+ return sr, audio
248
+
249
+ except Exception as e:
250
+ print(f"TTS error for {language}: {e}")
251
+ return None
252
 
253
  def highlight_differences(ref, hyp):
254
+ """Highlight word-level differences"""
255
+ ref_words = ref.strip().split()
256
+ hyp_words = hyp.strip().split()
257
+
258
  sm = difflib.SequenceMatcher(None, ref_words, hyp_words)
259
  out_html = []
260
+
261
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
262
  if tag == 'equal':
263
+ out_html.extend([f"<span style='color:green; font-weight:bold'>{w}</span>" for w in ref_words[i1:i2]])
264
  elif tag == 'replace':
265
+ out_html.extend([f"<span style='color:red; text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
266
+ out_html.extend([f"<span style='color:orange; font-weight:bold'> โ†’ {w}</span>" for w in hyp_words[j1:j2]])
267
  elif tag == 'delete':
268
+ out_html.extend([f"<span style='color:red; text-decoration:line-through'>{w}</span>" for w in ref_words[i1:i2]])
269
  elif tag == 'insert':
270
+ out_html.extend([f"<span style='color:orange; font-weight:bold'>+{w}</span>" for w in hyp_words[j1:j2]])
271
+
272
  return " ".join(out_html)
273
 
274
  def char_level_highlight(ref, hyp):
275
+ """Highlight character-level differences"""
276
  sm = difflib.SequenceMatcher(None, list(ref), list(hyp))
277
  out = []
278
+
279
  for tag, i1, i2, j1, j2 in sm.get_opcodes():
280
  if tag == 'equal':
281
  out.extend([f"<span style='color:green'>{c}</span>" for c in ref[i1:i2]])
282
  elif tag in ('replace', 'delete'):
283
+ out.extend([f"<span style='color:red; text-decoration:underline; font-weight:bold'>{c}</span>" for c in ref[i1:i2]])
284
  elif tag == 'insert':
285
+ out.extend([f"<span style='color:orange; background-color:yellow'>{c}</span>" for c in hyp[j1:j2]])
286
+
287
  return "".join(out)
288
 
289
+ # ---------------- MAIN FUNCTION ---------------- #
290
+ def compare_pronunciation(audio, language_choice, intended_sentence):
291
+ """Main function to compare pronunciation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  if audio is None or not intended_sentence.strip():
293
+ return ("โŒ No audio or intended sentence provided.", "", "", "", "", "",
294
  None, None, "", "")
295
+
296
+ try:
297
+ print(f"Processing audio for {language_choice}")
298
+
299
+ # Pass 1: Raw transcription
300
+ primer_weak, _ = LANG_PRIMERS[language_choice]
301
+ actual_text = transcribe_with_ai4bharat(audio, language_choice, primer_weak)
302
+
303
+ # Pass 2: Target-biased transcription
304
+ _, primer_strong = LANG_PRIMERS[language_choice]
305
+ strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
306
+ corrected_text = transcribe_with_ai4bharat(audio, language_choice, strict_prompt)
307
+
308
+ # Error metrics
309
+ try:
310
+ wer_val = jiwer.wer(intended_sentence, actual_text)
311
+ cer_val = jiwer.cer(intended_sentence, actual_text)
312
+ except:
313
+ wer_val, cer_val = 1.0, 1.0
314
+
315
+ # Transliteration
316
+ hk_translit = transliterate_to_hk(actual_text, language_choice)
317
+ if not is_script(actual_text, language_choice):
318
+ hk_translit = f"โš ๏ธ Script mismatch: expected {language_choice} script"
319
+
320
+ # Visual feedback
321
+ diff_html = highlight_differences(intended_sentence, actual_text)
322
+ char_html = char_level_highlight(intended_sentence, actual_text)
323
+
324
+ # TTS synthesis
325
+ tts_intended = synthesize_with_ai4bharat(intended_sentence, language_choice)
326
+ tts_actual = synthesize_with_ai4bharat(actual_text, language_choice)
327
+
328
+ # Status message
329
+ status = f"โœ… Analysis complete for {language_choice}"
330
+ if wer_val < 0.1:
331
+ status += " - Excellent pronunciation! ๐ŸŽ‰"
332
+ elif wer_val < 0.3:
333
+ status += " - Good pronunciation! ๐Ÿ‘"
334
+ elif wer_val < 0.5:
335
+ status += " - Needs improvement ๐Ÿ“š"
336
+ else:
337
+ status += " - Keep practicing! ๐Ÿ’ช"
338
+
339
+ return (
340
+ status,
341
+ actual_text,
342
+ corrected_text,
343
+ hk_translit,
344
+ f"{wer_val:.3f}",
345
+ f"{cer_val:.3f}",
346
+ diff_html,
347
+ tts_intended,
348
+ tts_actual,
349
+ char_html,
350
+ intended_sentence
351
+ )
352
+
353
+ except Exception as e:
354
+ error_msg = f"โŒ Error during analysis: {str(e)}"
355
+ print(error_msg)
356
+ return (error_msg, "", "", "", "", "", None, None, "", "")
357
 
358
  # ---------------- UI ---------------- #
359
+ def create_interface():
360
+ with gr.Blocks(title="๐ŸŽ™๏ธ AI4Bharat Pronunciation Trainer", theme=gr.themes.Soft()) as demo:
361
+ gr.Markdown("""
362
+ # ๐ŸŽ™๏ธ AI4Bharat Pronunciation Trainer
363
+
364
+ Practice pronunciation in **Tamil, Malayalam, Hindi, Sanskrit & English** using state-of-the-art AI4Bharat models!
365
+
366
+ ๐Ÿ“‹ **How to use:**
367
+ 1. Select your target language
368
+ 2. Generate a practice sentence
369
+ 3. Record yourself reading it aloud
370
+ 4. Get detailed feedback with error analysis
371
+ """)
372
+
373
+ with gr.Row():
374
+ with gr.Column(scale=2):
375
+ lang_choice = gr.Dropdown(
376
+ choices=list(LANG_CODES.keys()),
377
+ value="Tamil",
378
+ label="๐ŸŒ Select Language"
379
+ )
380
+ with gr.Column(scale=1):
381
+ gen_btn = gr.Button("๐ŸŽฒ Generate Practice Sentence", variant="primary")
382
+
383
+ intended_display = gr.Textbox(
384
+ label="๐Ÿ“ Practice Sentence (Read this aloud)",
385
+ placeholder="Click 'Generate Practice Sentence' to get started...",
386
+ interactive=False,
387
+ lines=2
388
+ )
389
+
390
+ with gr.Row():
391
+ audio_input = gr.Audio(
392
+ sources=["microphone", "upload"],
393
+ type="filepath",
394
+ label="๐ŸŽค Record Your Pronunciation"
395
+ )
396
+
397
+ analyze_btn = gr.Button("๐Ÿ” Analyze Pronunciation", variant="primary", size="lg")
398
+
399
+ status_output = gr.Textbox(label="๐Ÿ“Š Analysis Status", interactive=False)
400
+
401
+ with gr.Row():
402
+ with gr.Column():
403
+ pass1_out = gr.Textbox(label="๐ŸŽฏ What You Actually Said", interactive=False)
404
+ wer_out = gr.Textbox(label="๐Ÿ“ˆ Word Error Rate (lower = better)", interactive=False)
405
+
406
+ with gr.Column():
407
+ pass2_out = gr.Textbox(label="๐Ÿ”ง Target-Biased Output", interactive=False)
408
+ cer_out = gr.Textbox(label="๐Ÿ“Š Character Error Rate (lower = better)", interactive=False)
409
+
410
+ hk_out = gr.Textbox(label="๐Ÿ”ค Romanization (Harvard-Kyoto)", interactive=False)
411
+
412
+ with gr.Accordion("๐Ÿ“ Detailed Feedback", open=True):
413
+ diff_html_box = gr.HTML(label="๐Ÿ” Word-Level Differences")
414
+ char_html_box = gr.HTML(label="๐Ÿ”ค Character-Level Analysis")
415
+
416
+ with gr.Row():
417
+ intended_tts_audio = gr.Audio(label="๐Ÿ”Š Reference Pronunciation", type="numpy")
418
+ actual_tts_audio = gr.Audio(label="๐Ÿ”Š Your Pronunciation (TTS)", type="numpy")
419
+
420
+ gr.Markdown("""
421
+ ### ๐ŸŽจ Color Guide:
422
+ - ๐ŸŸข **Green**: Correctly pronounced
423
+ - ๐Ÿ”ด **Red**: Missing or incorrect words
424
+ - ๐ŸŸ  **Orange**: Extra or substituted words
425
+ - ๐ŸŸก **Yellow background**: Inserted characters
426
+ """)
427
+
428
+ # Event handlers
429
+ gen_btn.click(
430
+ fn=get_random_sentence,
431
+ inputs=[lang_choice],
432
+ outputs=[intended_display]
433
+ )
434
+
435
+ analyze_btn.click(
436
+ fn=compare_pronunciation,
437
+ inputs=[audio_input, lang_choice, intended_display],
438
+ outputs=[
439
+ status_output, pass1_out, pass2_out, hk_out,
440
+ wer_out, cer_out, diff_html_box,
441
+ intended_tts_audio, actual_tts_audio,
442
+ char_html_box, intended_display
443
+ ]
444
+ )
445
+
446
+ # Auto-generate sentence on language change
447
+ lang_choice.change(
448
+ fn=get_random_sentence,
449
+ inputs=[lang_choice],
450
+ outputs=[intended_display]
451
+ )
452
+
453
+ return demo
454
+
455
+ # ---------------- LAUNCH ---------------- #
456
  if __name__ == "__main__":
457
+ print("๐Ÿš€ Starting AI4Bharat Pronunciation Trainer...")
458
+
459
+ # Pre-load English models for faster startup
460
+ print("๐Ÿ“ฆ Pre-loading English models...")
461
+ try:
462
+ load_asr_model("English")
463
+ load_tts_model("English")
464
+ print("โœ… English models loaded successfully")
465
+ except Exception as e:
466
+ print(f"โš ๏ธ Warning: Could not pre-load English models: {e}")
467
+
468
+ demo = create_interface()
469
+ demo.launch(
470
+ share=True,
471
+ show_error=True,
472
+ server_name="0.0.0.0",
473
+ server_port=7860
474
+ )