File size: 5,981 Bytes
cd7aa15 fcd8965 cd7aa15 49df20f 723513d cd7aa15 49df20f cd7aa15 f6dc6c7 f0a5b40 cd7aa15 8d55ac9 723513d 49df20f f6dc6c7 551152e f6dc6c7 551152e 723513d cd7aa15 f0a5b40 f6dc6c7 098a61e cd7aa15 a06453c cd7aa15 3a79217 098a61e a06453c f6dc6c7 a06453c f6dc6c7 f0a5b40 a06453c 771c2e9 a06453c 1bb8243 771c2e9 f6dc6c7 a06453c f6dc6c7 771c2e9 a06453c 3a9d859 941924a 3a9d859 f6dc6c7 3a9d859 f6dc6c7 a06453c f6dc6c7 3a9d859 a06453c 3a9d859 f6dc6c7 dfe80a0 f6dc6c7 a06453c 941924a f6dc6c7 a06453c 3a9d859 f6dc6c7 f0a5b40 cd7aa15 f6dc6c7 cd7aa15 f6dc6c7 dfe80a0 f0a5b40 cd7aa15 a06453c cd7aa15 1cf13ee f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 dfe80a0 f6dc6c7 a06453c f6dc6c7 a312467 a06453c 49df20f f6dc6c7 dfe80a0 f6dc6c7 a06453c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from cryptography.fernet import Fernet
# ================================
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2οΈβ£ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"β
Model loaded on {device}")
# ================================
# 3οΈβ£ Load Dataset (From Extracted Folder)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("π Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("β
Extraction complete.")
else:
print("β
Dataset already extracted.")
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
def find_audio_files(base_folder):
return [os.path.join(root, file)
for root, _, files in os.walk(base_folder)
for file in files if file.endswith(".flac")]
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"β
Found {len(audio_files)} audio files in dataset!")
# ================================
# 4οΈβ£ Load Transcripts
# ================================
def load_transcripts():
transcript_dict = {}
for root, _, files in os.walk(AUDIO_FOLDER):
for file in files:
if file.endswith(".txt"):
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split(" ", 1)
if len(parts) == 2:
file_id, text = parts
transcript_dict[file_id] = text
return transcript_dict
transcripts = load_transcripts()
if not transcripts:
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
print(f"β
Loaded {len(transcripts)} transcripts.")
# ================================
# 5οΈβ£ Preprocess Dataset (Fixing `input_ids` issue)
# ================================
def load_and_process_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
return input_features
dataset = [
{
"input_features": load_and_process_audio(audio_file),
"labels": processor.tokenizer(transcripts[os.path.basename(audio_file).replace(".flac", "")],
padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
}
for audio_file in audio_files[:100] if os.path.basename(audio_file).replace(".flac", "") in transcripts
]
train_size = int(0.8 * len(dataset))
train_dataset, eval_dataset = dataset[:train_size], dataset[train_size:]
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 6οΈβ£ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("π§ Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1)
# ================================
# 7οΈβ£ Streamlit ASR Web App (Fast Decoding & Security Features)
# ================================
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
waveform = waveform.to(dtype=torch.float32)
# Simulate an adversarial attack by injecting random noise
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
with torch.inference_mode():
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True,
attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.1:
st.warning("β οΈ Adversarial attack detected! Transcription may be affected.")
st.success("π Secure Transcription:")
st.write(transcription) |