Athspi commited on
Commit
7d07125
·
verified ·
1 Parent(s): 5989272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -61
app.py CHANGED
@@ -1,11 +1,9 @@
1
  import gradio as gr
2
  import whisper
 
3
  import os
4
  from pydub import AudioSegment
5
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
- import torch
7
- import librosa
8
- import numpy as np
9
 
10
  # Mapping of model names to Whisper model sizes
11
  MODELS = {
@@ -16,8 +14,14 @@ MODELS = {
16
  "Large (Most Accurate)": "large"
17
  }
18
 
19
- # Fine-tuned Sinhala model
20
- SINHALA_MODEL = "Subhaka/whisper-small-Sinhala-Fine_Tune"
 
 
 
 
 
 
21
 
22
  # Mapping of full language names to language codes
23
  LANGUAGE_NAME_TO_CODE = {
@@ -123,68 +127,41 @@ LANGUAGE_NAME_TO_CODE = {
123
  "Sundanese": "su",
124
  }
125
 
126
- # Preload the fine-tuned Sinhala model and processor
127
- processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
128
- sinhala_model = AutoModelForSpeechSeq2Seq.from_pretrained(SINHALA_MODEL)
129
-
130
- # Move model to GPU if available
131
- device = "cuda" if torch.cuda.is_available() else "cpu"
132
- sinhala_model.to(device)
133
-
134
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
135
  """Transcribe the audio file."""
136
- # Load the appropriate model
137
- if language == "Sinhala":
138
- # Use the fine-tuned Sinhala model
139
- model = sinhala_model
140
- model_processor = processor
141
- else:
142
- # Use the selected Whisper model
143
- model = whisper.load_model(MODELS[model_size])
144
- model_processor = None
145
-
146
- # Convert audio to 16kHz mono for better compatibility with Whisper
147
  audio = AudioSegment.from_file(audio_file)
148
  audio = audio.set_frame_rate(16000).set_channels(1)
149
  processed_audio_path = "processed_audio.wav"
150
  audio.export(processed_audio_path, format="wav")
151
 
152
- # Transcribe the audio
153
- if language == "Auto Detect":
154
- if model_processor:
155
- # Load the audio as a NumPy array
156
- raw_audio, _ = librosa.load(processed_audio_path, sr=16000)
157
- raw_audio = np.array(raw_audio, dtype=np.float32)
158
-
159
- # Process the audio and generate transcription
160
- inputs = model_processor(raw_audio, return_tensors="pt", sampling_rate=16000).input_features.to(device)
161
- with torch.no_grad():
162
- generated_ids = model.generate(inputs)
163
- transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
164
- detected_language = "si"
165
- else:
166
- # Use Whisper for auto-detection
167
- result = model.transcribe(processed_audio_path, fp16=(device == "cuda"))
168
- transcription = result["text"]
169
- detected_language = result.get("language", "unknown")
170
  else:
171
- if model_processor:
172
- # Load the audio as a NumPy array
173
- raw_audio, _ = librosa.load(processed_audio_path, sr=16000)
174
- raw_audio = np.array(raw_audio, dtype=np.float32)
175
-
176
- # Process the audio and generate transcription
177
- inputs = model_processor(raw_audio, return_tensors="pt", sampling_rate=16000).input_features.to(device)
178
- with torch.no_grad():
179
- generated_ids = model.generate(inputs)
180
- transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
181
- detected_language = "si"
182
  else:
183
- # Use Whisper for transcription with the selected language
184
- language_code = LANGUAGE_NAME_TO_CODE.get(language, "en")
185
- result = model.transcribe(processed_audio_path, language=language_code, fp16=(device == "cuda"))
186
- transcription = result["text"]
187
  detected_language = language_code
 
 
188
 
189
  # Clean up processed audio file
190
  os.remove(processed_audio_path)
@@ -194,7 +171,7 @@ def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faste
194
 
195
  # Define the Gradio interface
196
  with gr.Blocks() as demo:
197
- gr.Markdown("# Audio Transcription and Language Detection")
198
 
199
  with gr.Tab("Transcribe Audio"):
200
  gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.")
@@ -215,8 +192,8 @@ with gr.Blocks() as demo:
215
 
216
  # Update model dropdown based on language selection
217
  def update_model_dropdown(language):
218
- if language == "Sinhala":
219
- return gr.Dropdown(interactive=False, value="Base (Faster)") # Set a valid value
220
  else:
221
  return gr.Dropdown(choices=list(MODELS.keys()), interactive=True, value="Base (Faster)")
222
 
 
1
  import gradio as gr
2
  import whisper
3
+ import torch
4
  import os
5
  from pydub import AudioSegment
6
+ from transformers import AutoProcessor, AutoModelForCTC
 
 
 
7
 
8
  # Mapping of model names to Whisper model sizes
9
  MODELS = {
 
14
  "Large (Most Accurate)": "large"
15
  }
16
 
17
+ # Fine-tuned Wav2Vec2 models for specific languages
18
+ WAV2VEC2_MODELS = {
19
+ "Tamil": {
20
+ "processor": "Amrrs/wav2vec2-large-xlsr-53-tamil",
21
+ "model": "Amrrs/wav2vec2-large-xlsr-53-tamil"
22
+ },
23
+ # Add more Wav2Vec2 models for other languages here
24
+ }
25
 
26
  # Mapping of full language names to language codes
27
  LANGUAGE_NAME_TO_CODE = {
 
127
  "Sundanese": "su",
128
  }
129
 
 
 
 
 
 
 
 
 
130
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
131
  """Transcribe the audio file."""
132
+ # Convert audio to 16kHz mono for better compatibility
 
 
 
 
 
 
 
 
 
 
133
  audio = AudioSegment.from_file(audio_file)
134
  audio = audio.set_frame_rate(16000).set_channels(1)
135
  processed_audio_path = "processed_audio.wav"
136
  audio.export(processed_audio_path, format="wav")
137
 
138
+ # Load the appropriate model
139
+ if language in WAV2VEC2_MODELS:
140
+ # Use the fine-tuned Wav2Vec2 model for the selected language
141
+ processor = AutoProcessor.from_pretrained(WAV2VEC2_MODELS[language]["processor"])
142
+ model = AutoModelForCTC.from_pretrained(WAV2VEC2_MODELS[language]["model"])
143
+
144
+ # Load audio and process
145
+ inputs = processor(AudioSegment.from_file(processed_audio_path).raw_data, sampling_rate=16000, return_tensors="pt")
146
+ with torch.no_grad():
147
+ logits = model(inputs.input_values).logits
148
+ predicted_ids = torch.argmax(logits, dim=-1)
149
+ transcription = processor.decode(predicted_ids[0])
150
+ detected_language = language
 
 
 
 
 
151
  else:
152
+ # Use the selected Whisper model
153
+ model = whisper.load_model(MODELS[model_size])
154
+
155
+ # Transcribe the audio
156
+ if language == "Auto Detect":
157
+ result = model.transcribe(processed_audio_path, fp16=False) # Auto-detect language
158
+ detected_language = result.get("language", "unknown")
 
 
 
 
159
  else:
160
+ language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") # Default to English if not found
161
+ result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
 
 
162
  detected_language = language_code
163
+
164
+ transcription = result["text"]
165
 
166
  # Clean up processed audio file
167
  os.remove(processed_audio_path)
 
171
 
172
  # Define the Gradio interface
173
  with gr.Blocks() as demo:
174
+ gr.Markdown("# Audio Transcription with Fine-Tuned Models")
175
 
176
  with gr.Tab("Transcribe Audio"):
177
  gr.Markdown("Upload an audio file, select a language (or choose 'Auto Detect'), and choose a model for transcription.")
 
192
 
193
  # Update model dropdown based on language selection
194
  def update_model_dropdown(language):
195
+ if language in WAV2VEC2_MODELS:
196
+ return gr.Dropdown(interactive=False, value=f"Fine-Tuned {language} Model")
197
  else:
198
  return gr.Dropdown(choices=list(MODELS.keys()), interactive=True, value="Base (Faster)")
199