tzu commited on
Commit
2c82342
·
1 Parent(s): 2fd576f

Create new file

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+
5
+
6
+ #def predict(image):
7
+ # predictions = pipeline(image)
8
+ # return {p["label"]: p["score"] for p in predictions}
9
+
10
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
11
+ from datasets import load_dataset
12
+ import torch
13
+
14
+ def predict(speech):
15
+ # load model and tokenizer
16
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
17
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
18
+
19
+ #pipeline = pipeline(task="speech-classification", model="facebook/wav2vec2-base-960h")
20
+
21
+ # load dummy dataset and read soundfiles
22
+ ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
23
+
24
+ # tokenize
25
+ input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values # Batch size 1
26
+
27
+ # retrieve logits
28
+ logits = model(input_values).logits
29
+
30
+ # take argmax and decode
31
+ predicted_ids = torch.argmax(logits, dim=-1)
32
+ transcription = processor.batch_decode(predicted_ids)
33
+ return transcription
34
+
35
+ demo = gr.Interface(fn=speech, inputs="text", outputs="text")
36
+
37
+ demo.launch()
38
+
39
+
40
+ #gr.Interface(
41
+ # predict,
42
+ # inputs=gr.inputs.speech(label="Upload", type="filepath"),
43
+ # outputs=gr.outputs.Label(num_top_classes=2),
44
+ # title="Audio",
45
+ #).launch()