import os import tarfile import torch import torchaudio import numpy as np import streamlit as st import matplotlib.pyplot as plt from huggingface_hub import login from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq, TrainingArguments, Trainer, DataCollatorForSeq2Seq, ) from cryptography.fernet import Fernet # ================================ # 1️⃣ Authenticate with Hugging Face Hub (Securely) # ================================ HF_TOKEN = os.getenv("hf_token") if HF_TOKEN is None: raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.") login(token=HF_TOKEN) # ================================ # 2️⃣ Load Model & Processor # ================================ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text" processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✅ Model loaded on {device}") # ================================ # 3️⃣ Load Dataset (With Fixes) # ================================ DATASET_TAR_PATH = "dev-clean.tar.gz" EXTRACT_PATH = "./librispeech_dev_clean" AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean") if not os.path.exists(AUDIO_FOLDER): print("🔄 Extracting dataset...") try: with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar: tar.extractall(EXTRACT_PATH) print("✅ Extraction complete.") except Exception as e: raise RuntimeError(f"❌ Dataset extraction failed: {e}") else: print("✅ Dataset already extracted.") def find_audio_files(base_folder): return [os.path.join(root, file) for root, _, files in os.walk(base_folder) for file in files if file.endswith(".flac")] audio_files = find_audio_files(AUDIO_FOLDER) if not audio_files: raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!") print(f"✅ Found {len(audio_files)} audio files in dataset!") # ================================ # 4️⃣ Load Transcripts (Fixed Mapping) # ================================ def load_transcripts(): transcript_dict = {} for root, _, files in os.walk(AUDIO_FOLDER): for file in files: if file.endswith(".txt"): with open(os.path.join(root, file), "r", encoding="utf-8") as f: for line in f: parts = line.strip().split(" ", 1) if len(parts) == 2: file_id, text = parts transcript_dict[file_id] = text return transcript_dict transcripts = load_transcripts() if not transcripts: raise FileNotFoundError("❌ No transcripts found! Check dataset structure.") print(f"✅ Loaded {len(transcripts)} transcripts.") # ================================ # 5️⃣ Preprocess Dataset (Fixed `input_ids` Issue) # ================================ def load_and_process_audio(audio_path): waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) waveform = waveform.to(dtype=torch.float32) input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0] return input_features dataset = [] for audio_file in audio_files[:100]: file_id = os.path.basename(audio_file).replace(".flac", "") if file_id in transcripts: input_features = load_and_process_audio(audio_file) labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0] dataset.append({"input_features": input_features, "labels": labels}) train_size = int(0.8 * len(dataset)) train_dataset = dataset[:train_size] eval_dataset = dataset[train_size:] print(f"✅ Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}") # ================================ # 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection # ================================ st.sidebar.title("🔧 Fine-Tuning Hyperparameters") num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3) learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5) batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8) attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1) # ================================ # 7️⃣ Streamlit ASR Web App (Fixed Security & Processing) # ================================ st.title("🎙️ Speech-to-Text ASR Model with Security Features 🎶") audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) if audio_file: audio_path = "temp_audio.wav" with open(audio_path, "wb") as f: f.write(audio_file.read()) waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) waveform = waveform.to(dtype=torch.float32) # Apply adversarial attack noise with limit noise = torch.randn_like(waveform) * attack_strength adversarial_waveform = torch.clamp(waveform + noise, -1.0, 1.0) input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device) with torch.inference_mode(): generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device)) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if attack_strength > 0.1: st.warning("⚠️ Adversarial attack detected! Transcription may be affected.") st.success("📄 Secure Transcription:") st.write(transcription)