tahirsher commited on
Commit
8dd61a6
·
verified ·
1 Parent(s): 8d55ac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -156,14 +156,17 @@ if audio_file:
156
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
157
 
158
  # Convert audio to model input
159
- input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
160
 
161
  # ✅ FIX: Ensure input tensor is correctly formatted
162
- input_tensor = input_features.unsqueeze(0).to(device) # Adds batch dimension
 
 
 
163
 
164
  # Perform ASR inference
165
  with torch.no_grad():
166
- logits = model(input_tensor).logits
167
  predicted_ids = torch.argmax(logits, dim=-1)
168
  transcription = processor.batch_decode(predicted_ids)[0]
169
 
 
156
  waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
157
 
158
  # Convert audio to model input
159
+ input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
160
 
161
  # ✅ FIX: Ensure input tensor is correctly formatted
162
+ input_tensor = input_features.to(device) # Move to GPU/CPU
163
+
164
+ # ✅ FIX: Provide decoder_input_ids
165
+ decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
166
 
167
  # Perform ASR inference
168
  with torch.no_grad():
169
+ logits = model(input_tensor, decoder_input_ids=decoder_input_ids).logits
170
  predicted_ids = torch.argmax(logits, dim=-1)
171
  transcription = processor.batch_decode(predicted_ids)[0]
172