Athspi commited on
Commit
5989272
·
verified ·
1 Parent(s): 3e73331

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -123,10 +123,14 @@ LANGUAGE_NAME_TO_CODE = {
123
  "Sundanese": "su",
124
  }
125
 
126
- # Load the fine-tuned Sinhala model and processor
127
  processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
128
  sinhala_model = AutoModelForSpeechSeq2Seq.from_pretrained(SINHALA_MODEL)
129
 
 
 
 
 
130
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
131
  """Transcribe the audio file."""
132
  # Load the appropriate model
@@ -153,14 +157,14 @@ def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faste
153
  raw_audio = np.array(raw_audio, dtype=np.float32)
154
 
155
  # Process the audio and generate transcription
156
- inputs = model_processor(raw_audio, return_tensors="pt", sampling_rate=16000)
157
  with torch.no_grad():
158
- generated_ids = model.generate(inputs.input_features)
159
  transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
160
  detected_language = "si"
161
  else:
162
  # Use Whisper for auto-detection
163
- result = model.transcribe(processed_audio_path, fp16=False)
164
  transcription = result["text"]
165
  detected_language = result.get("language", "unknown")
166
  else:
@@ -170,15 +174,15 @@ def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faste
170
  raw_audio = np.array(raw_audio, dtype=np.float32)
171
 
172
  # Process the audio and generate transcription
173
- inputs = model_processor(raw_audio, return_tensors="pt", sampling_rate=16000)
174
  with torch.no_grad():
175
- generated_ids = model.generate(inputs.input_features)
176
  transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
177
  detected_language = "si"
178
  else:
179
  # Use Whisper for transcription with the selected language
180
  language_code = LANGUAGE_NAME_TO_CODE.get(language, "en")
181
- result = model.transcribe(processed_audio_path, language=language_code, fp16=False)
182
  transcription = result["text"]
183
  detected_language = language_code
184
 
 
123
  "Sundanese": "su",
124
  }
125
 
126
+ # Preload the fine-tuned Sinhala model and processor
127
  processor = AutoProcessor.from_pretrained(SINHALA_MODEL)
128
  sinhala_model = AutoModelForSpeechSeq2Seq.from_pretrained(SINHALA_MODEL)
129
 
130
+ # Move model to GPU if available
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ sinhala_model.to(device)
133
+
134
  def transcribe_audio(audio_file, language="Auto Detect", model_size="Base (Faster)"):
135
  """Transcribe the audio file."""
136
  # Load the appropriate model
 
157
  raw_audio = np.array(raw_audio, dtype=np.float32)
158
 
159
  # Process the audio and generate transcription
160
+ inputs = model_processor(raw_audio, return_tensors="pt", sampling_rate=16000).input_features.to(device)
161
  with torch.no_grad():
162
+ generated_ids = model.generate(inputs)
163
  transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
164
  detected_language = "si"
165
  else:
166
  # Use Whisper for auto-detection
167
+ result = model.transcribe(processed_audio_path, fp16=(device == "cuda"))
168
  transcription = result["text"]
169
  detected_language = result.get("language", "unknown")
170
  else:
 
174
  raw_audio = np.array(raw_audio, dtype=np.float32)
175
 
176
  # Process the audio and generate transcription
177
+ inputs = model_processor(raw_audio, return_tensors="pt", sampling_rate=16000).input_features.to(device)
178
  with torch.no_grad():
179
+ generated_ids = model.generate(inputs)
180
  transcription = model_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
181
  detected_language = "si"
182
  else:
183
  # Use Whisper for transcription with the selected language
184
  language_code = LANGUAGE_NAME_TO_CODE.get(language, "en")
185
+ result = model.transcribe(processed_audio_path, language=language_code, fp16=(device == "cuda"))
186
  transcription = result["text"]
187
  detected_language = language_code
188