MusIre commited on
Commit
f47a9e0
·
1 Parent(s): 78cc121

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -1,16 +1,14 @@
1
  import subprocess
2
- subprocess.run(["pip", "install", "datasets"])
3
  subprocess.run(["pip", "install", "transformers"])
4
- subprocess.run(["pip", "install", "torch", "torchvision", "torchaudio", "-f", "https://download.pytorch.org/whl/torch_stable.html"])
5
 
6
  import gradio as gr
7
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
  import torchaudio
9
 
10
  # Load model and processor
11
- processor = WhisperProcessor.from_pretrained("openai/whisper-large")
12
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
13
- model.config.forced_decoder_ids = None
14
 
15
  # Function to perform ASR on audio data
16
  def transcribe_audio(audio_data):
@@ -18,18 +16,20 @@ def transcribe_audio(audio_data):
18
  audio_data = torchaudio.functional.to_mono(audio_data)
19
  audio_data = torchaudio.functional.gain(audio_data, gain_db=5.0)
20
 
21
- # Resample if needed (Whisper model requires 16 kHz sampling rate)
22
  if audio_data[1] != 16000:
23
  audio_data = torchaudio.transforms.Resample(audio_data[1], 16000)(audio_data[0])
24
 
25
  # Apply custom preprocessing to the audio data if needed
26
- processed_input = processor(audio_data[0].numpy(), return_tensors="pt").input_features
27
 
28
- # Generate token ids
29
- predicted_ids = model.generate(processed_input)
 
30
 
31
- # Decode token ids to text
32
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
 
33
 
34
  return transcription[0]
35
 
 
1
  import subprocess
 
2
  subprocess.run(["pip", "install", "transformers"])
3
+ subprocess.run(["pip", "install", "torchaudio"])
4
 
5
  import gradio as gr
6
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
7
  import torchaudio
8
 
9
  # Load model and processor
10
+ processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-italian")
11
+ model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-italian")
 
12
 
13
  # Function to perform ASR on audio data
14
  def transcribe_audio(audio_data):
 
16
  audio_data = torchaudio.functional.to_mono(audio_data)
17
  audio_data = torchaudio.functional.gain(audio_data, gain_db=5.0)
18
 
19
+ # Resample if needed (Wav2Vec2 model requires 16 kHz sampling rate)
20
  if audio_data[1] != 16000:
21
  audio_data = torchaudio.transforms.Resample(audio_data[1], 16000)(audio_data[0])
22
 
23
  # Apply custom preprocessing to the audio data if needed
24
+ input_values = processor(audio_data[0].numpy(), return_tensors="pt").input_values
25
 
26
+ # Perform ASR
27
+ with torch.no_grad():
28
+ logits = model(input_values).logits
29
 
30
+ # Decode the output
31
+ predicted_ids = torch.argmax(logits, dim=-1)
32
+ transcription = processor.batch_decode(predicted_ids)
33
 
34
  return transcription[0]
35