Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -65,17 +65,51 @@ SENTENCE_BANK = {
|
|
65 |
}
|
66 |
|
67 |
# Global variables for models (will be loaded lazily)
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def load_model(language_choice):
|
72 |
-
"""Load model for specific language
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
print(f"{language_choice} model loaded successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# ---------------- HELPERS ---------------- #
|
81 |
def get_random_sentence(language_choice):
|
@@ -95,38 +129,50 @@ def transliterate_to_hk(text, lang_choice):
|
|
95 |
|
96 |
@spaces.GPU
|
97 |
def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
def highlight_differences(ref, hyp):
|
132 |
ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
|
@@ -226,54 +272,76 @@ def char_level_highlight(ref, hyp):
|
|
226 |
def compare_pronunciation(audio, language_choice, intended_sentence,
|
227 |
pass1_beam, pass1_temp, pass1_condition):
|
228 |
if audio is None or not intended_sentence.strip():
|
229 |
-
return ("No audio or intended sentence.", "", "", "", "", "", "", "")
|
230 |
|
231 |
-
|
|
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
# ---------------- UI ---------------- #
|
257 |
with gr.Blocks(title="Pronunciation Comparator") as demo:
|
258 |
gr.Markdown("## π Pronunciation Comparator - English, Tamil & Malayalam")
|
259 |
gr.Markdown("Practice pronunciation with specialized Whisper models for each language!")
|
|
|
260 |
|
261 |
with gr.Row():
|
262 |
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
|
263 |
gen_btn = gr.Button("π² Generate Sentence")
|
264 |
|
265 |
intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
|
|
|
|
|
|
|
266 |
|
267 |
with gr.Row():
|
268 |
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation")
|
269 |
|
270 |
with gr.Column():
|
271 |
-
gr.Markdown("### Transcription Parameters")
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
275 |
|
276 |
-
submit_btn = gr.Button("π Analyze Pronunciation", variant="primary")
|
277 |
|
278 |
gr.Markdown("### π Analysis Results")
|
279 |
with gr.Row():
|
@@ -299,7 +367,7 @@ with gr.Blocks(title="Pronunciation Comparator") as demo:
|
|
299 |
inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
|
300 |
outputs=[
|
301 |
pass1_out, pass2_out, hk_out, wer_out, cer_out,
|
302 |
-
diff_html_box, char_html_box, intended_display
|
303 |
]
|
304 |
)
|
305 |
|
|
|
65 |
}
|
66 |
|
67 |
# Global variables for models (will be loaded lazily)
|
68 |
+
current_model = None
|
69 |
+
current_processor = None
|
70 |
+
current_language = None
|
71 |
+
|
72 |
+
def clear_gpu_memory():
|
73 |
+
"""Clear GPU memory to prevent OOM errors"""
|
74 |
+
if torch.cuda.is_available():
|
75 |
+
torch.cuda.empty_cache()
|
76 |
|
77 |
def load_model(language_choice):
|
78 |
+
"""Load model for specific language, unload previous if different"""
|
79 |
+
global current_model, current_processor, current_language
|
80 |
+
|
81 |
+
if current_language == language_choice and current_model is not None:
|
82 |
+
return current_model, current_processor
|
83 |
+
|
84 |
+
# Clear previous model if different language
|
85 |
+
if current_model is not None:
|
86 |
+
print(f"Unloading previous model for {current_language}")
|
87 |
+
del current_model
|
88 |
+
del current_processor
|
89 |
+
clear_gpu_memory()
|
90 |
+
|
91 |
+
# Load new model
|
92 |
+
model_id = MODEL_CONFIGS[language_choice]
|
93 |
+
print(f"Loading {language_choice} model: {model_id}")
|
94 |
+
|
95 |
+
try:
|
96 |
+
current_processor = WhisperProcessor.from_pretrained(model_id)
|
97 |
+
current_model = WhisperForConditionalGeneration.from_pretrained(
|
98 |
+
model_id,
|
99 |
+
torch_dtype=torch.float16, # Use half precision to save memory
|
100 |
+
device_map="auto"
|
101 |
+
)
|
102 |
+
current_language = language_choice
|
103 |
print(f"{language_choice} model loaded successfully!")
|
104 |
+
return current_model, current_processor
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Error loading model: {e}")
|
108 |
+
# Fallback to CPU if GPU fails
|
109 |
+
current_processor = WhisperProcessor.from_pretrained(model_id)
|
110 |
+
current_model = WhisperForConditionalGeneration.from_pretrained(model_id)
|
111 |
+
current_language = language_choice
|
112 |
+
return current_model, current_processor
|
113 |
|
114 |
# ---------------- HELPERS ---------------- #
|
115 |
def get_random_sentence(language_choice):
|
|
|
129 |
|
130 |
@spaces.GPU
|
131 |
def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
|
132 |
+
try:
|
133 |
+
# Load model if not already loaded
|
134 |
+
model, processor = load_model(language_choice)
|
135 |
+
lang_code = LANG_CODES[language_choice]
|
136 |
+
|
137 |
+
# Load and process audio
|
138 |
+
import librosa
|
139 |
+
audio, sr = librosa.load(audio_path, sr=16000)
|
140 |
+
|
141 |
+
# Process audio with the specific model's processor
|
142 |
+
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
|
143 |
+
|
144 |
+
# Move to GPU if available
|
145 |
+
if torch.cuda.is_available():
|
146 |
+
input_features = input_features.to("cuda")
|
147 |
+
|
148 |
+
# Generate forced decoder ids for the language
|
149 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language=lang_code, task="transcribe")
|
150 |
+
|
151 |
+
# Generate transcription with memory-efficient settings
|
152 |
+
with torch.no_grad():
|
153 |
+
predicted_ids = model.generate(
|
154 |
+
input_features,
|
155 |
+
forced_decoder_ids=forced_decoder_ids,
|
156 |
+
max_length=200, # Reduced max length to save memory
|
157 |
+
num_beams=min(beam_size, 4), # Limit beam size for memory
|
158 |
+
temperature=temperature if temperature > 0 else None,
|
159 |
+
do_sample=temperature > 0,
|
160 |
+
no_repeat_ngram_size=2,
|
161 |
+
early_stopping=True
|
162 |
+
)
|
163 |
+
|
164 |
+
# Decode the transcription
|
165 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
166 |
+
|
167 |
+
# Clear GPU cache after inference
|
168 |
+
clear_gpu_memory()
|
169 |
+
|
170 |
+
return transcription.strip()
|
171 |
+
|
172 |
+
except Exception as e:
|
173 |
+
print(f"Transcription error: {e}")
|
174 |
+
clear_gpu_memory()
|
175 |
+
return f"Error during transcription: {str(e)}"
|
176 |
|
177 |
def highlight_differences(ref, hyp):
|
178 |
ref_words, hyp_words = ref.strip().split(), hyp.strip().split()
|
|
|
272 |
def compare_pronunciation(audio, language_choice, intended_sentence,
|
273 |
pass1_beam, pass1_temp, pass1_condition):
|
274 |
if audio is None or not intended_sentence.strip():
|
275 |
+
return ("No audio or intended sentence.", "", "", "", "", "", "", "", "β Please provide audio and sentence")
|
276 |
|
277 |
+
try:
|
278 |
+
primer_weak, primer_strong = LANG_PRIMERS[language_choice]
|
279 |
|
280 |
+
# Pass 1: raw transcription with user-configured decoding parameters
|
281 |
+
status_msg = f"π Transcribing with {language_choice} model..."
|
282 |
+
actual_text = transcribe_once(audio, language_choice, primer_weak,
|
283 |
+
pass1_beam, pass1_temp, pass1_condition)
|
284 |
+
|
285 |
+
if actual_text.startswith("Error"):
|
286 |
+
return (actual_text, "", "", "", "", "", "", "", "β Transcription failed")
|
287 |
+
|
288 |
+
# Pass 2: strict transcription biased by intended sentence (fixed decoding params)
|
289 |
+
strict_prompt = f"{primer_strong}\nTarget: {intended_sentence}"
|
290 |
+
corrected_text = transcribe_once(audio, language_choice, strict_prompt,
|
291 |
+
beam_size=3, temperature=0.0, condition_on_previous_text=False)
|
292 |
+
|
293 |
+
# Compute WER and CER
|
294 |
+
try:
|
295 |
+
wer_val = jiwer.wer(intended_sentence, actual_text)
|
296 |
+
cer_val = jiwer.cer(intended_sentence, actual_text)
|
297 |
+
except:
|
298 |
+
wer_val = 1.0
|
299 |
+
cer_val = 1.0
|
300 |
+
|
301 |
+
# Transliteration of Pass 1 output
|
302 |
+
hk_translit = transliterate_to_hk(actual_text, language_choice) if is_script(actual_text, language_choice) else f"[Script mismatch: expected {language_choice}]"
|
303 |
+
|
304 |
+
# Highlight word-level and character-level differences
|
305 |
+
diff_html = highlight_differences(intended_sentence, actual_text)
|
306 |
+
char_html = char_level_highlight(intended_sentence, actual_text)
|
307 |
+
|
308 |
+
# Success status
|
309 |
+
status_msg = f"β
Analysis complete! WER: {wer_val:.2f}"
|
310 |
+
|
311 |
+
return (actual_text, corrected_text, hk_translit, f"{wer_val:.2f}", f"{cer_val:.2f}",
|
312 |
+
diff_html, char_html, intended_sentence, status_msg)
|
313 |
+
|
314 |
+
except Exception as e:
|
315 |
+
error_msg = f"β Error: {str(e)}"
|
316 |
+
clear_gpu_memory()
|
317 |
+
return ("Error occurred", "", "", "", "", "", "", "", error_msg)
|
318 |
|
319 |
# ---------------- UI ---------------- #
|
320 |
with gr.Blocks(title="Pronunciation Comparator") as demo:
|
321 |
gr.Markdown("## π Pronunciation Comparator - English, Tamil & Malayalam")
|
322 |
gr.Markdown("Practice pronunciation with specialized Whisper models for each language!")
|
323 |
+
gr.Markdown("β οΈ **Note**: Models load on-demand to optimize memory usage. First use may take longer.")
|
324 |
|
325 |
with gr.Row():
|
326 |
lang_choice = gr.Dropdown(choices=list(LANG_CODES.keys()), value="Malayalam", label="Language")
|
327 |
gen_btn = gr.Button("π² Generate Sentence")
|
328 |
|
329 |
intended_display = gr.Textbox(label="Generated Sentence (Read aloud)", interactive=False)
|
330 |
+
|
331 |
+
# Status indicator
|
332 |
+
status_display = gr.Textbox(label="Status", interactive=False, value="π’ Ready")
|
333 |
|
334 |
with gr.Row():
|
335 |
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record your pronunciation")
|
336 |
|
337 |
with gr.Column():
|
338 |
+
gr.Markdown("### βοΈ Transcription Parameters")
|
339 |
+
with gr.Row():
|
340 |
+
pass1_beam = gr.Slider(1, 4, value=2, step=1, label="Beam Size (lower = faster)")
|
341 |
+
pass1_temp = gr.Slider(0.0, 0.8, value=0.2, step=0.1, label="Temperature")
|
342 |
+
pass1_condition = gr.Checkbox(value=False, label="Condition on previous text")
|
343 |
|
344 |
+
submit_btn = gr.Button("π Analyze Pronunciation", variant="primary", size="lg")
|
345 |
|
346 |
gr.Markdown("### π Analysis Results")
|
347 |
with gr.Row():
|
|
|
367 |
inputs=[audio_input, lang_choice, intended_display, pass1_beam, pass1_temp, pass1_condition],
|
368 |
outputs=[
|
369 |
pass1_out, pass2_out, hk_out, wer_out, cer_out,
|
370 |
+
diff_html_box, char_html_box, intended_display, status_display
|
371 |
]
|
372 |
)
|
373 |
|