File size: 5,943 Bytes
cd7aa15 fcd8965 cd7aa15 49df20f 723513d cd7aa15 49df20f cd7aa15 f6dc6c7 f0a5b40 cd7aa15 8d55ac9 723513d 49df20f f6dc6c7 551152e f6dc6c7 551152e 723513d cd7aa15 f0a5b40 f6dc6c7 098a61e cd7aa15 3a9d859 cd7aa15 3a79217 098a61e cd7aa15 f6dc6c7 cd7aa15 f6dc6c7 f0a5b40 771c2e9 1bb8243 771c2e9 f6dc6c7 1bb8243 771c2e9 f6dc6c7 771c2e9 3a9d859 941924a 3a9d859 f6dc6c7 3a9d859 f6dc6c7 3a9d859 f6dc6c7 3a9d859 f6dc6c7 3a9d859 f6dc6c7 941924a f6dc6c7 3a9d859 f6dc6c7 f0a5b40 cd7aa15 f6dc6c7 cd7aa15 f6dc6c7 f0a5b40 cd7aa15 f6dc6c7 cd7aa15 f6dc6c7 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 f6dc6c7 a312467 f6dc6c7 49df20f f6dc6c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from cryptography.fernet import Fernet
# ================================
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2οΈβ£ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"β
Model loaded on {device}")
# ================================
# 3οΈβ£ Load Dataset (From Extracted Folder)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("π Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("β
Extraction complete.")
else:
print("β
Dataset already extracted.")
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")
def find_audio_files(base_folder):
audio_files = []
for root, _, files in os.walk(base_folder):
for file in files:
if file.endswith(".flac"):
audio_files.append(os.path.join(root, file))
return audio_files
audio_files = find_audio_files(AUDIO_FOLDER)
if not audio_files:
raise FileNotFoundError(f"β No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")
print(f"β
Found {len(audio_files)} audio files in dataset!")
# ================================
# 4οΈβ£ Load Transcripts
# ================================
def load_transcripts():
transcript_dict = {}
for root, _, files in os.walk(AUDIO_FOLDER):
for file in files:
if file.endswith(".txt"):
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split(" ", 1)
if len(parts) == 2:
file_id, text = parts
transcript_dict[file_id] = text
return transcript_dict
transcripts = load_transcripts()
if not transcripts:
raise FileNotFoundError("β No transcripts found! Check dataset structure.")
print(f"β
Loaded {len(transcripts)} transcripts.")
# ================================
# 5οΈβ£ Preprocess Dataset (Fixing `input_ids` issue)
# ================================
def load_and_process_audio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
return input_features
dataset = []
for audio_file in audio_files[:100]:
file_id = os.path.basename(audio_file).replace(".flac", "")
if file_id in transcripts:
input_features = load_and_process_audio(audio_file)
labels = processor.tokenizer(transcripts[file_id], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
dataset.append({"input_features": input_features, "labels": labels})
train_size = int(0.8 * len(dataset))
train_dataset = dataset[:train_size]
eval_dataset = dataset[train_size:]
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 6οΈβ£ Streamlit UI: Fine-Tuning Hyperparameter Selection
# ================================
st.sidebar.title("π§ Fine-Tuning Hyperparameters")
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3)
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5)
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8)
# ================================
# 7οΈβ£ Streamlit ASR Web App (Fast Decoding & Adversarial Attack Detection)
# ================================
st.title("ποΈ Speech-to-Text ASR Model with Security Features πΆ")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Simulate an adversarial attack by injecting random noise
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.1, 0.2, 0.5, 0.7,0.9)
adversarial_waveform = waveform + (attack_strength * torch.randn_like(waveform))
adversarial_waveform = torch.clamp(adversarial_waveform, -1.0, 1.0)
input_features = processor(adversarial_waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
with torch.inference_mode():
generated_ids = model.generate(input_features, max_length=200, num_beams=2, do_sample=False, use_cache=True, attention_mask=torch.ones(input_features.shape, dtype=torch.long).to(device))
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if attack_strength > 0.1:
st.warning("β οΈ Adversarial attack detected! Transcription secured.")
st.success("π Secure Transcription:")
st.write(transcription)
|