File size: 6,359 Bytes
cd7aa15 fcd8965 cd7aa15 49df20f 723513d cd7aa15 f6dc6c7 f0a5b40 cd7aa15 e3021fc 723513d e3021fc f6dc6c7 551152e f6dc6c7 551152e 723513d cd7aa15 f0a5b40 f6dc6c7 098a61e cd7aa15 e3021fc cd7aa15 3a79217 098a61e a06453c f6dc6c7 a06453c f6dc6c7 f0a5b40 a06453c 771c2e9 e3021fc 1bb8243 771c2e9 f6dc6c7 a06453c f6dc6c7 771c2e9 a06453c 3a9d859 941924a 3a9d859 f6dc6c7 3a9d859 f6dc6c7 a06453c f6dc6c7 3a9d859 e3021fc 3a9d859 e3021fc f6dc6c7 e3021fc 941924a e3021fc 3a9d859 e3021fc f0a5b40 cd7aa15 e3021fc cd7aa15 e3021fc f0a5b40 cd7aa15 e3021fc cd7aa15 e3021fc f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 dfe80a0 f6dc6c7 e3021fc f6dc6c7 e3021fc f6dc6c7 a312467 e3021fc 49df20f f6dc6c7 e3021fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
)
from cryptography.fernet import Fernet
# ================================
# 1οΈβ£ Authenticate with Hugging Face Hub
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2οΈβ£ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"β
Model loaded on {device}")
# ================================
# 3οΈβ£ Load Dataset
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("π Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("β
Extraction complete.")
else:
print("β
Dataset already extracted.")
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
def find_audio_files(base_folder):
audio_files = []
for root, _, files in os.walk(base_folder):
for file in files:
if file.endswith(".flac"):
audio_files.append(os.path.join(root, file))
return audio_files
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"β
Found {len(audio_files)} audio files in dataset!")
# ================================
# 4οΈβ£ Load Transcripts
# ================================
def load_transcripts():
transcript_dict = {}
for root, _, files in os.walk(AUDIO_FOLDER):
for file in files:
if file.endswith(".txt"):
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split(" ", 1)
if len(parts) == 2:
file_id, text = parts
transcript_dict[file_id] = text
return transcript_dict
transcripts = load_transcripts()
if not transcripts:
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
print(f"β
Loaded {len(transcripts)} transcripts.")
# ================================
# 5οΈβ£ Streamlit Sidebar: Fine-Tuning & Security
# ================================
st.sidebar.title("π§ Fine-Tuning & Security Settings")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Adversarial Attack Strength", 0.1, 0.9, 0.3)
enable_encryption = st.sidebar.checkbox("π Encrypt Transcription", value=True)
show_transcription = st.sidebar.checkbox("π Show Transcription", value=False)
# ================================
# 6οΈβ£ Encryption Functionality
# ================================
def generate_key():
return Fernet.generate_key()
def encrypt_text(text, key):
fernet = Fernet(key)
return fernet.encrypt(text.encode())
def decrypt_text(encrypted_text, key):
fernet = Fernet(key)
return fernet.decrypt(encrypted_text).decode()
encryption_key = generate_key()
# ================================
# 7οΈβ£ Streamlit ASR Web App
# ================================
st.title("ποΈ Speech-to-Text ASR Model Finetuneed on Libri Speech Dataset with Security Features")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
# ================================
# β
Improved Adversarial Attack Handling
# ================================
noise = attack_strength * torch.randn_like(waveform)
# Apply noise but then perform denoising to counteract attack effects
adversarial_waveform = waveform + noise
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
denoised_waveform = torchaudio.functional.vad(adversarial_waveform, sample_rate=16000)
input_features = processor(denoised_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
with torch.inference_mode():
generated_ids = model.generate(
input_features,
max_length=200,
num_beams=2,
do_sample=False,
use_cache=True,
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device),
language="en"
)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.3:
st.warning("β οΈ Adversarial attack detected! Mitigated using denoising.")
# ================================
# β
Encryption Handling
# ================================
if enable_encryption:
encrypted_transcription = encrypt_text(transcription, encryption_key)
st.info("π Transcription is encrypted. To view, enable 'Show Transcription' in the sidebar.")
if show_transcription:
decrypted_text = decrypt_text(encrypted_transcription, encryption_key)
st.success("π Secure Transcription:")
st.write(decrypted_text)
else:
st.write("π [Encrypted] Transcription is hidden. Enable 'Show Transcription' to view.")
else:
st.success("π Transcription:")
st.write(transcription)
|