File size: 2,118 Bytes
88dc3ba
5d9c950
b2593bc
5d9c950
527e644
b2593bc
f47a9e0
78cc121
a1917fb
b2593bc
 
f47a9e0
 
b2593bc
7634d42
 
 
 
b2593bc
7634d42
 
22fbcf1
 
7634d42
a1917fb
52fc07d
7634d42
a1917fb
22fbcf1
 
 
 
 
52fc07d
81137f5
7634d42
52fc07d
 
7634d42
78cc121
52fc07d
 
 
038e82c
52fc07d
 
 
038e82c
52fc07d
038e82c
52fc07d
 
527e644
b2593bc
7634d42
038e82c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import subprocess
subprocess.run(["pip", "install", "gradio", "--upgrade"])
subprocess.run(["pip", "install", "transformers"])
subprocess.run(["pip", "install", "torchaudio", "--upgrade"])

import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torch

# Load model and processor
processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-italian")
model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-italian")

def preprocess_audio(audio_data):
    # Apply any custom preprocessing to the audio data here if needed
    return processor(audio_data, return_tensors="pt").input_features

# Function to perform ASR on audio data
def transcribe_audio(input_features):
    print("Received audio data:", input_features)  # Debug print

    # Check if audio_data is None or not a tuple of length 2
    if audio_data is None or not isinstance(input_features, tuple) or len(input_features) != 2:
        return "Invalid audio data format."

    sample_rate, waveform = input_features

    # Check if waveform is None or not a NumPy array
    if waveform is None or not isinstance(waveform, torch.Tensor):
        return "Invalid audio data format."

    try:
        # Convert audio data to mono and normalize
        audio_data = torchaudio.transforms.Resample(sample_rate, 100000)(waveform)
        audio_data = torchaudio.functional.gain(input_features, gain_db=5.0)

        # Apply custom preprocessing to the audio data if needed
        input_values = processor(input_features[0], return_tensors="pt").input_values

        # Perform ASR
        with torch.no_grad():
            logits = model(input_values).logits

        # Decode the output
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)

        return transcription[0]

    except Exception as e:
        return f"An error occurred: {str(e)}"

# Create Gradio interface
audio_input = gr.Audio(sources=["microphone"])
gr.Interface(fn=transcribe_audio, inputs=audio_input, outputs="text").launch()