File size: 5,326 Bytes
cd7aa15 fcd8965 cd7aa15 49df20f 723513d 49df20f cd7aa15 49df20f cd7aa15 f0a5b40 cd7aa15 8d55ac9 723513d 49df20f 551152e 723513d cd7aa15 f0a5b40 cd7aa15 098a61e cd7aa15 49df20f cd7aa15 3a79217 098a61e cd7aa15 fcd8965 f0a5b40 49df20f 1bb8243 49df20f 1bb8243 49df20f f0a5b40 49df20f c3f9689 49df20f 3a79217 49df20f f0a5b40 cd7aa15 49df20f cd7aa15 f0a5b40 49df20f f0a5b40 49df20f 723513d f0a5b40 26ada0a f0a5b40 3a79217 723513d f0a5b40 cd7aa15 49df20f cd7aa15 f0a5b40 49df20f cd7aa15 49df20f cd7aa15 49df20f f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 8dd61a6 f0a5b40 49df20f 8d55ac9 49df20f f0a5b40 49df20f f0a5b40 cd7aa15 f0a5b40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from datasets import load_dataset, DatasetDict
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
# ================================
# 1οΈβ£ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")
if HF_TOKEN is None:
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.")
login(token=HF_TOKEN)
# ================================
# 2οΈβ£ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"β
Model loaded on {device}")
# ================================
# 3οΈβ£ Load and Prepare Dataset
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"
if not os.path.exists(EXTRACT_PATH):
print("π Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("β
Extraction complete.")
else:
print("β
Dataset already extracted.")
# Load dataset with transcripts
dataset = load_dataset("librispeech_asr", "clean", split="train")
# Ensure dataset has transcripts
if "text" not in dataset.column_names:
raise ValueError("β Dataset is missing transcription text!")
# Preprocessing Function
def preprocess_data(batch):
# Process audio
waveform, sample_rate = torchaudio.load(batch["file"])
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
batch["input_features"] = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
# Tokenize transcript text
batch["labels"] = processor.tokenizer(batch["text"], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
return batch
# Apply preprocessing
dataset = dataset.map(preprocess_data, remove_columns=["file", "audio", "text"])
# Split into train & eval
train_size = int(0.8 * len(dataset))
train_dataset = dataset.select(range(train_size))
eval_dataset = dataset.select(range(train_size, len(dataset)))
print(f"β
Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")
# ================================
# 4οΈβ£ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True,
hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
hub_token=HF_TOKEN,
)
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# ================================
# 5οΈβ£ Fine-Tuning Execution & Training Stats
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("β
Fine-Tuning Completed! Model updated.")
# Plot Training Loss
train_loss = trainer.state.log_history
losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
plt.figure(figsize=(8, 5))
plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
st.pyplot(plt)
# ================================
# 6οΈβ£ Streamlit ASR Web App (Proper Decoding)
# ================================
st.title("ποΈ Speech-to-Text ASR Model with Fine-Tuning πΆ")
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features
input_tensor = input_features.to(device)
# β
FIX: Use `generate()` for Proper Transcription
with torch.no_grad():
generated_ids = model.generate(
input_tensor,
max_length=500,
num_beams=5,
do_sample=True,
top_k=50
)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Display transcription
st.success("π Transcription:")
st.write(transcription)
|