Spaces:
Sleeping
Sleeping
File size: 1,886 Bytes
f9bbc64 80317ce f9bbc64 03f8630 f9bbc64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
import os
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from deepmultilingualpunctuation import PunctuationModel
MODEL = "xTorch8/fine-tuned-mms"
TOKEN = os.getenv("TOKEN")
model = Wav2Vec2ForCTC.from_pretrained(MODEL, token = TOKEN)
processor = Wav2Vec2Processor.from_pretrained(MODEL, token = TOKEN)
torchaudio.set_audio_backend("soundfile")
language_model = PunctuationModel()
def transcription(audio_stream, is_video = False):
try:
if isinstance(audio_stream, tuple):
audio_stream = audio_stream[0]
if is_video:
waveform, sample_rate = torchaudio.load(audio_stream, format = "wav")
else:
waveform, sample_rate = torchaudio.load(audio_stream)
target_sample_rate = 16000
if sample_rate != target_sample_rate:
transform = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = target_sample_rate)
waveform = transform(waveform)
input_values = processor(waveform.squeeze().numpy(), return_tensors = "pt", sampling_rate = target_sample_rate).input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim = -1)
transcription = processor.batch_decode(predicted_ids)[0]
transcription = language_model.restore_punctuation(transcription)
return transcription
except Exception as e:
return e
demo = gr.Interface(
fn = transcription,
inputs = [
gr.Audio(label = "Upload Audio/Video", type="filepath"),
gr.Checkbox(label = "Is this a video file?")
],
outputs = gr.Textbox(label = "Transcription Output"),
title = "MMS Audio/Video Transcription",
allow_flagging = "never"
)
if __name__ == "__main__":
demo.launch()
# Trigger rebuild |