sudhanm commited on
Commit
386695f
·
verified ·
1 Parent(s): 663c2a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -14
app.py CHANGED
@@ -253,27 +253,47 @@ def transcribe_audio(audio_path, language, initial_prompt="", force_language=Tru
253
 
254
  # Generate transcription
255
  with torch.no_grad():
256
- # Set generation parameters
257
  generate_kwargs = {
258
  "input_features": input_features,
259
  "max_length": 200,
260
- "num_beams": 5,
261
- "temperature": 0.0,
262
  "do_sample": False
263
  }
264
 
265
- # Add language forcing if supported
266
- if hasattr(model.config, 'forced_decoder_ids') and force_language:
267
  lang_code = LANG_CODES.get(language, "en")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  try:
269
- forced_decoder_ids = processor.get_decoder_prompt_ids(
270
- language=lang_code,
271
- task="transcribe"
272
- )
273
- generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
274
- except:
275
- pass # Skip if not supported
 
276
 
 
277
  predicted_ids = model.generate(**generate_kwargs)
278
 
279
  # Decode
@@ -283,11 +303,29 @@ def transcribe_audio(audio_path, language, initial_prompt="", force_language=Tru
283
  clean_up_tokenization_spaces=True
284
  )[0]
285
 
286
- return transcription.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  except Exception as e:
289
  print(f"Transcription error for {language}: {e}")
290
- return f"Error: Transcription failed - {str(e)[:100]}"
291
 
292
  def highlight_differences(ref, hyp):
293
  """Highlight word-level differences with better styling"""
 
253
 
254
  # Generate transcription
255
  with torch.no_grad():
256
+ # Basic generation parameters
257
  generate_kwargs = {
258
  "input_features": input_features,
259
  "max_length": 200,
260
+ "num_beams": 3, # Reduced for better compatibility
 
261
  "do_sample": False
262
  }
263
 
264
+ # Try different approaches for language forcing
265
+ if force_language and language != "English":
266
  lang_code = LANG_CODES.get(language, "en")
267
+
268
+ # Method 1: Try forced_decoder_ids (OpenAI Whisper style)
269
+ try:
270
+ if hasattr(processor, 'get_decoder_prompt_ids'):
271
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
272
+ language=lang_code,
273
+ task="transcribe"
274
+ )
275
+ # Test if model accepts this parameter
276
+ test_kwargs = generate_kwargs.copy()
277
+ test_kwargs["max_length"] = 10
278
+ test_kwargs["forced_decoder_ids"] = forced_decoder_ids
279
+ _ = model.generate(**test_kwargs) # Test run
280
+ generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
281
+ print(f"✅ Using forced_decoder_ids for {language}")
282
+ except Exception as e:
283
+ print(f"⚠️ forced_decoder_ids not supported: {e}")
284
+
285
+ # Method 2: Try language parameter
286
  try:
287
+ test_kwargs = generate_kwargs.copy()
288
+ test_kwargs["max_length"] = 10
289
+ test_kwargs["language"] = lang_code
290
+ _ = model.generate(**test_kwargs) # Test run
291
+ generate_kwargs["language"] = lang_code
292
+ print(f"✅ Using language parameter for {language}")
293
+ except Exception as e:
294
+ print(f"⚠️ language parameter not supported: {e}")
295
 
296
+ # Generate with whatever parameters work
297
  predicted_ids = model.generate(**generate_kwargs)
298
 
299
  # Decode
 
303
  clean_up_tokenization_spaces=True
304
  )[0]
305
 
306
+ # Post-process transcription
307
+ transcription = transcription.strip()
308
+
309
+ # If we get empty transcription, try again with simpler parameters
310
+ if not transcription and generate_kwargs.get("num_beams", 1) > 1:
311
+ print("🔄 Retrying with greedy decoding...")
312
+ simple_kwargs = {
313
+ "input_features": input_features,
314
+ "max_length": 200,
315
+ "do_sample": False
316
+ }
317
+ predicted_ids = model.generate(**simple_kwargs)
318
+ transcription = processor.batch_decode(
319
+ predicted_ids,
320
+ skip_special_tokens=True,
321
+ clean_up_tokenization_spaces=True
322
+ )[0].strip()
323
+
324
+ return transcription or "(No transcription generated)"
325
 
326
  except Exception as e:
327
  print(f"Transcription error for {language}: {e}")
328
+ return f"Error: {str(e)[:150]}..."
329
 
330
  def highlight_differences(ref, hyp):
331
  """Highlight word-level differences with better styling"""