File size: 5,501 Bytes
cd7aa15 fcd8965 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 098a61e cd7aa15 fcd8965 cd7aa15 fcd8965 098a61e fcd8965 cd7aa15 fcd8965 f0a5b40 fcd8965 cd7aa15 fcd8965 393feaa cd7aa15 f0a5b40 fcd8965 f0a5b40 cd7aa15 fcd8965 f0a5b40 cd7aa15 f0a5b40 fcd8965 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 f0a5b40 cd7aa15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from datasets import load_dataset
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
# ================================
# 1οΈβ£ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
# Load ASR model and processor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"β
Model loaded on {device}")
# ================================
# 2οΈβ£ Load Dataset (LibriSpeech) from Extracted Path
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz" # Uploaded dataset in your Hugging Face Space
EXTRACT_PATH = "./librispeech_dev_clean" # Extracted dataset folder
# Extract dataset only if not already extracted
if not os.path.exists(EXTRACT_PATH):
print("π Extracting dataset...")
with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
tar.extractall(EXTRACT_PATH)
print("β
Extraction complete.")
else:
print("β
Dataset already extracted.")
# β
Load dataset from extracted folder
dataset = load_dataset("librispeech_asr", data_dir=EXTRACT_PATH, split="train", trust_remote_code=True)
print(f"β
Dataset Loaded Successfully! Size: {len(dataset)}")
# ================================
# 3οΈβ£ Preprocess Dataset
# ================================
def preprocess_audio(batch):
"""Converts raw audio to a model-compatible format."""
audio = batch["audio"]
waveform, sample_rate = torchaudio.load(audio["path"])
# Resample to 16kHz (ASR models usually require this)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert to model input format
batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
batch["labels"] = processor.tokenizer(batch["text"]).input_ids
return batch
# Apply preprocessing
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])
print(f"β
Dataset Preprocessed! Ready for Fine-Tuning.")
# ================================
# 4οΈβ£ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
output_dir="./asr_model_finetuned",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=500,
save_total_limit=2,
push_to_hub=True,
metric_for_best_model="wer",
greater_is_better=False,
save_on_each_node=True, # Improves stability during multi-GPU training
load_best_model_at_end=True, # Saves best model
)
# Data collator (for dynamic padding)
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)
# Define Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=None, # No validation dataset for now
tokenizer=processor.feature_extractor,
data_collator=data_collator,
)
# ================================
# 5οΈβ£ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
with st.spinner("Fine-tuning in progress... Please wait!"):
trainer.train()
st.success("β
Fine-Tuning Completed! Model updated.")
# ================================
# 6οΈβ£ Streamlit ASR Web App
# ================================
st.title("ποΈ Speech-to-Text ASR with Fine-Tuning πΆ")
# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
if audio_file:
# Save uploaded file temporarily
audio_path = "temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(audio_file.read())
# Load and process audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Convert audio to model input
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
# Perform ASR inference
with torch.no_grad():
input_tensor = torch.tensor([input_values]).to(device)
logits = model(input_tensor).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
# Display transcription
st.success("π Transcription:")
st.write(transcription)
# ================================
# 7οΈβ£ Fine-Tune Model with User Correction
# ================================
user_correction = st.text_area("π§ Correct the transcription (if needed):", transcription)
if st.button("Fine-Tune with Correction"):
if user_correction:
corrected_input = processor.tokenizer(user_correction).input_ids
# Dynamically add new example to dataset
dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})
# Perform quick re-training (1 epoch)
trainer.args.num_train_epochs = 1
trainer.train()
st.success("β
Model fine-tuned with new correction! Try another audio file.")
|