MusIre commited on
Commit
78cc121
·
1 Parent(s): 038e82c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -5,6 +5,7 @@ subprocess.run(["pip", "install", "torch", "torchvision", "torchaudio", "-f", "h
5
 
6
  import gradio as gr
7
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
8
 
9
  # Load model and processor
10
  processor = WhisperProcessor.from_pretrained("openai/whisper-large")
@@ -13,8 +14,16 @@ model.config.forced_decoder_ids = None
13
 
14
  # Function to perform ASR on audio data
15
  def transcribe_audio(audio_data):
 
 
 
 
 
 
 
 
16
  # Apply custom preprocessing to the audio data if needed
17
- processed_input = processor(audio_data, return_tensors="pt").input_features
18
 
19
  # Generate token ids
20
  predicted_ids = model.generate(processed_input)
 
5
 
6
  import gradio as gr
7
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
+ import torchaudio
9
 
10
  # Load model and processor
11
  processor = WhisperProcessor.from_pretrained("openai/whisper-large")
 
14
 
15
  # Function to perform ASR on audio data
16
  def transcribe_audio(audio_data):
17
+ # Convert audio data to mono and normalize
18
+ audio_data = torchaudio.functional.to_mono(audio_data)
19
+ audio_data = torchaudio.functional.gain(audio_data, gain_db=5.0)
20
+
21
+ # Resample if needed (Whisper model requires 16 kHz sampling rate)
22
+ if audio_data[1] != 16000:
23
+ audio_data = torchaudio.transforms.Resample(audio_data[1], 16000)(audio_data[0])
24
+
25
  # Apply custom preprocessing to the audio data if needed
26
+ processed_input = processor(audio_data[0].numpy(), return_tensors="pt").input_features
27
 
28
  # Generate token ids
29
  predicted_ids = model.generate(processed_input)