Athspi commited on
Commit
5a84705
·
verified ·
1 Parent(s): ce80eeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -58
app.py CHANGED
@@ -2,9 +2,7 @@ import gradio as gr
2
  import whisper
3
  import os
4
  from pydub import AudioSegment
5
- from transformers import AutoProcessor, AutoModelForCTC
6
- import torchaudio
7
- import torch
8
 
9
  # Mapping of model names to Whisper model sizes
10
  MODELS = {
@@ -12,12 +10,11 @@ MODELS = {
12
  "Base (Faster)": "base",
13
  "Small (Balanced)": "small",
14
  "Medium (Accurate)": "medium",
15
- "Large (Most Accurate)": "large"
 
 
16
  }
17
 
18
- # Fine-tuned Sinhala model (using Hugging Face Transformers)
19
- SINHALA_MODEL = "IAmNotAnanth/wav2vec2-large-xls-r-300m-sinhala"
20
-
21
  # Mapping of full language names to language codes
22
  LANGUAGE_NAME_TO_CODE = {
23
  "Auto Detect": "Auto Detect",
@@ -122,78 +119,106 @@ LANGUAGE_NAME_TO_CODE = {
122
  "Sundanese": "su",
123
  }
124
 
125
- def transcribe_with_whisper(audio_file, language="Auto Detect", model_size="Base (Faster)"):
126
- """Transcribe using OpenAI's Whisper models."""
127
- model = whisper.load_model(MODELS[model_size])
 
128
 
129
- # Convert audio to 16kHz mono for compatibility with Whisper
130
  audio = AudioSegment.from_file(audio_file)
131
  audio = audio.set_frame_rate(16000).set_channels(1)
132
  processed_audio_path = "processed_audio.wav"
133
  audio.export(processed_audio_path, format="wav")
134
 
135
- # Transcribe the audio
136
- if language == "Auto Detect":
137
- result = model.transcribe(processed_audio_path, fp16=False)
138
- detected_language = result.get("language", "unknown")
139
- else:
140
- language_code = LANGUAGE_NAME_TO_CODE.get(language, "en")
141
- result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
142
- detected_language = language_code
143
-
144
  os.remove(processed_audio_path)
145
- return f"Detected Language: {detected_language}\n\nTranscription:\n{result['text']}"
 
146
 
147
- def transcribe_with_sinhala_model(audio_file):
148
- """Transcribe using the fine-tuned Sinhala Wav2Vec2 model."""
149
- processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
150
- model = AutoModelForCTC.from_pretrained(SINHALA_MODEL)
 
 
 
 
151
 
 
 
 
 
 
 
 
 
 
 
152
  audio = AudioSegment.from_file(audio_file)
153
  audio = audio.set_frame_rate(16000).set_channels(1)
154
  processed_audio_path = "processed_audio.wav"
155
  audio.export(processed_audio_path, format="wav")
156
 
157
- audio_input, _ = torchaudio.load(processed_audio_path)
158
- input_values = processor(audio_input.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
159
- logits = model(input_values).logits
160
- predicted_ids = torch.argmax(logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- transcription = processor.batch_decode(predicted_ids)[0]
163
  os.remove(processed_audio_path)
164
- return f"Transcription:\n{transcription}"
165
 
166
- def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
167
- """Wrapper to select the correct transcription method."""
168
- if language == "Sinhala":
169
- return transcribe_with_sinhala_model(audio_file)
170
- else:
171
- return transcribe_with_whisper(audio_file, language, model_size)
172
 
173
- # Gradio interface
174
  with gr.Blocks() as demo:
175
  gr.Markdown("# Audio Transcription and Language Detection")
176
 
177
- transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File")
178
- language_dropdown = gr.Dropdown(
179
- choices=list(LANGUAGE_NAME_TO_CODE.keys()),
180
- label="Select Language",
181
- value="Auto Detect"
182
- )
183
- model_dropdown = gr.Dropdown(
184
- choices=list(MODELS.keys()),
185
- label="Select Whisper Model",
186
- value="Base (Faster)"
187
- )
188
- transcribe_output = gr.Textbox(label="Transcription")
189
- transcribe_button = gr.Button("Transcribe Audio")
190
 
191
- def update_model_dropdown(language):
192
- if language == "Sinhala":
193
- return gr.update(interactive=False, value="Base (Faster)") # Disable dropdown
194
- return gr.update(interactive=True, value="Base (Faster)")
195
-
196
- language_dropdown.change(update_model_dropdown, inputs=language_dropdown, outputs=model_dropdown)
 
 
 
 
 
 
 
 
 
 
 
 
197
  transcribe_button.click(transcribe_audio, inputs=[transcribe_audio_input, language_dropdown, model_dropdown], outputs=transcribe_output)
198
 
199
- demo.launch()
 
 
2
  import whisper
3
  import os
4
  from pydub import AudioSegment
5
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
 
 
6
 
7
  # Mapping of model names to Whisper model sizes
8
  MODELS = {
 
10
  "Base (Faster)": "base",
11
  "Small (Balanced)": "small",
12
  "Medium (Accurate)": "medium",
13
+ "Large (Most Accurate)": "large",
14
+ "Fine-Tuned Hindi": "yash-04/whisper-base-hindi", # Hindi fine-tuned model
15
+ "Fine-Tuned Tamil": "mahimairaja/whisper-base-tamil" # Tamil fine-tuned model
16
  }
17
 
 
 
 
18
  # Mapping of full language names to language codes
19
  LANGUAGE_NAME_TO_CODE = {
20
  "Auto Detect": "Auto Detect",
 
119
  "Sundanese": "su",
120
  }
121
 
122
+ def detect_language(audio_file):
123
+ """Detect the language of the audio file."""
124
+ # Load the Whisper model (use "base" for faster detection)
125
+ model = whisper.load_model("base")
126
 
127
+ # Convert audio to 16kHz mono for better compatibility with Whisper
128
  audio = AudioSegment.from_file(audio_file)
129
  audio = audio.set_frame_rate(16000).set_channels(1)
130
  processed_audio_path = "processed_audio.wav"
131
  audio.export(processed_audio_path, format="wav")
132
 
133
+ # Detect the language
134
+ result = model.transcribe(processed_audio_path, task="detect_language", fp16=False)
135
+ detected_language = result.get("language", "unknown")
136
+
137
+ # Clean up processed audio file
 
 
 
 
138
  os.remove(processed_audio_path)
139
+
140
+ return f"Detected Language: {detected_language}"
141
 
142
+ def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
143
+ """Transcribe the audio file."""
144
+ # Map language to fine-tuned model
145
+ language_to_model = {
146
+ "Hindi": "yash-04/whisper-base-hindi",
147
+ "Tamil": "mahimairaja/whisper-base-tamil",
148
+ # Add more mappings as needed
149
+ }
150
 
151
+ # Load the selected Whisper model
152
+ if language in language_to_model:
153
+ model_name = language_to_model[language]
154
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
155
+ processor = WhisperProcessor.from_pretrained(model_name)
156
+ else:
157
+ model = whisper.load_model(MODELS[model_size])
158
+ processor = None # Use default Whisper processor
159
+
160
+ # Convert audio to 16kHz mono for better compatibility with Whisper
161
  audio = AudioSegment.from_file(audio_file)
162
  audio = audio.set_frame_rate(16000).set_channels(1)
163
  processed_audio_path = "processed_audio.wav"
164
  audio.export(processed_audio_path, format="wav")
165
 
166
+ # Transcribe the audio
167
+ if language == "Auto Detect":
168
+ if processor:
169
+ inputs = processor(processed_audio_path, return_tensors="pt", sampling_rate=16000)
170
+ result = model.generate(inputs.input_features)
171
+ transcription = processor.batch_decode(result, skip_special_tokens=True)[0]
172
+ else:
173
+ result = model.transcribe(processed_audio_path, fp16=False)
174
+ transcription = result["text"]
175
+ detected_language = result.get("language", "unknown")
176
+ else:
177
+ language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") # Default to English if not found
178
+ if processor:
179
+ inputs = processor(processed_audio_path, return_tensors="pt", sampling_rate=16000)
180
+ result = model.generate(inputs.input_features, language=language_code)
181
+ transcription = processor.batch_decode(result, skip_special_tokens=True)[0]
182
+ else:
183
+ result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
184
+ transcription = result["text"]
185
+ detected_language = language_code
186
 
187
+ # Clean up processed audio file
188
  os.remove(processed_audio_path)
 
189
 
190
+ # Return transcription and detected language
191
+ return f"Detected Language: {detected_language}\n\nTranscription:\n{transcription}"
 
 
 
 
192
 
193
+ # Define the Gradio interface
194
  with gr.Blocks() as demo:
195
  gr.Markdown("# Audio Transcription and Language Detection")
196
 
197
+ with gr.Tab("Detect Language"):
198
+ gr.Markdown("Upload an audio file to detect its language.")
199
+ detect_audio_input = gr.Audio(type="filepath", label="Upload Audio File")
200
+ detect_language_output = gr.Textbox(label="Detected Language")
201
+ detect_button = gr.Button("Detect Language")
 
 
 
 
 
 
 
 
202
 
203
+ with gr.Tab("Transcribe Audio"):
204
+ gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.")
205
+ transcribe_audio_input = gr.Audio(type="filepath", label="Upload Audio File")
206
+ language_dropdown = gr.Dropdown(
207
+ choices=list(LANGUAGE_NAME_TO_CODE.keys()), # Full language names
208
+ label="Select Language",
209
+ value="Auto Detect"
210
+ )
211
+ model_dropdown = gr.Dropdown(
212
+ choices=list(MODELS.keys()), # Model options
213
+ label="Select Model",
214
+ value="Base (Faster)" # Default to "Base" model
215
+ )
216
+ transcribe_output = gr.Textbox(label="Transcription and Detected Language")
217
+ transcribe_button = gr.Button("Transcribe Audio")
218
+
219
+ # Link buttons to functions
220
+ detect_button.click(detect_language, inputs=detect_audio_input, outputs=detect_language_output)
221
  transcribe_button.click(transcribe_audio, inputs=[transcribe_audio_input, language_dropdown, model_dropdown], outputs=transcribe_output)
222
 
223
+ # Launch the Gradio interface
224
+ demo.launch()