Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
import torchaudio | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
from deepmultilingualpunctuation import PunctuationModel | |
MODEL = "xTorch8/fine-tuned-mms" | |
TOKEN = os.getenv("TOKEN") | |
model = Wav2Vec2ForCTC.from_pretrained(MODEL, token = TOKEN) | |
processor = Wav2Vec2Processor.from_pretrained(MODEL, token = TOKEN) | |
torchaudio.set_audio_backend("soundfile") | |
language_model = PunctuationModel() | |
def transcription(audio_stream, is_video = False): | |
try: | |
if isinstance(audio_stream, tuple): | |
audio_stream = audio_stream[0] | |
if is_video: | |
waveform, sample_rate = torchaudio.load(audio_stream, format = "wav") | |
else: | |
waveform, sample_rate = torchaudio.load(audio_stream) | |
target_sample_rate = 16000 | |
if sample_rate != target_sample_rate: | |
transform = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = target_sample_rate) | |
waveform = transform(waveform) | |
input_values = processor(waveform.squeeze().numpy(), return_tensors = "pt", sampling_rate = target_sample_rate).input_values | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim = -1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
transcription = language_model.restore_punctuation(transcription) | |
return transcription | |
except Exception as e: | |
return e | |
demo = gr.Interface( | |
fn = transcription, | |
inputs = [ | |
gr.Audio(label = "Upload Audio/Video", type="filepath"), | |
gr.Checkbox(label = "Is this a video file?") | |
], | |
outputs = gr.Textbox(label = "Transcription Output"), | |
title = "MMS Audio/Video Transcription", | |
allow_flagging = "never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |
# Trigger rebuild |