Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -253,27 +253,47 @@ def transcribe_audio(audio_path, language, initial_prompt="", force_language=Tru
|
|
253 |
|
254 |
# Generate transcription
|
255 |
with torch.no_grad():
|
256 |
-
#
|
257 |
generate_kwargs = {
|
258 |
"input_features": input_features,
|
259 |
"max_length": 200,
|
260 |
-
"num_beams":
|
261 |
-
"temperature": 0.0,
|
262 |
"do_sample": False
|
263 |
}
|
264 |
|
265 |
-
#
|
266 |
-
if
|
267 |
lang_code = LANG_CODES.get(language, "en")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
try:
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
)
|
273 |
-
generate_kwargs["
|
274 |
-
|
275 |
-
|
|
|
276 |
|
|
|
277 |
predicted_ids = model.generate(**generate_kwargs)
|
278 |
|
279 |
# Decode
|
@@ -283,11 +303,29 @@ def transcribe_audio(audio_path, language, initial_prompt="", force_language=Tru
|
|
283 |
clean_up_tokenization_spaces=True
|
284 |
)[0]
|
285 |
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
except Exception as e:
|
289 |
print(f"Transcription error for {language}: {e}")
|
290 |
-
return f"Error:
|
291 |
|
292 |
def highlight_differences(ref, hyp):
|
293 |
"""Highlight word-level differences with better styling"""
|
|
|
253 |
|
254 |
# Generate transcription
|
255 |
with torch.no_grad():
|
256 |
+
# Basic generation parameters
|
257 |
generate_kwargs = {
|
258 |
"input_features": input_features,
|
259 |
"max_length": 200,
|
260 |
+
"num_beams": 3, # Reduced for better compatibility
|
|
|
261 |
"do_sample": False
|
262 |
}
|
263 |
|
264 |
+
# Try different approaches for language forcing
|
265 |
+
if force_language and language != "English":
|
266 |
lang_code = LANG_CODES.get(language, "en")
|
267 |
+
|
268 |
+
# Method 1: Try forced_decoder_ids (OpenAI Whisper style)
|
269 |
+
try:
|
270 |
+
if hasattr(processor, 'get_decoder_prompt_ids'):
|
271 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
272 |
+
language=lang_code,
|
273 |
+
task="transcribe"
|
274 |
+
)
|
275 |
+
# Test if model accepts this parameter
|
276 |
+
test_kwargs = generate_kwargs.copy()
|
277 |
+
test_kwargs["max_length"] = 10
|
278 |
+
test_kwargs["forced_decoder_ids"] = forced_decoder_ids
|
279 |
+
_ = model.generate(**test_kwargs) # Test run
|
280 |
+
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
|
281 |
+
print(f"✅ Using forced_decoder_ids for {language}")
|
282 |
+
except Exception as e:
|
283 |
+
print(f"⚠️ forced_decoder_ids not supported: {e}")
|
284 |
+
|
285 |
+
# Method 2: Try language parameter
|
286 |
try:
|
287 |
+
test_kwargs = generate_kwargs.copy()
|
288 |
+
test_kwargs["max_length"] = 10
|
289 |
+
test_kwargs["language"] = lang_code
|
290 |
+
_ = model.generate(**test_kwargs) # Test run
|
291 |
+
generate_kwargs["language"] = lang_code
|
292 |
+
print(f"✅ Using language parameter for {language}")
|
293 |
+
except Exception as e:
|
294 |
+
print(f"⚠️ language parameter not supported: {e}")
|
295 |
|
296 |
+
# Generate with whatever parameters work
|
297 |
predicted_ids = model.generate(**generate_kwargs)
|
298 |
|
299 |
# Decode
|
|
|
303 |
clean_up_tokenization_spaces=True
|
304 |
)[0]
|
305 |
|
306 |
+
# Post-process transcription
|
307 |
+
transcription = transcription.strip()
|
308 |
+
|
309 |
+
# If we get empty transcription, try again with simpler parameters
|
310 |
+
if not transcription and generate_kwargs.get("num_beams", 1) > 1:
|
311 |
+
print("🔄 Retrying with greedy decoding...")
|
312 |
+
simple_kwargs = {
|
313 |
+
"input_features": input_features,
|
314 |
+
"max_length": 200,
|
315 |
+
"do_sample": False
|
316 |
+
}
|
317 |
+
predicted_ids = model.generate(**simple_kwargs)
|
318 |
+
transcription = processor.batch_decode(
|
319 |
+
predicted_ids,
|
320 |
+
skip_special_tokens=True,
|
321 |
+
clean_up_tokenization_spaces=True
|
322 |
+
)[0].strip()
|
323 |
+
|
324 |
+
return transcription or "(No transcription generated)"
|
325 |
|
326 |
except Exception as e:
|
327 |
print(f"Transcription error for {language}: {e}")
|
328 |
+
return f"Error: {str(e)[:150]}..."
|
329 |
|
330 |
def highlight_differences(ref, hyp):
|
331 |
"""Highlight word-level differences with better styling"""
|