Update app.py
Browse files
app.py
CHANGED
@@ -156,14 +156,17 @@ if audio_file:
|
|
156 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
157 |
|
158 |
# Convert audio to model input
|
159 |
-
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
|
160 |
|
161 |
# ✅ FIX: Ensure input tensor is correctly formatted
|
162 |
-
input_tensor = input_features.
|
|
|
|
|
|
|
163 |
|
164 |
# Perform ASR inference
|
165 |
with torch.no_grad():
|
166 |
-
logits = model(input_tensor).logits
|
167 |
predicted_ids = torch.argmax(logits, dim=-1)
|
168 |
transcription = processor.batch_decode(predicted_ids)[0]
|
169 |
|
|
|
156 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
157 |
|
158 |
# Convert audio to model input
|
159 |
+
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
|
160 |
|
161 |
# ✅ FIX: Ensure input tensor is correctly formatted
|
162 |
+
input_tensor = input_features.to(device) # Move to GPU/CPU
|
163 |
+
|
164 |
+
# ✅ FIX: Provide decoder_input_ids
|
165 |
+
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
166 |
|
167 |
# Perform ASR inference
|
168 |
with torch.no_grad():
|
169 |
+
logits = model(input_tensor, decoder_input_ids=decoder_input_ids).logits
|
170 |
predicted_ids = torch.argmax(logits, dim=-1)
|
171 |
transcription = processor.batch_decode(predicted_ids)[0]
|
172 |
|