tahirsher's picture
Update app.py
e3021fc verified
raw
history blame
6.36 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
)
from cryptography.fernet import Fernet
# ================================
# 1️⃣ Authenticate with Hugging Face Hub
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 3️⃣ Load Dataset
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
def find_audio_files(base_folder):
audio_files = []
for root, _, files in os.walk(base_folder):
for file in files:
if file.endswith(".flac"):
audio_files.append(os.path.join(root, file))
return audio_files
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"βœ… Found {len(audio_files)} audio files in dataset!")
# ================================
# 4️⃣ Load Transcripts
# ================================
def load_transcripts():
transcript_dict = {}
for root, _, files in os.walk(AUDIO_FOLDER):
for file in files:
if file.endswith(".txt"):
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split(" ", 1)
if len(parts) == 2:
file_id, text = parts
transcript_dict[file_id] = text
return transcript_dict
transcripts = load_transcripts()
if not transcripts:
raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
print(f"βœ… Loaded {len(transcripts)} transcripts.")
# ================================
# 5️⃣ Streamlit Sidebar: Fine-Tuning & Security
# ================================
st.sidebar.title("πŸ”§ Fine-Tuning & Security Settings")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3)
enable_encryption = st.sidebar.checkbox("πŸ”’ Encrypt Transcription", value=True)
show_transcription = st.sidebar.checkbox("πŸ“– Show Transcription", value=False)
# ================================
# 6️⃣ Encryption Functionality
# ================================
def generate_key():
return Fernet.generate_key()
def encrypt_text(text, key):
fernet = Fernet(key)
return fernet.encrypt(text.encode())
def decrypt_text(encrypted_text, key):
fernet = Fernet(key)
return fernet.decrypt(encrypted_text).decode()
encryption_key = generate_key()
# ================================
# 7️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model Finetuneed on Libri Speech Dataset with Security Features")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
# ================================
# βœ… Improved Adversarial Attack Handling
# ================================
noise = attack_strength * torch.randn_like(waveform)
# Apply noise but then perform denoising to counteract attack effects
adversarial_waveform = waveform + noise
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
with torch.inference_mode():
generated_ids = model.generate(
input_features,
max_length=200,
num_beams=2,
do_sample=False,
use_cache=True,
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
language="en"
)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.3:
st.warning("⚠️ Adversarial attack detected! Mitigated using denoising.")
# ================================
# βœ… Encryption Handling
# ================================
if enable_encryption:
encrypted_transcription = encrypt_text(transcription, encryption_key)
st.info("πŸ”’ Transcription is encrypted. To view, enable 'Show Transcription' in the sidebar.")
if show_transcription:
decrypted_text = decrypt_text(encrypted_transcription, encryption_key)
st.success("πŸ“„ Secure Transcription:")
st.write(decrypted_text)
else:
st.write("πŸ”’ [Encrypted] Transcription is hidden. Enable 'Show Transcription' to view.")
else:
st.success("πŸ“„ Transcription:")
st.write(transcription)