File size: 3,007 Bytes
cd7aa15 723513d 15b7647 f0a5b40 cd7aa15 15b7647 723513d 15b7647 f6dc6c7 15b7647 723513d 15b7647 cd7aa15 15b7647 f0a5b40 15b7647 098a61e cd7aa15 15b7647 3a9d859 15b7647 e3021fc 15b7647 f0a5b40 cd7aa15 15b7647 cd7aa15 15b7647 f0a5b40 cd7aa15 f0a5b40 8d19597 15b7647 e3021fc 15b7647 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import torch
import torchaudio
import streamlit as st
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2️⃣ Load Conformer Model & Processor
# ================================
MODEL_NAME = "facebook/wav2vec2-conformer-rel-pos-large"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"✅ Conformer Model loaded on {device}")
# ================================
# 3️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("🔧 Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
# ================================
# 4️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
# ================================
st.title("🎙️ Speech-to-Text ASR Model with Security Features 🎶")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
# Simulate an adversarial attack by injecting random noise
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
inputs = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_values.to(device)
attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None
with torch.inference_mode():
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
attention_mask=attention_mask)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.1:
st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
st.success("📄 Secure Transcription:")
st.write(transcription)
|