Athspi commited on
Commit
ce80eeb
·
verified ·
1 Parent(s): 09f2e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -32
app.py CHANGED
@@ -137,14 +137,11 @@ def transcribe_with_whisper(audio_file, language="Auto Detect", model_size="Base
137
  result = model.transcribe(processed_audio_path, fp16=False)
138
  detected_language = result.get("language", "unknown")
139
  else:
140
- language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") # Default to English if not found
141
  result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
142
  detected_language = language_code
143
 
144
- # Clean up processed audio file
145
  os.remove(processed_audio_path)
146
-
147
- # Return transcription and detected language
148
  return f"Detected Language: {detected_language}\n\nTranscription:\n{result['text']}"
149
 
150
  def transcribe_with_sinhala_model(audio_file):
@@ -152,24 +149,18 @@ def transcribe_with_sinhala_model(audio_file):
152
  processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
153
  model = AutoModelForCTC.from_pretrained(SINHALA_MODEL)
154
 
155
- # Convert audio to 16kHz mono
156
  audio = AudioSegment.from_file(audio_file)
157
  audio = audio.set_frame_rate(16000).set_channels(1)
158
  processed_audio_path = "processed_audio.wav"
159
  audio.export(processed_audio_path, format="wav")
160
 
161
- # Load and process audio
162
  audio_input, _ = torchaudio.load(processed_audio_path)
163
  input_values = processor(audio_input.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
164
  logits = model(input_values).logits
165
  predicted_ids = torch.argmax(logits, dim=-1)
166
 
167
- # Decode prediction
168
  transcription = processor.batch_decode(predicted_ids)[0]
169
-
170
- # Clean up processed audio file
171
  os.remove(processed_audio_path)
172
-
173
  return f"Transcription:\n{transcription}"
174
 
175
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
@@ -179,35 +170,30 @@ def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faste
179
  else:
180
  return transcribe_with_whisper(audio_file, language, model_size)
181
 
182
- # Define the Gradio interface
183
  with gr.Blocks() as demo:
184
  gr.Markdown("# Audio Transcription and Language Detection")
185
 
186
- with gr.Tab("Transcribe Audio"):
187
- gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.")
188
- transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File")
189
- language_dropdown = gr.Dropdown(
190
- choices=list(LANGUAGE_NAME_TO_CODE.keys()),
191
- label="Select Language",
192
- value="Auto Detect"
193
- )
194
- model_dropdown = gr.Dropdown(
195
- choices=list(MODELS.keys()),
196
- label="Select Model",
197
- value="Base (Faster)"
198
- )
199
- transcribe_output = gr.Textbox(label="Transcription and Detected Language")
200
- transcribe_button = gr.Button("Transcribe Audio")
201
 
202
- # Update model dropdown based on language selection
203
  def update_model_dropdown(language):
204
  if language == "Sinhala":
205
- return gr.Dropdown(interactive=False, value="Fine-Tuned Sinhala Model")
206
- else:
207
- return gr.Dropdown(choices=list(MODELS.keys()), interactive=True, value="Base (Faster)")
208
-
209
  language_dropdown.change(update_model_dropdown, inputs=language_dropdown, outputs=model_dropdown)
210
  transcribe_button.click(transcribe_audio, inputs=[transcribe_audio_input, language_dropdown, model_dropdown], outputs=transcribe_output)
211
 
212
- # Launch the Gradio interface
213
  demo.launch()
 
137
  result = model.transcribe(processed_audio_path, fp16=False)
138
  detected_language = result.get("language", "unknown")
139
  else:
140
+ language_code = LANGUAGE_NAME_TO_CODE.get(language, "en")
141
  result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
142
  detected_language = language_code
143
 
 
144
  os.remove(processed_audio_path)
 
 
145
  return f"Detected Language: {detected_language}\n\nTranscription:\n{result['text']}"
146
 
147
  def transcribe_with_sinhala_model(audio_file):
 
149
  processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
150
  model = AutoModelForCTC.from_pretrained(SINHALA_MODEL)
151
 
 
152
  audio = AudioSegment.from_file(audio_file)
153
  audio = audio.set_frame_rate(16000).set_channels(1)
154
  processed_audio_path = "processed_audio.wav"
155
  audio.export(processed_audio_path, format="wav")
156
 
 
157
  audio_input, _ = torchaudio.load(processed_audio_path)
158
  input_values = processor(audio_input.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
159
  logits = model(input_values).logits
160
  predicted_ids = torch.argmax(logits, dim=-1)
161
 
 
162
  transcription = processor.batch_decode(predicted_ids)[0]
 
 
163
  os.remove(processed_audio_path)
 
164
  return f"Transcription:\n{transcription}"
165
 
166
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
 
170
  else:
171
  return transcribe_with_whisper(audio_file, language, model_size)
172
 
173
+ # Gradio interface
174
  with gr.Blocks() as demo:
175
  gr.Markdown("# Audio Transcription and Language Detection")
176
 
177
+ transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File")
178
+ language_dropdown = gr.Dropdown(
179
+ choices=list(LANGUAGE_NAME_TO_CODE.keys()),
180
+ label="Select Language",
181
+ value="Auto Detect"
182
+ )
183
+ model_dropdown = gr.Dropdown(
184
+ choices=list(MODELS.keys()),
185
+ label="Select Whisper Model",
186
+ value="Base (Faster)"
187
+ )
188
+ transcribe_output = gr.Textbox(label="Transcription")
189
+ transcribe_button = gr.Button("Transcribe Audio")
 
 
190
 
 
191
  def update_model_dropdown(language):
192
  if language == "Sinhala":
193
+ return gr.update(interactive=False, value="Base (Faster)") # Disable dropdown
194
+ return gr.update(interactive=True, value="Base (Faster)")
195
+
 
196
  language_dropdown.change(update_model_dropdown, inputs=language_dropdown, outputs=model_dropdown)
197
  transcribe_button.click(transcribe_audio, inputs=[transcribe_audio_input, language_dropdown, model_dropdown], outputs=transcribe_output)
198
 
 
199
  demo.launch()