Update app.py
Browse files
app.py
CHANGED
@@ -188,34 +188,25 @@ if audio_file:
|
|
188 |
with open(audio_path, "wb") as f:
|
189 |
f.write(audio_file.read())
|
190 |
|
191 |
-
# β
Ensure Model Precision Matches Input
|
192 |
-
if device == "cuda":
|
193 |
-
model.half() # Use FP16 for speed on GPU
|
194 |
-
else:
|
195 |
-
model.float() # Ensure CPU uses FP32
|
196 |
-
|
197 |
-
# β
Load and preprocess audio
|
198 |
waveform, sample_rate = torchaudio.load(audio_path)
|
199 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
200 |
|
201 |
-
|
202 |
-
input_features = processor(
|
203 |
-
waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt"
|
204 |
-
).input_features.to(device)
|
205 |
|
206 |
-
|
207 |
-
input_features = input_features.half() # β
Convert to FP16
|
208 |
|
209 |
-
# β
|
210 |
-
with torch.
|
211 |
generated_ids = model.generate(
|
212 |
-
|
213 |
-
max_length=
|
214 |
-
num_beams=
|
215 |
-
do_sample=
|
216 |
-
|
217 |
)
|
218 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
219 |
|
|
|
220 |
st.success("π Transcription:")
|
221 |
st.write(transcription)
|
|
|
|
188 |
with open(audio_path, "wb") as f:
|
189 |
f.write(audio_file.read())
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
waveform, sample_rate = torchaudio.load(audio_path)
|
192 |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
193 |
|
194 |
+
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
|
|
|
|
|
|
|
195 |
|
196 |
+
input_tensor = input_features.to(device)
|
|
|
197 |
|
198 |
+
# β
FIX: Use `generate()` for Proper Transcription
|
199 |
+
with torch.no_grad():
|
200 |
generated_ids = model.generate(
|
201 |
+
input_tensor,
|
202 |
+
max_length=500,
|
203 |
+
num_beams=5,
|
204 |
+
do_sample=True,
|
205 |
+
top_k=50
|
206 |
)
|
207 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
208 |
|
209 |
+
# Display transcription
|
210 |
st.success("π Transcription:")
|
211 |
st.write(transcription)
|
212 |
+
|