sudhanm commited on
Commit
382a648
Β·
verified Β·
1 Parent(s): 382858d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -70
app.py CHANGED
@@ -65,17 +65,51 @@ SENTENCE_BANK = {
65
  }
66
 
67
  # Global variables for models (will be loaded lazily)
68
- whisper_models = {}
69
- whisper_processors = {}
 
 
 
 
 
 
70
 
71
  def load_model(language_choice):
72
- """Load model for specific language if not already loaded"""
73
- if language_choice not in whisper_models:
74
- model_id = MODEL_CONFIGS[language_choice]
75
- print(f"Loading {language_choice} model: {model_id}")
76
- whisper_models[language_choice] = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
77
- whisper_processors[language_choice] = WhisperProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  print(f"{language_choice} model loaded successfully!")
 
 
 
 
 
 
 
 
 
79
 
80
  # ---------------- HELPERS ---------------- #
81
  def get_random_sentence(language_choice):
@@ -95,38 +129,50 @@ def transliterate_to_hk(text, lang_choice):
95
 
96
  @spaces.GPU
97
  def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
98
- # Load model if not already loaded
99
- load_model(language_choice)
100
-
101
- # Get the appropriate model and processor for the language
102
- model = whisper_models[language_choice]
103
- processor = whisper_processors[language_choice]
104
- lang_code = LANG_CODES[language_choice]
105
-
106
- # Load and process audio
107
- import librosa
108
- audio, sr = librosa.load(audio_path, sr=16000)
109
-
110
- # Process audio with the specific model's processor
111
- input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
112
-
113
- # Generate forced decoder ids for the language
114
- forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe")
115
-
116
- # Generate transcription
117
- with torch.no_grad():
118
- predicted_ids = model.generate(
119
- input_features,
120
- forced_decoder_ids=forced_decoder_ids,
121
- max_length=448,
122
- num_beams=beam_size,
123
- temperature=temperature if temperature > 0 else None,
124
- do_sample=temperature > 0,
125
- )
126
-
127
- # Decode the transcription
128
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
129
- return transcription.strip()
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def highlight_differences(ref, hyp):
132
  ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
@@ -226,54 +272,76 @@ def char_level_highlight(ref, hyp):
226
  def compare_pronunciation(audio, language_choice, intended_sentence,
227
  pass1_beam, pass1_temp, pass1_condition):
228
  if audio is None or not intended_sentence.strip():
229
- return ("No audio or intended sentence.", "", "", "", "", "", "", "")
230
 
231
- primer_weak, primer_strong = LANG_PRIMERS[language_choice]
 
232
 
233
- # Pass 1: raw transcription with user-configured decoding parameters
234
- actual_text = transcribe_once(audio, language_choice, primer_weak,
235
- pass1_beam, pass1_temp, pass1_condition)
236
-
237
- # Pass 2: strict transcription biased by intended sentence (fixed decoding params)
238
- strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
239
- corrected_text = transcribe_once(audio, language_choice, strict_prompt,
240
- beam_size=5, temperature=0.0, condition_on_previous_text=False)
241
-
242
- # Compute WER and CER
243
- wer_val = jiwer.wer(intended_sentence, actual_text)
244
- cer_val = jiwer.cer(intended_sentence, actual_text)
245
-
246
- # Transliteration of Pass 1 output
247
- hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]"
248
-
249
- # Highlight word-level and character-level differences
250
- diff_html = highlight_differences(intended_sentence, actual_text)
251
- char_html = char_level_highlight(intended_sentence, actual_text)
252
-
253
- return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}",
254
- diff_html, char_html, intended_sentence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  # ---------------- UI ---------------- #
257
  with gr.Blocks(title="Pronunciation Comparator") as demo:
258
  gr.Markdown("## πŸŽ™ Pronunciation Comparator - English, Tamil & Malayalam")
259
  gr.Markdown("Practice pronunciation with specialized Whisper models for each language!")
 
260
 
261
  with gr.Row():
262
  lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
263
  gen_btn = gr.Button("🎲 Generate Sentence")
264
 
265
  intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
 
 
 
266
 
267
  with gr.Row():
268
  audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation")
269
 
270
  with gr.Column():
271
- gr.Markdown("### Transcription Parameters")
272
- pass1_beam = gr.Slider(1, 10, value=8, step=1, label="Pass 1 Beam Size")
273
- pass1_temp = gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Pass 1 Temperature")
274
- pass1_condition = gr.Checkbox(value=True, label="Pass 1: Condition on previous text")
 
275
 
276
- submit_btn = gr.Button("πŸ” Analyze Pronunciation", variant="primary")
277
 
278
  gr.Markdown("### πŸ“Š Analysis Results")
279
  with gr.Row():
@@ -299,7 +367,7 @@ with gr.Blocks(title="Pronunciation Comparator") as demo:
299
  inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
300
  outputs=[
301
  pass1_out, pass2_out, hk_out, wer_out, cer_out,
302
- diff_html_box, char_html_box, intended_display
303
  ]
304
  )
305
 
 
65
  }
66
 
67
  # Global variables for models (will be loaded lazily)
68
+ current_model = None
69
+ current_processor = None
70
+ current_language = None
71
+
72
+ def clear_gpu_memory():
73
+ """Clear GPU memory to prevent OOM errors"""
74
+ if torch.cuda.is_available():
75
+ torch.cuda.empty_cache()
76
 
77
  def load_model(language_choice):
78
+ """Load model for specific language, unload previous if different"""
79
+ global current_model, current_processor, current_language
80
+
81
+ if current_language == language_choice and current_model is not None:
82
+ return current_model, current_processor
83
+
84
+ # Clear previous model if different language
85
+ if current_model is not None:
86
+ print(f"Unloading previous model for {current_language}")
87
+ del current_model
88
+ del current_processor
89
+ clear_gpu_memory()
90
+
91
+ # Load new model
92
+ model_id = MODEL_CONFIGS[language_choice]
93
+ print(f"Loading {language_choice} model: {model_id}")
94
+
95
+ try:
96
+ current_processor = WhisperProcessor.from_pretrained(model_id)
97
+ current_model = WhisperForConditionalGeneration.from_pretrained(
98
+ model_id,
99
+ torch_dtype=torch.float16, # Use half precision to save memory
100
+ device_map="auto"
101
+ )
102
+ current_language = language_choice
103
  print(f"{language_choice} model loaded successfully!")
104
+ return current_model, current_processor
105
+
106
+ except Exception as e:
107
+ print(f"Error loading model: {e}")
108
+ # Fallback to CPU if GPU fails
109
+ current_processor = WhisperProcessor.from_pretrained(model_id)
110
+ current_model = WhisperForConditionalGeneration.from_pretrained(model_id)
111
+ current_language = language_choice
112
+ return current_model, current_processor
113
 
114
  # ---------------- HELPERS ---------------- #
115
  def get_random_sentence(language_choice):
 
129
 
130
  @spaces.GPU
131
  def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
132
+ try:
133
+ # Load model if not already loaded
134
+ model, processor = load_model(language_choice)
135
+ lang_code = LANG_CODES[language_choice]
136
+
137
+ # Load and process audio
138
+ import librosa
139
+ audio, sr = librosa.load(audio_path, sr=16000)
140
+
141
+ # Process audio with the specific model's processor
142
+ input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
143
+
144
+ # Move to GPU if available
145
+ if torch.cuda.is_available():
146
+ input_features = input_features.to("cuda")
147
+
148
+ # Generate forced decoder ids for the language
149
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe")
150
+
151
+ # Generate transcription with memory-efficient settings
152
+ with torch.no_grad():
153
+ predicted_ids = model.generate(
154
+ input_features,
155
+ forced_decoder_ids=forced_decoder_ids,
156
+ max_length=200, # Reduced max length to save memory
157
+ num_beams=min(beam_size, 4), # Limit beam size for memory
158
+ temperature=temperature if temperature > 0 else None,
159
+ do_sample=temperature > 0,
160
+ no_repeat_ngram_size=2,
161
+ early_stopping=True
162
+ )
163
+
164
+ # Decode the transcription
165
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
166
+
167
+ # Clear GPU cache after inference
168
+ clear_gpu_memory()
169
+
170
+ return transcription.strip()
171
+
172
+ except Exception as e:
173
+ print(f"Transcription error: {e}")
174
+ clear_gpu_memory()
175
+ return f"Error during transcription: {str(e)}"
176
 
177
  def highlight_differences(ref, hyp):
178
  ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
 
272
  def compare_pronunciation(audio, language_choice, intended_sentence,
273
  pass1_beam, pass1_temp, pass1_condition):
274
  if audio is None or not intended_sentence.strip():
275
+ return ("No audio or intended sentence.", "", "", "", "", "", "", "", "❌ Please provide audio and sentence")
276
 
277
+ try:
278
+ primer_weak, primer_strong = LANG_PRIMERS[language_choice]
279
 
280
+ # Pass 1: raw transcription with user-configured decoding parameters
281
+ status_msg = f"πŸ”„ Transcribing with {language_choice} model..."
282
+ actual_text = transcribe_once(audio, language_choice, primer_weak,
283
+ pass1_beam, pass1_temp, pass1_condition)
284
+
285
+ if actual_text.startswith("Error"):
286
+ return (actual_text, "", "", "", "", "", "", "", "❌ Transcription failed")
287
+
288
+ # Pass 2: strict transcription biased by intended sentence (fixed decoding params)
289
+ strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
290
+ corrected_text = transcribe_once(audio, language_choice, strict_prompt,
291
+ beam_size=3, temperature=0.0, condition_on_previous_text=False)
292
+
293
+ # Compute WER and CER
294
+ try:
295
+ wer_val = jiwer.wer(intended_sentence, actual_text)
296
+ cer_val = jiwer.cer(intended_sentence, actual_text)
297
+ except:
298
+ wer_val = 1.0
299
+ cer_val = 1.0
300
+
301
+ # Transliteration of Pass 1 output
302
+ hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]"
303
+
304
+ # Highlight word-level and character-level differences
305
+ diff_html = highlight_differences(intended_sentence, actual_text)
306
+ char_html = char_level_highlight(intended_sentence, actual_text)
307
+
308
+ # Success status
309
+ status_msg = f"βœ… Analysis complete! WER: {wer_val:.2f}"
310
+
311
+ return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}",
312
+ diff_html, char_html, intended_sentence, status_msg)
313
+
314
+ except Exception as e:
315
+ error_msg = f"❌ Error: {str(e)}"
316
+ clear_gpu_memory()
317
+ return ("Error occurred", "", "", "", "", "", "", "", error_msg)
318
 
319
  # ---------------- UI ---------------- #
320
  with gr.Blocks(title="Pronunciation Comparator") as demo:
321
  gr.Markdown("## πŸŽ™ Pronunciation Comparator - English, Tamil & Malayalam")
322
  gr.Markdown("Practice pronunciation with specialized Whisper models for each language!")
323
+ gr.Markdown("⚠️ **Note**: Models load on-demand to optimize memory usage. First use may take longer.")
324
 
325
  with gr.Row():
326
  lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
327
  gen_btn = gr.Button("🎲 Generate Sentence")
328
 
329
  intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
330
+
331
+ # Status indicator
332
+ status_display = gr.Textbox(label="Status", interactive=False, value="🟒 Ready")
333
 
334
  with gr.Row():
335
  audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation")
336
 
337
  with gr.Column():
338
+ gr.Markdown("### βš™οΈ Transcription Parameters")
339
+ with gr.Row():
340
+ pass1_beam = gr.Slider(1, 4, value=2, step=1, label="Beam Size (lower = faster)")
341
+ pass1_temp = gr.Slider(0.0, 0.8, value=0.2, step=0.1, label="Temperature")
342
+ pass1_condition = gr.Checkbox(value=False, label="Condition on previous text")
343
 
344
+ submit_btn = gr.Button("πŸ” Analyze Pronunciation", variant="primary", size="lg")
345
 
346
  gr.Markdown("### πŸ“Š Analysis Results")
347
  with gr.Row():
 
367
  inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
368
  outputs=[
369
  pass1_out, pass2_out, hk_out, wer_out, cer_out,
370
+ diff_html_box, char_html_box, intended_display, status_display
371
  ]
372
  )
373