eabo commited on
Commit
cd814fd
·
1 Parent(s): cdef8c2

Initial commit

Browse files
Files changed (5) hide show
  1. app.py +7 -0
  2. data_files/.DS_Store +0 -0
  3. requirements.txt +2 -0
  4. utils.py +36 -0
  5. whisperui.py +216 -0
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from whisperui import WhisperModelUI
3
+
4
+ my_app = gr.Blocks()
5
+ iface = WhisperModelUI(my_app)
6
+ iface.create_whisper_ui()
7
+ iface.launch()
data_files/.DS_Store ADDED
Binary file (6.15 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/openai/whisper.git
2
+ pytube
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import os
3
+
4
+
5
+ def whisper_decode(model, audio):
6
+ # model = whisper.load_model("base")
7
+
8
+ audio = whisper.pad_or_trim(audio)
9
+ # make log-Mel spectrogram and move to the same device as the model
10
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
11
+ # detect the spoken language
12
+ _, probs = model.detect_language(mel)
13
+ print(f"Detected language: {max(probs, key=probs.get)}")
14
+
15
+ # decode the audio
16
+ options = whisper.DecodingOptions(
17
+ task='translate',
18
+ fp16=False)
19
+ result = whisper.decode(model, mel, options)
20
+ # print the recognized text
21
+ print(result.text)
22
+
23
+
24
+ def whisper_transcribe(model, audio):
25
+ result = model.transcribe(audio)
26
+ print(result["text"])
27
+
28
+
29
+ def try_whisper_model(model_type, choice):
30
+ model = whisper.load_model(model_type)
31
+ data_file = os.path.join(os.path.curdir, 'data_files', 'bharat.mp3')
32
+ audio = whisper.load_audio(data_file)
33
+ if choice == 'decode':
34
+ whisper_decode(model, audio)
35
+ elif choice == 'transcribe':
36
+ whisper_transcribe(model, audio)
whisperui.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import gradio as gr
3
+ import os
4
+ from pytube import YouTube
5
+
6
+
7
+ class WhisperModelUI(object):
8
+ def __init__(self, ui_obj):
9
+ self.name = "Whisper Model Processor UI"
10
+ self.description = "This class is designed to build UI for our Whisper Model"
11
+ self.ui_obj = ui_obj
12
+ self.audio_files_list = ['No content']
13
+ self.whisper_model = whisper.model.Whisper
14
+ self.video_store_path = 'data_files'
15
+
16
+ def load_content(self, file_list):
17
+ video_out_path = os.path.join(os.getcwd(), self.video_store_path)
18
+
19
+ self.audio_files_list = [f for f in os.listdir(video_out_path)
20
+ if os.path.isfile(video_out_path + "/" + f)
21
+ and (f.endswith(".mp4") or f.endswith('mp3'))]
22
+
23
+ return gr.Dropdown.update(choices=self.audio_files_list)
24
+
25
+ def load_whisper_model(self, model_type):
26
+ try:
27
+ asr_model = whisper.load_model(model_type.lower())
28
+ self.whisper_model = asr_model
29
+ status = "{} Model is loaded successfully".format(model_type)
30
+ except:
31
+ status = "error in loading {} model".format(model_type)
32
+
33
+ return status, str(self.whisper_model)
34
+
35
+ def load_youtube_video(self, video_url):
36
+ video_out_path = os.path.join(os.getcwd(), self.video_store_path)
37
+ yt = YouTube(video_url)
38
+ local_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by(
39
+ 'resolution').desc().first().download(video_out_path)
40
+ return local_video_path
41
+
42
+ def get_video_to_text(self,
43
+ transcribe_or_decode,
44
+ video_list_dropdown_file_name,
45
+ language_detect,
46
+ translate_or_transcribe
47
+ ):
48
+ debug_text = ""
49
+ try:
50
+ video_out_path = os.path.join(os.getcwd(), 'data_files')
51
+ video_full_path = os.path.join(video_out_path, video_list_dropdown_file_name)
52
+ if not os.path.isfile(video_full_path):
53
+ video_text = "Selected video/audio is could not be located.."
54
+ else:
55
+ video_text = "Bad choice or result.."
56
+ if transcribe_or_decode == 'Transcribe':
57
+ video_text, debug_text = self.run_asr_with_transcribe(video_full_path, language_detect,
58
+ translate_or_transcribe)
59
+ elif transcribe_or_decode == 'Decode':
60
+ audio = whisper.load_audio(video_full_path)
61
+ video_text, debug_text = self.run_asr_with_decode(audio, language_detect,
62
+ translate_or_transcribe)
63
+ except:
64
+ video_text = "Error processing audio..."
65
+ return video_text, debug_text
66
+
67
+ def run_asr_with_decode(self, audio, language_detect, translate_or_transcribe):
68
+ debug_info = "None.."
69
+
70
+ if 'encoder' not in dir(self.whisper_model) or 'decoder' not in dir(self.whisper_model):
71
+ return "Model is not loaded, please load the model first", debug_info
72
+
73
+ if self.whisper_model.encoder is None or self.whisper_model.decoder is None:
74
+ return "Model is not loaded, please load the model first", debug_info
75
+
76
+ try:
77
+ # pad/trim it to fit 30 seconds
78
+ audio = whisper.pad_or_trim(audio)
79
+
80
+ # make log-Mel spectrogram and move to the same device as the model
81
+ mel = whisper.log_mel_spectrogram(audio).to(self.whisper_model.device)
82
+
83
+ if language_detect == 'Detect':
84
+ # detect the spoken language
85
+ _, probs = self.whisper_model.detect_language(mel)
86
+ # print(f"Detected language: {max(probs, key=probs.get)}")
87
+
88
+ # decode the audio
89
+ # mps crash if fp16=False is not used
90
+
91
+ task_type = 'transcribe'
92
+ if translate_or_transcribe == 'Translate':
93
+ task_type = 'translate'
94
+
95
+ if language_detect != 'Detect':
96
+ options = whisper.DecodingOptions(fp16=False,
97
+ language=language_detect,
98
+ task=task_type)
99
+ else:
100
+ options = whisper.DecodingOptions(fp16=False,
101
+ task=task_type)
102
+
103
+ result = whisper.decode(self.whisper_model, mel, options)
104
+ result_text = result.text
105
+ debug_info = str(result)
106
+ except:
107
+ result_text = "Error handing audio to text.."
108
+ return result_text, debug_info
109
+
110
+ def run_asr_with_transcribe(self, audio_path, language_detect, translate_or_transcribe):
111
+ result_text = "Error..."
112
+ debug_info = "None.."
113
+
114
+ if 'encoder' not in dir(self.whisper_model) or 'decoder' not in dir(self.whisper_model):
115
+ return "Model is not loaded, please load the model first", debug_info
116
+
117
+ if self.whisper_model.encoder is None or self.whisper_model.decoder is None:
118
+ return "Model is not loaded, please load the model first", debug_info
119
+
120
+ task_type = 'transcribe'
121
+ if translate_or_transcribe == 'Translate':
122
+ task_type = 'translate'
123
+
124
+ transcribe_options = dict(beam_size=5, best_of=5,
125
+ fp16=False,
126
+ task=task_type,
127
+ without_timestamps=False)
128
+ if language_detect != 'Detect':
129
+ transcribe_options['language'] = language_detect
130
+
131
+ transcription = self.whisper_model.transcribe(audio_path, **transcribe_options)
132
+ if transcription is not None:
133
+ result_text = transcription['text']
134
+ debug_info = str(transcription)
135
+ return result_text, debug_info
136
+
137
+ def create_whisper_ui(self):
138
+ with self.ui_obj:
139
+ gr.Markdown("AI翻訳・書き起こし")
140
+ with gr.Tabs():
141
+ with gr.TabItem("YouTubeURLから"):
142
+ with gr.Row():
143
+ with gr.Column():
144
+ asr_model_type = gr.Radio(['Tiny', 'Base', 'Small', 'Medium', 'Large'],
145
+ label="モデルタイプ(精度)",
146
+ value='Base'
147
+ )
148
+ model_status_lbl = gr.Label(label="ローディングステータス")
149
+ load_model_btn = gr.Button("モデルをロード")
150
+ youtube_url = gr.Textbox(label="YouTube URL",
151
+ # value="https://www.youtube.com/watch?v=Y2nHd7El8iw"
152
+ value=""
153
+ )
154
+ youtube_video = gr.Video(label="ビデオ")
155
+ get_video_btn = gr.Button("YouTubeURLをロード")
156
+ with gr.Column():
157
+ video_list_dropdown = gr.Dropdown(self.audio_files_list, label="保存済みビデオ")
158
+ load_video_list_btn = gr.Button("全てのビデオをロード")
159
+ transcribe_or_decode = gr.Radio(['Transcribe', 'Decode'],
160
+ label="オプション(Transcribe = 書き起こし)",
161
+ value='Transcribe'
162
+ )
163
+ language_detect = gr.Dropdown(['Detect', 'English', 'Hindi', 'Japanese'],
164
+ label="自動検知か言語を選択")
165
+ translate_or_transcribe = gr.Dropdown(['Transcribe', 'Translate'],
166
+ label="Translate(翻訳)か Transcribe(書き起こし)を選択")
167
+ get_video_txt_btn = gr.Button("変換開始!")
168
+ video_text = gr.Textbox(label="テキスト", lines=10)
169
+ with gr.TabItem("デバッグ情報"):
170
+ with gr.Row():
171
+ with gr.Column():
172
+ debug_text = gr.Textbox(label="Debug Details", lines=20)
173
+ load_model_btn.click(
174
+ self.load_whisper_model,
175
+ [
176
+ asr_model_type
177
+ ],
178
+ [
179
+ model_status_lbl,
180
+ debug_text
181
+ ]
182
+ )
183
+ get_video_btn.click(
184
+ self.load_youtube_video,
185
+ [
186
+ youtube_url
187
+ ],
188
+ [
189
+ youtube_video
190
+ ]
191
+ )
192
+ load_video_list_btn.click(
193
+ self.load_content,
194
+ [
195
+ video_list_dropdown
196
+ ],
197
+ [
198
+ video_list_dropdown
199
+ ]
200
+ )
201
+ get_video_txt_btn.click(
202
+ self.get_video_to_text,
203
+ [
204
+ transcribe_or_decode,
205
+ video_list_dropdown,
206
+ language_detect,
207
+ translate_or_transcribe
208
+ ],
209
+ [
210
+ video_text,
211
+ debug_text
212
+ ]
213
+ )
214
+
215
+ def launch_ui(self):
216
+ self.ui_obj.launch(debug=True)