tahirsher's picture
Update app.py
15b7647 verified
raw
history blame
3.01 kB
import os
import torch
import torchaudio
import streamlit as st
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Conformer Model & Processor
# ================================
MODEL_NAME = "facebook/wav2vec2-conformer-rel-pos-large"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Conformer Model loaded on {device}")
# ================================
# 3️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("πŸ”§ Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
# ================================
# 4️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Security Features 🎢")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
# Simulate an adversarial attack by injecting random noise
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
inputs = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_values.to(device)
attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None
with torch.inference_mode():
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
attention_mask=attention_mask)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.1:
st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
st.success("πŸ“„ Secure Transcription:")
st.write(transcription)