tahirsher's picture
Update app.py
49df20f verified
raw
history blame
5.33 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from datasets import load_dataset, DatasetDict
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 3️⃣ Load and Prepare Dataset
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
# Load dataset with transcripts
dataset = load_dataset("librispeech_asr", "clean", split="train")
# Ensure dataset has transcripts
if "text" not in dataset.column_names:
raise ValueError("❌ Dataset is missing transcription text!")
# Preprocessing Function
def preprocess_data(batch):
# Process audio
waveform, sample_rate = torchaudio.load(batch["file"])
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
batch["input_features"] = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
# Tokenize transcript text
batch["labels"] = processor.tokenizer(batch["text"], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
return batch
# Apply preprocessing
dataset = dataset.map(preprocess_data, remove_columns=["file", "audio", "text"])
# Split into train & eval
train_size = int(0.8 * len(dataset))
train_dataset = dataset.select(range(train_size))
eval_dataset = dataset.select(range(train_size, len(dataset)))
print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 4️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True,
hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
hub_token=HF_TOKEN,
)
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# ================================
# 5️⃣ Fine-Tuning Execution & Training Stats
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("βœ… Fine-Tuning Completed! Model updated.")
# Plot Training Loss
train_loss = trainer.state.log_history
losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
plt.figure(figsize=(8, 5))
plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
st.pyplot(plt)
# ================================
# 6️⃣ Streamlit ASR Web App (Proper Decoding)
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
input_tensor = input_features.to(device)
# βœ… FIX: Use `generate()` for Proper Transcription
with torch.no_grad():
generated_ids = model.generate(
input_tensor,
max_length=500,
num_beams=5,
do_sample=True,
top_k=50
)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Display transcription
st.success("πŸ“„ Transcription:")
st.write(transcription)