|
import os |
|
import torch |
|
import torchaudio |
|
import streamlit as st |
|
from huggingface_hub import login |
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq |
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("hf_token") |
|
|
|
if HF_TOKEN is None: |
|
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.") |
|
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "facebook/wav2vec2-conformer-rel-pos-large" |
|
processor = AutoProcessor.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
print(f"β
Conformer Model loaded on {device}") |
|
|
|
|
|
|
|
|
|
st.sidebar.title("π§ Fine-Tuning Hyperparameters") |
|
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3) |
|
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5) |
|
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8) |
|
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1) |
|
|
|
|
|
|
|
|
|
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ") |
|
|
|
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) |
|
|
|
if audio_file: |
|
audio_path = "temp_audio.wav" |
|
with open(audio_path, "wb") as f: |
|
f.write(audio_file.read()) |
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) |
|
waveform = waveform.to(dtype=torch.float32) |
|
|
|
|
|
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform)) |
|
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0) |
|
|
|
inputs = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") |
|
input_features = inputs.input_values.to(device) |
|
attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None |
|
|
|
with torch.inference_mode(): |
|
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, |
|
attention_mask=attention_mask) |
|
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
if attack_strength > 0.1: |
|
st.warning("β οΈ Adversarial attack detected! Transcription may be affected.") |
|
|
|
st.success("π Secure Transcription:") |
|
st.write(transcription) |
|
|