Spaces:
Runtime error
Runtime error
import gradio as gr | |
import transformers | |
#def predict(image): | |
# predictions = pipeline(image) | |
# return {p["label"]: p["score"] for p in predictions} | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
from datasets import load_dataset | |
import torch | |
def predict(speech): | |
# load model and tokenizer | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
#pipeline = pipeline(task="speech-classification", model="facebook/wav2vec2-base-960h") | |
# load dummy dataset and read soundfiles | |
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") | |
# tokenize | |
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values # Batch size 1 | |
# retrieve logits | |
logits = model(input_values).logits | |
# take argmax and decode | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
return transcription | |
demo = gr.Interface(fn=speech, inputs="text", outputs="text") | |
demo.launch() | |
#gr.Interface( | |
# predict, | |
# inputs=gr.inputs.speech(label="Upload", type="filepath"), | |
# outputs=gr.outputs.Label(num_top_classes=2), | |
# title="Audio", | |
#).launch() |