tahirsher commited on
Commit
2e48e3c
Β·
verified Β·
1 Parent(s): 81653fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -188,34 +188,25 @@ if audio_file:
188
  with open(audio_path, "wb") as f:
189
  f.write(audio_file.read())
190
 
191
- # βœ… Ensure Model Precision Matches Input
192
- if device == "cuda":
193
- model.half() # Use FP16 for speed on GPU
194
- else:
195
- model.float() # Ensure CPU uses FP32
196
-
197
- # βœ… Load and preprocess audio
198
  waveform, sample_rate = torchaudio.load(audio_path)
199
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
200
 
201
- # βœ… Convert to input format (Match FP16 for GPU)
202
- input_features = processor(
203
- waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
204
- ).input_features.to(device)
205
 
206
- if device == "cuda":
207
- input_features = input_features.half() # βœ… Convert to FP16
208
 
209
- # βœ… Optimized Inference
210
- with torch.inference_mode():
211
  generated_ids = model.generate(
212
- input_features,
213
- max_length=200, # ⏩ Reduced length for speed
214
- num_beams=2, # ⏩ Lower beams for faster decoding
215
- do_sample=False, # ⏩ Disables unnecessary sampling
216
- use_cache=True # βœ… Speeds up processing
217
  )
218
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
219
 
 
220
  st.success("πŸ“„ Transcription:")
221
  st.write(transcription)
 
 
188
  with open(audio_path, "wb") as f:
189
  f.write(audio_file.read())
190
 
 
 
 
 
 
 
 
191
  waveform, sample_rate = torchaudio.load(audio_path)
192
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
193
 
194
+ input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
 
 
 
195
 
196
+ input_tensor = input_features.to(device)
 
197
 
198
+ # βœ… FIX: Use `generate()` for Proper Transcription
199
+ with torch.no_grad():
200
  generated_ids = model.generate(
201
+ input_tensor,
202
+ max_length=500,
203
+ num_beams=5,
204
+ do_sample=True,
205
+ top_k=50
206
  )
207
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
208
 
209
+ # Display transcription
210
  st.success("πŸ“„ Transcription:")
211
  st.write(transcription)
212
+