import os import tarfile import torch import torchaudio import numpy as np import streamlit as st import matplotlib.pyplot as plt from huggingface_hub import login from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq, ) from cryptography.fernet import Fernet # ================================ # 1️⃣ Authenticate with Hugging Face Hub # ================================ HF_TOKEN = os.getenv("hf_token") if HF_TOKEN is None: raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.") login(token=HF_TOKEN) # ================================ # 2️⃣ Load Model & Processor # ================================ MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text" processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✅ Model loaded on {device}") # ================================ # 3️⃣ Load Dataset # ================================ DATASET_TAR_PATH = "dev-clean.tar.gz" EXTRACT_PATH = "./librispeech_dev_clean" if not os.path.exists(EXTRACT_PATH): print("🔄 Extracting dataset...") with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar: tar.extractall(EXTRACT_PATH) print("✅ Extraction complete.") else: print("✅ Dataset already extracted.") AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean") def find_audio_files(base_folder): audio_files = [] for root, _, files in os.walk(base_folder): for file in files: if file.endswith(".flac"): audio_files.append(os.path.join(root, file)) return audio_files audio_files = find_audio_files(AUDIO_FOLDER) if not audio_files: raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!") print(f"✅ Found {len(audio_files)} audio files in dataset!") # ================================ # 4️⃣ Load Transcripts # ================================ def load_transcripts(): transcript_dict = {} for root, _, files in os.walk(AUDIO_FOLDER): for file in files: if file.endswith(".txt"): with open(os.path.join(root, file), "r", encoding="utf-8") as f: for line in f: parts = line.strip().split(" ", 1) if len(parts) == 2: file_id, text = parts transcript_dict[file_id] = text return transcript_dict transcripts = load_transcripts() if not transcripts: raise FileNotFoundError("❌ No transcripts found! Check dataset structure.") print(f"✅ Loaded {len(transcripts)} transcripts.") # ================================ # 5️⃣ Streamlit Sidebar: Fine-Tuning & Security # ================================ st.sidebar.title("🔧 Fine-Tuning & Security Settings") num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3) learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5) batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8) attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3) enable_encryption = st.sidebar.checkbox("🔒 Encrypt Transcription", value=True) show_transcription = st.sidebar.checkbox("📖 Show Transcription", value=False) # ================================ # 6️⃣ Encryption Functionality # ================================ def generate_key(): return Fernet.generate_key() def encrypt_text(text, key): fernet = Fernet(key) return fernet.encrypt(text.encode()) def decrypt_text(encrypted_text, key): fernet = Fernet(key) return fernet.decrypt(encrypted_text).decode() encryption_key = generate_key() # ================================ # 7️⃣ Streamlit ASR Web App # ================================ st.title("🎙️ Speech-to-Text ASR Model Finetuned on Libri Speech Dataset with Security Features") audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) if audio_file: audio_path = "temp_audio.wav" with open(audio_path, "wb") as f: f.write(audio_file.read()) waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) waveform = waveform.to(dtype=torch.float32) # ================================ # ✅ Improved Adversarial Attack Handling # ================================ noise = attack_strength * torch.randn_like(waveform) # Apply noise but then perform denoising to counteract attack effects adversarial_waveform = waveform + noise adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0) denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000) input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device) with torch.inference_mode(): generated_ids = model.generate( input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device), language="en" ) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] if attack_strength > 0.3: st.warning("⚠️ Adversarial attack detected! Mitigated using denoising.") # ================================ # ✅ Encryption Handling # ================================ if enable_encryption: encrypted_transcription = encrypt_text(transcription, encryption_key) st.info("🔒 Transcription is encrypted. To view, enable 'Show Transcription' in the sidebar.") if show_transcription: decrypted_text = decrypt_text(encrypted_transcription, encryption_key) st.success("📄 Secure Transcription:") st.write(decrypted_text) else: st.write("🔒 [Encrypted] Transcription is hidden. Enable 'Show Transcription' to view.") else: st.success("📄 Transcription:") st.write(transcription)