import os import tarfile import torch import torchaudio import numpy as np import streamlit as st import matplotlib.pyplot as plt from huggingface_hub import login from datasets import load_dataset, DatasetDict from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq, TrainingArguments, Trainer, DataCollatorForSeq2Seq, ) # ================================ # 1️⃣ Authenticate with Hugging Face Hub (Securely) # ================================ HF_TOKEN = os.getenv("hf_token") if HF_TOKEN is None: raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.") login(token=HF_TOKEN) # ================================ # 2️⃣ Load Model & Processor # ================================ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text" processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✅ Model loaded on {device}") # ================================ # 3️⃣ Load and Prepare Dataset # ================================ DATASET_TAR_PATH = "dev-clean.tar.gz" EXTRACT_PATH = "./librispeech_dev_clean" if not os.path.exists(EXTRACT_PATH): print("🔄 Extracting dataset...") with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar: tar.extractall(EXTRACT_PATH) print("✅ Extraction complete.") else: print("✅ Dataset already extracted.") # Load dataset with transcripts dataset = load_dataset("librispeech_asr", "clean", split="train") # Ensure dataset has transcripts if "text" not in dataset.column_names: raise ValueError("❌ Dataset is missing transcription text!") # Preprocessing Function def preprocess_data(batch): # Process audio waveform, sample_rate = torchaudio.load(batch["file"]) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) batch["input_features"] = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0] # Tokenize transcript text batch["labels"] = processor.tokenizer(batch["text"], padding="max_length", truncation=True, return_tensors="pt").input_ids[0] return batch # Apply preprocessing dataset = dataset.map(preprocess_data, remove_columns=["file", "audio", "text"]) # Split into train & eval train_size = int(0.8 * len(dataset)) train_dataset = dataset.select(range(train_size)) eval_dataset = dataset.select(range(train_size, len(dataset))) print(f"✅ Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}") # ================================ # 4️⃣ Training Arguments & Trainer # ================================ training_args = TrainingArguments( output_dir="./asr_model_finetuned", evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, logging_dir="./logs", logging_steps=500, save_total_limit=2, push_to_hub=True, hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text", hub_token=HF_TOKEN, ) data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt") trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, ) # ================================ # 5️⃣ Fine-Tuning Execution & Training Stats # ================================ if st.button("Start Fine-Tuning"): with st.spinner("Fine-tuning in progress... Please wait!"): trainer.train() st.success("✅ Fine-Tuning Completed! Model updated.") # Plot Training Loss train_loss = trainer.state.log_history losses = [entry['loss'] for entry in train_loss if 'loss' in entry] plt.figure(figsize=(8, 5)) plt.plot(range(len(losses)), losses, label="Training Loss", color="blue") plt.xlabel("Steps") plt.ylabel("Loss") plt.title("Training Loss Over Time") plt.legend() st.pyplot(plt) # ================================ # 6️⃣ Streamlit ASR Web App (Proper Decoding) # ================================ st.title("🎙️ Speech-to-Text ASR Model with Fine-Tuning 🎶") audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) if audio_file: audio_path = "temp_audio.wav" with open(audio_path, "wb") as f: f.write(audio_file.read()) waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features input_tensor = input_features.to(device) # ✅ FIX: Use `generate()` for Proper Transcription with torch.no_grad(): generated_ids = model.generate( input_tensor, max_length=500, num_beams=5, do_sample=True, top_k=50 ) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Display transcription st.success("📄 Transcription:") st.write(transcription)