tahirsher's picture
Update app.py
a06453c verified
raw
history blame
5.98 kB
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from cryptography.fernet import Fernet
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")
# ================================
# 3️⃣ Load Dataset (From Extracted Folder)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("πŸ”„ Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("βœ… Extraction complete.")
else:
print("βœ… Dataset already extracted.")
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
def find_audio_files(base_folder):
return [os.path.join(root, file)
for root, _, files in os.walk(base_folder)
for file in files if file.endswith(".flac")]
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"βœ… Found {len(audio_files)} audio files in dataset!")
# ================================
# 4️⃣ Load Transcripts
# ================================
def load_transcripts():
transcript_dict = {}
for root, _, files in os.walk(AUDIO_FOLDER):
for file in files:
if file.endswith(".txt"):
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split(" ", 1)
if len(parts) == 2:
file_id, text = parts
transcript_dict[file_id] = text
return transcript_dict
transcripts = load_transcripts()
if not transcripts:
raise FileNotFoundError("❌ No transcripts found! Check dataset structure.")
print(f"βœ… Loaded {len(transcripts)} transcripts.")
# ================================
# 5️⃣ Preprocess Dataset (Fixing `input_ids` issue)
# ================================
def load_and_process_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
return input_features
dataset = [
{
"input_features": load_and_process_audio(audio_file),
"labels": processor.tokenizer(transcripts[os.path.basename(audio_file).replace(".flac", "")],
padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
}
for audio_file in audio_files[:100] if os.path.basename(audio_file).replace(".flac", "") in transcripts
]
train_size = int(0.8 * len(dataset))
train_dataset, eval_dataset = dataset[:train_size], dataset[train_size:]
print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 6️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
# ================================
# 7️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
# Simulate an adversarial attack by injecting random noise
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
with torch.inference_mode():
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.1:
st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
st.success("πŸ“„ Secure Transcription:")
st.write(transcription)