tahirsher's picture
Update app.py
26ada0a verified
raw
history blame
5.98 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq, # βœ… Fix: Use correct data collator
)
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("HF_TOKEN") # Ensure it's set in Hugging Face Spaces Secrets
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 3️⃣ Load Dataset (Recursively from Extracted Path)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
# Extract dataset if not already extracted
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
# Base directory where audio files are stored
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
# Recursively find all `.flac` files inside the dataset directory
def find_audio_files(base_folder):
"""Recursively search for all .flac files in subdirectories."""
audio_files = []
for root, _, files in os.walk(base_folder):
for file in files:
if file.endswith(".flac"):
audio_files.append(os.path.join(root, file))
return audio_files
# Get all audio files
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"βœ… Found {len(audio_files)} audio files in dataset!")
# ================================
# 4️⃣ Preprocess Dataset (Fixed input_features)
# ================================
def load_and_process_audio(audio_path):
"""Loads and processes a single audio file into model format."""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to model input format
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
return input_features
# Manually create dataset structure
dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]
# Split dataset into train and eval
train_size = int(0.9 * len(dataset))
train_dataset = dataset[:train_size]
eval_dataset = dataset[train_size:]
print(f"βœ… Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 5️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
eval_strategy="epoch", # Fixed deprecated evaluation_strategy
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True, # Fix: Properly authenticate Hugging Face Hub
hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text", # Replace with your Hugging Face repo
hub_token=HF_TOKEN,
)
# βœ… FIX: Use correct Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
# Define Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# ================================
# 6️⃣ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("βœ… Fine-Tuning Completed! Model updated.")
# ================================
# 7️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")
# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
# Save uploaded file temporarily
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
# Load and process audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert audio to model input
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
# βœ… FIX: Ensure input tensor is correctly formatted
input_tensor = input_features.to(device) # Move to GPU/CPU
# βœ… FIX: Provide decoder_input_ids
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
# Perform ASR inference
with torch.no_grad():
logits = model(input_tensor, decoder_input_ids=decoder_input_ids).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
# Display transcription
st.success("πŸ“„ Transcription:")
st.write(transcription)