MusIre commited on
Commit
cd65652
·
1 Parent(s): 7634d42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -50
app.py CHANGED
@@ -4,53 +4,31 @@ subprocess.run(["pip", "install", "transformers"])
4
  subprocess.run(["pip", "install", "torchaudio", "--upgrade"])
5
 
6
  import gradio as gr
7
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
8
- import torchaudio
9
- import torch
10
-
11
- # Load model and processor
12
- processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-italian")
13
- model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-italian")
14
-
15
- def preprocess_audio(audio_data):
16
- # Apply any custom preprocessing to the audio data here if needed
17
- return processor(audio_data, return_tensors="pt").input_features
18
-
19
- # Function to perform ASR on audio data
20
- def transcribe_audio(input_features):
21
- print("Received audio data:", input_features) # Debug print
22
-
23
- # Check if audio_data is None or not a tuple of length 2
24
- if audio_data is None or not isinstance(input_features, tuple) or len(input_features) != 2:
25
- return "Invalid audio data format."
26
-
27
- sample_rate, waveform = input_features
28
-
29
- # Check if waveform is None or not a NumPy array
30
- if waveform is None or not isinstance(waveform, torch.Tensor):
31
- return "Invalid audio data format."
32
-
33
- try:
34
- # Convert audio data to mono and normalize
35
- audio_data = torchaudio.transforms.Resample(sample_rate, 100000)(waveform)
36
- audio_data = torchaudio.functional.gain(input_features, gain_db=5.0)
37
-
38
- # Apply custom preprocessing to the audio data if needed
39
- input_values = processor(input_features[0], return_tensors="pt").input_values
40
-
41
- # Perform ASR
42
- with torch.no_grad():
43
- logits = model(input_values).logits
44
-
45
- # Decode the output
46
- predicted_ids = torch.argmax(logits, dim=-1)
47
- transcription = processor.batch_decode(predicted_ids)
48
-
49
- return transcription[0]
50
-
51
- except Exception as e:
52
- return f"An error occurred: {str(e)}"
53
-
54
- # Create Gradio interface
55
- audio_input = gr.Audio(sources=["microphone"])
56
- gr.Interface(fn=transcribe_audio, inputs=audio_input, outputs="text").launch()
 
4
  subprocess.run(["pip", "install", "torchaudio", "--upgrade"])
5
 
6
  import gradio as gr
7
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
+
9
+ # Load Whisper ASR model and processor
10
+ model_name = "openai/whisper-small"
11
+ processor = WhisperProcessor.from_pretrained(model_name)
12
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
13
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="italian", task="transcribe")
14
+
15
+ def transcribe_audio(input_audio):
16
+ # Process audio using the Whisper processor
17
+ input_features = processor(input_audio, return_tensors="pt").input_features
18
+
19
+ # Generate token ids
20
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
21
+
22
+ # Decode token ids to text
23
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
24
+
25
+ return transcription[0]
26
+
27
+ iface = gr.Interface(
28
+ fn=transcribe_audio,
29
+ inputs=gr.Audio(source="microphone", type="wav", label="Speak"),
30
+ outputs="text",
31
+ live=True
32
+ )
33
+
34
+ iface.launch()