VanYsa commited on
Commit
0d7c0a3
·
1 Parent(s): 4bb9fa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -4
app.py CHANGED
@@ -1,7 +1,164 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
+ import librosa
4
+ import os
5
+ import soundfile as sf
6
+ import tempfile
7
+ import uuid
8
 
9
+ import torch
 
10
 
11
+ from nemo.collections.asr.models import ASRModel
12
+ from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
13
+ from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
14
+
15
+ SAMPLE_RATE = 16000 # Hz
16
+ MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
17
+
18
+ model = ASRModel.from_pretrained("nvidia/canary-1b")
19
+ model.eval()
20
+
21
+ # make sure beam size always 1 for consistency
22
+ model.change_decoding_strategy(None)
23
+ decoding_cfg = model.cfg.decoding
24
+ decoding_cfg.beam.beam_size = 1
25
+ model.change_decoding_strategy(decoding_cfg)
26
+
27
+ # setup for buffered inference
28
+ model.cfg.preprocessor.dither = 0.0
29
+ model.cfg.preprocessor.pad_to = 0
30
+
31
+ feature_stride = model.cfg.preprocessor['window_stride']
32
+ model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
33
+
34
+ frame_asr = FrameBatchMultiTaskAED(
35
+ asr_model=model,
36
+ frame_len=40.0,
37
+ total_buffer=40.0,
38
+ batch_size=16,
39
+ )
40
+
41
+ amp_dtype = torch.float16
42
+
43
+ def convert_audio(audio_filepath, tmpdir, utt_id):
44
+ """
45
+ Convert all files to monochannel 16 kHz wav files.
46
+ Do not convert and raise error if audio too long.
47
+ Returns output filename and duration.
48
+ """
49
+
50
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
51
+
52
+ duration = librosa.get_duration(y=data, sr=sr)
53
+
54
+ if duration / 60.0 > MAX_AUDIO_MINUTES:
55
+ raise gr.Error(
56
+ f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
57
+ "If you wish, you may trim the audio using the Audio viewer in Step 1 "
58
+ "(click on the scissors icon to start trimming audio)."
59
+ )
60
+
61
+ if sr != SAMPLE_RATE:
62
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
63
+
64
+ out_filename = os.path.join(tmpdir, utt_id + '.wav')
65
+
66
+ # save output audio
67
+ sf.write(out_filename, data, SAMPLE_RATE)
68
+
69
+ return out_filename, duration
70
+
71
+
72
+ def transcribe(audio_filepath):
73
+
74
+ if audio_filepath is None:
75
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
76
+
77
+ utt_id = uuid.uuid4()
78
+ with tempfile.TemporaryDirectory() as tmpdir:
79
+ converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
80
+
81
+ # make manifest file and save
82
+ manifest_data = {
83
+ "audio_filepath": converted_audio_filepath,
84
+ "source_lang": "en",
85
+ "target_lang": "en",
86
+ "taskname": "asr",
87
+ "pnc": "yes",
88
+ "answer": "predict",
89
+ "duration": str(duration),
90
+ }
91
+
92
+ manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
93
+
94
+ with open(manifest_filepath, 'w') as fout:
95
+ line = json.dumps(manifest_data)
96
+ fout.write(line + '\n')
97
+
98
+ # call transcribe, passing in manifest filepath
99
+ if duration < 40:
100
+ output_text = model.transcribe(manifest_filepath)[0]
101
+ else: # do buffered inference
102
+ with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
103
+ with torch.no_grad():
104
+ hyps = get_buffered_pred_feat_multitaskAED(
105
+ frame_asr,
106
+ model.cfg.preprocessor,
107
+ model_stride_in_secs,
108
+ model.device,
109
+ manifest=manifest_filepath,
110
+ filepaths=None,
111
+ )
112
+
113
+ output_text = hyps[0].text
114
+
115
+ return output_text
116
+
117
+
118
+
119
+ with gr.Blocks(
120
+ title="NeMo Canary Model",
121
+ css="""
122
+ textarea { font-size: 18px;}
123
+ #model_output_text_box span {
124
+ font-size: 18px;
125
+ font-weight: bold;
126
+ }
127
+ """,
128
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
129
+ ) as demo:
130
+
131
+ gr.HTML("<h1 style='text-align: center'>NeMo Canary model: Transcribe & Translate audio</h1>")
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ gr.HTML(
136
+ "<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>"
137
+ )
138
+
139
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
140
+
141
+
142
+ with gr.Column():
143
+
144
+ gr.HTML("<p><b>Step 3:</b> Run the model.</p>")
145
+
146
+ go_button = gr.Button(
147
+ value="Run model",
148
+ variant="primary", # make "primary" so it stands out (default is "secondary")
149
+ )
150
+
151
+ model_output_text_box = gr.Textbox(
152
+ label="Model Output",
153
+ elem_id="model_output_text_box",
154
+ )
155
+
156
+ go_button.click(
157
+ fn=transcribe,
158
+ inputs = [audio_file],
159
+ outputs = [model_output_text_box]
160
+ )
161
+
162
+ print(torch. cuda. is_available())
163
+ demo.queue()
164
+ demo.launch()