tahirsher commited on
Commit
8d55ac9
Β·
verified Β·
1 Parent(s): 9b5528f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -26
app.py CHANGED
@@ -14,24 +14,19 @@ from transformers import (
14
  )
15
 
16
  # ================================
17
- # 1️⃣ Authenticate with Hugging Face Hub
18
  # ================================
19
-
20
- # Get HF token securely from environment variables
21
- HF_TOKEN = os.getenv("hf_token")
22
 
23
  if HF_TOKEN is None:
24
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
25
 
26
- # Login using the stored token
27
  login(token=HF_TOKEN)
28
 
29
  # ================================
30
  # 2️⃣ Load Model & Processor
31
  # ================================
32
  MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
33
-
34
- # Load ASR model and processor
35
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
36
  model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
37
 
@@ -163,9 +158,11 @@ if audio_file:
163
  # Convert audio to model input
164
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
165
 
 
 
 
166
  # Perform ASR inference
167
  with torch.no_grad():
168
- input_tensor = torch.tensor([input_features]).to(device)
169
  logits = model(input_tensor).logits
170
  predicted_ids = torch.argmax(logits, dim=-1)
171
  transcription = processor.batch_decode(predicted_ids)[0]
@@ -173,21 +170,3 @@ if audio_file:
173
  # Display transcription
174
  st.success("πŸ“„ Transcription:")
175
  st.write(transcription)
176
-
177
- # ================================
178
- # 8️⃣ Fine-Tune Model with User Correction
179
- # ================================
180
- user_correction = st.text_area("πŸ”§ Correct the transcription (if needed):", transcription)
181
-
182
- if st.button("Fine-Tune with Correction"):
183
- if user_correction:
184
- corrected_input = processor.tokenizer(user_correction).input_ids
185
-
186
- # Dynamically add new example to dataset
187
- dataset.append({"input_features": input_features, "labels": corrected_input})
188
-
189
- # Perform quick re-training (1 epoch)
190
- trainer.args.num_train_epochs = 1
191
- trainer.train()
192
-
193
- st.success("βœ… Model fine-tuned with new correction! Try another audio file.")
 
14
  )
15
 
16
  # ================================
17
+ # 1️⃣ Authenticate with Hugging Face Hub (Securely)
18
  # ================================
19
+ HF_TOKEN = os.getenv("hf_token") # Ensure it's set in Hugging Face Spaces Secrets
 
 
20
 
21
  if HF_TOKEN is None:
22
  raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
23
 
 
24
  login(token=HF_TOKEN)
25
 
26
  # ================================
27
  # 2️⃣ Load Model & Processor
28
  # ================================
29
  MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
 
 
30
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
31
  model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
32
 
 
158
  # Convert audio to model input
159
  input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
160
 
161
+ # βœ… FIX: Ensure input tensor is correctly formatted
162
+ input_tensor = input_features.unsqueeze(0).to(device) # Adds batch dimension
163
+
164
  # Perform ASR inference
165
  with torch.no_grad():
 
166
  logits = model(input_tensor).logits
167
  predicted_ids = torch.argmax(logits, dim=-1)
168
  transcription = processor.batch_decode(predicted_ids)[0]
 
170
  # Display transcription
171
  st.success("πŸ“„ Transcription:")
172
  st.write(transcription)