File size: 1,886 Bytes
f9bbc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80317ce
 
 
f9bbc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03f8630
f9bbc64
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import os
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from deepmultilingualpunctuation import PunctuationModel

MODEL = "xTorch8/fine-tuned-mms"
TOKEN = os.getenv("TOKEN")

model = Wav2Vec2ForCTC.from_pretrained(MODEL, token = TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL, token = TOKEN)

torchaudio.set_audio_backend("soundfile")

language_model = PunctuationModel()

def transcription(audio_stream, is_video = False):
    try:
        if isinstance(audio_stream, tuple):
            audio_stream = audio_stream[0]

        if is_video:
            waveform, sample_rate = torchaudio.load(audio_stream, format = "wav")
        else:
            waveform, sample_rate = torchaudio.load(audio_stream)


        target_sample_rate = 16000
        if sample_rate != target_sample_rate:
            transform = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = target_sample_rate)
            waveform = transform(waveform)

        input_values = processor(waveform.squeeze().numpy(), return_tensors = "pt", sampling_rate = target_sample_rate).input_values

        with torch.no_grad():
            logits = model(input_values).logits

        predicted_ids = torch.argmax(logits, dim = -1)
        transcription = processor.batch_decode(predicted_ids)[0]
        transcription = language_model.restore_punctuation(transcription)

        return transcription
    except Exception as e:
        return e
    
demo = gr.Interface(
    fn = transcription,
    inputs = [
        gr.Audio(label = "Upload Audio/Video", type="filepath"),
        gr.Checkbox(label = "Is this a video file?")
    ],
    outputs = gr.Textbox(label = "Transcription Output"),
    title = "MMS Audio/Video Transcription",
    allow_flagging = "never"
)

if __name__ == "__main__":
    demo.launch()
# Trigger rebuild