File size: 3,007 Bytes
cd7aa15
 
 
 
723513d
15b7647
f0a5b40
cd7aa15
15b7647
723513d
15b7647
f6dc6c7
15b7647
 
 
 
723513d
 
15b7647
cd7aa15
15b7647
 
 
f0a5b40
15b7647
 
 
098a61e
cd7aa15
15b7647
3a9d859
15b7647
e3021fc
 
 
15b7647
f0a5b40
cd7aa15
15b7647
cd7aa15
15b7647
f0a5b40
 
 
 
cd7aa15
 
f0a5b40
 
8d19597
 
 
 
15b7647
 
e3021fc
15b7647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import torchaudio
import streamlit as st
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")  

if HF_TOKEN is None:
    raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")

login(token=HF_TOKEN)

# ================================
# 2️⃣ Load Conformer Model & Processor
# ================================
MODEL_NAME = "facebook/wav2vec2-conformer-rel-pos-large"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"✅ Conformer Model loaded on {device}")

# ================================
# 3️⃣ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("🔧 Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)

# ================================
# 4️⃣ Streamlit ASR Web App (Fast Decoding & Security Features)
# ================================
st.title("🎙️ Speech-to-Text ASR Model with Security Features 🎶")

audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if audio_file:
    audio_path = "temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(audio_file.read())

    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    waveform = waveform.to(dtype=torch.float32)
    
    # Simulate an adversarial attack by injecting random noise
    adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
    adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
    
    inputs = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
    input_features = inputs.input_values.to(device)
    attention_mask = inputs.attention_mask.to(device) if "attention_mask" in inputs else None
    
    with torch.inference_mode():
        generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
                                       attention_mask=attention_mask)
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    if attack_strength > 0.1:
        st.warning("⚠️ Adversarial attack detected! Transcription may be affected.")
    
    st.success("📄 Secure Transcription:")
    st.write(transcription)