File size: 5,326 Bytes
cd7aa15
fcd8965
cd7aa15
 
 
 
49df20f
723513d
49df20f
cd7aa15
 
 
 
 
49df20f
cd7aa15
f0a5b40
cd7aa15
8d55ac9
723513d
49df20f
551152e
 
 
 
 
723513d
 
 
cd7aa15
 
 
 
f0a5b40
 
 
cd7aa15
098a61e
cd7aa15
49df20f
cd7aa15
3a79217
 
098a61e
cd7aa15
 
 
 
 
fcd8965
 
f0a5b40
49df20f
 
1bb8243
49df20f
 
 
1bb8243
49df20f
 
 
 
f0a5b40
49df20f
 
 
 
 
 
 
 
 
 
 
 
c3f9689
49df20f
 
3a79217
49df20f
f0a5b40
cd7aa15
49df20f
cd7aa15
f0a5b40
 
49df20f
f0a5b40
 
 
 
 
 
 
 
 
49df20f
 
723513d
f0a5b40
 
26ada0a
f0a5b40
 
 
 
3a79217
723513d
f0a5b40
 
 
cd7aa15
49df20f
cd7aa15
 
 
 
 
f0a5b40
49df20f
 
 
 
 
 
 
 
 
 
 
 
cd7aa15
49df20f
cd7aa15
49df20f
f0a5b40
 
 
 
cd7aa15
 
f0a5b40
 
cd7aa15
f0a5b40
 
8dd61a6
f0a5b40
49df20f
8d55ac9
49df20f
f0a5b40
49df20f
 
 
 
 
 
 
 
f0a5b40
 
cd7aa15
f0a5b40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from huggingface_hub import login
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)

# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("hf_token")  

if HF_TOKEN is None:
    raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")

login(token=HF_TOKEN)

# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")

# ================================
# 3️⃣ Load and Prepare Dataset
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"

if not os.path.exists(EXTRACT_PATH):
    print("πŸ”„ Extracting dataset...")
    with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
        tar.extractall(EXTRACT_PATH)
    print("βœ… Extraction complete.")
else:
    print("βœ… Dataset already extracted.")

# Load dataset with transcripts
dataset = load_dataset("librispeech_asr", "clean", split="train")

# Ensure dataset has transcripts
if "text" not in dataset.column_names:
    raise ValueError("❌ Dataset is missing transcription text!")

# Preprocessing Function
def preprocess_data(batch):
    # Process audio
    waveform, sample_rate = torchaudio.load(batch["file"])
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    
    batch["input_features"] = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]
    
    # Tokenize transcript text
    batch["labels"] = processor.tokenizer(batch["text"], padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
    
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess_data, remove_columns=["file", "audio", "text"])

# Split into train & eval
train_size = int(0.8 * len(dataset))
train_dataset = dataset.select(range(train_size))
eval_dataset = dataset.select(range(train_size, len(dataset)))

print(f"βœ… Dataset Prepared! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")

# ================================
# 4️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
    output_dir="./asr_model_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    push_to_hub=True,
    hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",
    hub_token=HF_TOKEN,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

# ================================
# 5️⃣ Fine-Tuning Execution & Training Stats
# ================================
if st.button("Start Fine-Tuning"):
    with st.spinner("Fine-tuning in progress... Please wait!"):
        trainer.train()
    st.success("βœ… Fine-Tuning Completed! Model updated.")

    # Plot Training Loss
    train_loss = trainer.state.log_history
    losses = [entry['loss'] for entry in train_loss if 'loss' in entry]
    
    plt.figure(figsize=(8, 5))
    plt.plot(range(len(losses)), losses, label="Training Loss", color="blue")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.legend()
    st.pyplot(plt)

# ================================
# 6️⃣ Streamlit ASR Web App (Proper Decoding)
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR Model with Fine-Tuning 🎢")

audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if audio_file:
    audio_path = "temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(audio_file.read())

    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features

    input_tensor = input_features.to(device)

    # βœ… FIX: Use `generate()` for Proper Transcription
    with torch.no_grad():
        generated_ids = model.generate(
            input_tensor,
            max_length=500,
            num_beams=5,
            do_sample=True,
            top_k=50
        )
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Display transcription
    st.success("πŸ“„ Transcription:")
    st.write(transcription)