MusIre commited on
Commit
52fc07d
·
1 Parent(s): 2331d07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -3,6 +3,7 @@ subprocess.run(["pip", "install", "gradio", "--upgrade"])
3
  subprocess.run(["pip", "install", "transformers"])
4
  subprocess.run(["pip", "install", "torchaudio", "--upgrade"])
5
 
 
6
  import gradio as gr
7
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
8
  import torchaudio
@@ -13,22 +14,29 @@ model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-it
13
 
14
  # Function to perform ASR on audio data
15
  def transcribe_audio(audio_data):
16
- # Convert audio data to mono and normalize
17
- audio_data = torchaudio.transforms.Resample(audio_data[1], 16000)(audio_data[0])
18
- audio_data = torchaudio.functional.gain(audio_data, gain_db=5.0)
 
 
 
 
 
 
 
19
 
20
- # Apply custom preprocessing to the audio data if needed
21
- input_values = processor(audio_data[0].numpy(), return_tensors="pt").input_values
 
22
 
23
- # Perform ASR
24
- with torch.no_grad():
25
- logits = model(input_values).logits
26
 
27
- # Decode the output
28
- predicted_ids = torch.argmax(logits, dim=-1)
29
- transcription = processor.batch_decode(predicted_ids)
30
 
31
- return transcription[0]
 
32
 
33
  # Create Gradio interface
34
  audio_input = gr.Audio()
 
3
  subprocess.run(["pip", "install", "transformers"])
4
  subprocess.run(["pip", "install", "torchaudio", "--upgrade"])
5
 
6
+
7
  import gradio as gr
8
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
9
  import torchaudio
 
14
 
15
  # Function to perform ASR on audio data
16
  def transcribe_audio(audio_data):
17
+ if audio_data is None:
18
+ return "No audio data received."
19
+
20
+ try:
21
+ # Convert audio data to mono and normalize
22
+ audio_data = torchaudio.transforms.Resample(audio_data[1], 16000)(audio_data[0])
23
+ audio_data = torchaudio.functional.gain(audio_data, gain_db=5.0)
24
+
25
+ # Apply custom preprocessing to the audio data if needed
26
+ input_values = processor(audio_data[0].numpy(), return_tensors="pt").input_values
27
 
28
+ # Perform ASR
29
+ with torch.no_grad():
30
+ logits = model(input_values).logits
31
 
32
+ # Decode the output
33
+ predicted_ids = torch.argmax(logits, dim=-1)
34
+ transcription = processor.batch_decode(predicted_ids)
35
 
36
+ return transcription[0]
 
 
37
 
38
+ except Exception as e:
39
+ return f"An error occurred: {str(e)}"
40
 
41
  # Create Gradio interface
42
  audio_input = gr.Audio()