Update app.py
Browse files
app.py
CHANGED
@@ -14,24 +14,19 @@ from transformers import (
|
|
14 |
)
|
15 |
|
16 |
# ================================
|
17 |
-
# 1οΈβ£ Authenticate with Hugging Face Hub
|
18 |
# ================================
|
19 |
-
|
20 |
-
# Get HF token securely from environment variables
|
21 |
-
HF_TOKEN = os.getenv("hf_token")
|
22 |
|
23 |
if HF_TOKEN is None:
|
24 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
25 |
|
26 |
-
# Login using the stored token
|
27 |
login(token=HF_TOKEN)
|
28 |
|
29 |
# ================================
|
30 |
# 2οΈβ£ Load Model & Processor
|
31 |
# ================================
|
32 |
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
|
33 |
-
|
34 |
-
# Load ASR model and processor
|
35 |
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
36 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
37 |
|
@@ -163,9 +158,11 @@ if audio_file:
|
|
163 |
# Convert audio to model input
|
164 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
165 |
|
|
|
|
|
|
|
166 |
# Perform ASR inference
|
167 |
with torch.no_grad():
|
168 |
-
input_tensor = torch.tensor([input_features]).to(device)
|
169 |
logits = model(input_tensor).logits
|
170 |
predicted_ids = torch.argmax(logits, dim=-1)
|
171 |
transcription = processor.batch_decode(predicted_ids)[0]
|
@@ -173,21 +170,3 @@ if audio_file:
|
|
173 |
# Display transcription
|
174 |
st.success("π Transcription:")
|
175 |
st.write(transcription)
|
176 |
-
|
177 |
-
# ================================
|
178 |
-
# 8οΈβ£ Fine-Tune Model with User Correction
|
179 |
-
# ================================
|
180 |
-
user_correction = st.text_area("π§ Correct the transcription (if needed):", transcription)
|
181 |
-
|
182 |
-
if st.button("Fine-Tune with Correction"):
|
183 |
-
if user_correction:
|
184 |
-
corrected_input = processor.tokenizer(user_correction).input_ids
|
185 |
-
|
186 |
-
# Dynamically add new example to dataset
|
187 |
-
dataset.append({"input_features": input_features, "labels": corrected_input})
|
188 |
-
|
189 |
-
# Perform quick re-training (1 epoch)
|
190 |
-
trainer.args.num_train_epochs = 1
|
191 |
-
trainer.train()
|
192 |
-
|
193 |
-
st.success("β
Model fine-tuned with new correction! Try another audio file.")
|
|
|
14 |
)
|
15 |
|
16 |
# ================================
|
17 |
+
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
|
18 |
# ================================
|
19 |
+
HF_TOKEN = os.getenv("hf_token") # Ensure it's set in Hugging Face Spaces Secrets
|
|
|
|
|
20 |
|
21 |
if HF_TOKEN is None:
|
22 |
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
|
23 |
|
|
|
24 |
login(token=HF_TOKEN)
|
25 |
|
26 |
# ================================
|
27 |
# 2οΈβ£ Load Model & Processor
|
28 |
# ================================
|
29 |
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
|
|
|
|
|
30 |
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
31 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
|
32 |
|
|
|
158 |
# Convert audio to model input
|
159 |
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
|
160 |
|
161 |
+
# β
FIX: Ensure input tensor is correctly formatted
|
162 |
+
input_tensor = input_features.unsqueeze(0).to(device) # Adds batch dimension
|
163 |
+
|
164 |
# Perform ASR inference
|
165 |
with torch.no_grad():
|
|
|
166 |
logits = model(input_tensor).logits
|
167 |
predicted_ids = torch.argmax(logits, dim=-1)
|
168 |
transcription = processor.batch_decode(predicted_ids)[0]
|
|
|
170 |
# Display transcription
|
171 |
st.success("π Transcription:")
|
172 |
st.write(transcription)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|