Athspi commited on
Commit
d600bb8
·
verified ·
1 Parent(s): 0f8086c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -8
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import whisper
3
  import os
4
  from pydub import AudioSegment
 
 
5
 
6
  # Mapping of model names to Whisper model sizes
7
  MODELS = {
@@ -13,7 +15,7 @@ MODELS = {
13
  }
14
 
15
  # Fine-tuned Sinhala model
16
- SINHALA_MODEL = "malakazzz/Subhaka-whisper-small-Sinhala-Fine_Tune"
17
 
18
  # Mapping of full language names to language codes
19
  LANGUAGE_NAME_TO_CODE = {
@@ -119,15 +121,21 @@ LANGUAGE_NAME_TO_CODE = {
119
  "Sundanese": "su",
120
  }
121
 
 
 
 
 
122
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
123
  """Transcribe the audio file."""
124
  # Load the appropriate model
125
  if language == "Sinhala":
126
  # Use the fine-tuned Sinhala model
127
- model = gr.load(SINHALA_MODEL)
 
128
  else:
129
  # Use the selected Whisper model
130
  model = whisper.load_model(MODELS[model_size])
 
131
 
132
  # Convert audio to 16kHz mono for better compatibility with Whisper
133
  audio = AudioSegment.from_file(audio_file)
@@ -137,18 +145,38 @@ def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faste
137
 
138
  # Transcribe the audio
139
  if language == "Auto Detect":
140
- result = model.transcribe(processed_audio_path, fp16=False) # Auto-detect language
141
- detected_language = result.get("language", "unknown")
 
 
 
 
 
 
 
 
 
 
142
  else:
143
- language_code = LANGUAGE_NAME_TO_CODE.get(language, "en") # Default to English if not found
144
- result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
145
- detected_language = language_code
 
 
 
 
 
 
 
 
 
 
146
 
147
  # Clean up processed audio file
148
  os.remove(processed_audio_path)
149
 
150
  # Return transcription and detected language
151
- return f"Detected Language: {detected_language}\n\nTranscription:\n{result['text']}"
152
 
153
  # Define the Gradio interface
154
  with gr.Blocks() as demo:
 
2
  import whisper
3
  import os
4
  from pydub import AudioSegment
5
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
+ import torch
7
 
8
  # Mapping of model names to Whisper model sizes
9
  MODELS = {
 
15
  }
16
 
17
  # Fine-tuned Sinhala model
18
+ SINHALA_MODEL = "Subhaka/whisper-small-Sinhala-Fine_Tune"
19
 
20
  # Mapping of full language names to language codes
21
  LANGUAGE_NAME_TO_CODE = {
 
121
  "Sundanese": "su",
122
  }
123
 
124
+ # Load the fine-tuned Sinhala model and processor
125
+ processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
126
+ sinhala_model = AutoModelForSpeechSeq2Seq.from_pretrained(SINHALA_MODEL)
127
+
128
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
129
  """Transcribe the audio file."""
130
  # Load the appropriate model
131
  if language == "Sinhala":
132
  # Use the fine-tuned Sinhala model
133
+ model = sinhala_model
134
+ model_processor = processor
135
  else:
136
  # Use the selected Whisper model
137
  model = whisper.load_model(MODELS[model_size])
138
+ model_processor = None
139
 
140
  # Convert audio to 16kHz mono for better compatibility with Whisper
141
  audio = AudioSegment.from_file(audio_file)
 
145
 
146
  # Transcribe the audio
147
  if language == "Auto Detect":
148
+ if model_processor:
149
+ # Use the fine-tuned Sinhala model for transcription
150
+ inputs = model_processor(processed_audio_path, return_tensors="pt", sampling_rate=16000)
151
+ with torch.no_grad():
152
+ generated_ids = model.generate(inputs.input_features)
153
+ transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
154
+ detected_language = "si"
155
+ else:
156
+ # Use Whisper for auto-detection
157
+ result = model.transcribe(processed_audio_path, fp16=False)
158
+ transcription = result["text"]
159
+ detected_language = result.get("language", "unknown")
160
  else:
161
+ if model_processor:
162
+ # Use the fine-tuned Sinhala model for transcription
163
+ inputs = model_processor(processed_audio_path, return_tensors="pt", sampling_rate=16000)
164
+ with torch.no_grad():
165
+ generated_ids = model.generate(inputs.input_features)
166
+ transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
167
+ detected_language = "si"
168
+ else:
169
+ # Use Whisper for transcription with the selected language
170
+ language_code = LANGUAGE_NAME_TO_CODE.get(language, "en")
171
+ result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
172
+ transcription = result["text"]
173
+ detected_language = language_code
174
 
175
  # Clean up processed audio file
176
  os.remove(processed_audio_path)
177
 
178
  # Return transcription and detected language
179
+ return f"Detected Language: {detected_language}\n\nTranscription:\n{transcription}"
180
 
181
  # Define the Gradio interface
182
  with gr.Blocks() as demo: