File size: 5,978 Bytes
cd7aa15
fcd8965
cd7aa15
 
 
 
723513d
cd7aa15
 
 
 
 
26ada0a
cd7aa15
f0a5b40
cd7aa15
8d55ac9
723513d
26ada0a
551152e
 
 
 
 
723513d
 
 
cd7aa15
 
 
 
f0a5b40
 
 
 
cd7aa15
098a61e
cd7aa15
723513d
cd7aa15
3a79217
 
098a61e
3a79217
cd7aa15
 
 
 
 
fcd8965
 
f0a5b40
3a79217
1bb8243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393feaa
cd7aa15
723513d
cd7aa15
d2d38cf
 
 
cd7aa15
d2d38cf
f0a5b40
cd7aa15
3a79217
7d7504d
d2d38cf
7d7504d
d2d38cf
 
3a79217
f0a5b40
723513d
3a79217
 
 
 
 
f0a5b40
cd7aa15
723513d
cd7aa15
f0a5b40
 
723513d
f0a5b40
 
 
 
 
 
 
 
 
723513d
fc83c0b
723513d
f0a5b40
 
859561d
26ada0a
f0a5b40
859561d
f0a5b40
 
 
3a79217
723513d
f0a5b40
 
 
cd7aa15
723513d
cd7aa15
 
 
 
 
f0a5b40
cd7aa15
723513d
cd7aa15
 
f0a5b40
 
 
 
 
cd7aa15
 
 
f0a5b40
 
cd7aa15
 
f0a5b40
 
3a79217
8dd61a6
f0a5b40
8d55ac9
8dd61a6
 
 
 
8d55ac9
cd7aa15
f0a5b40
8dd61a6
f0a5b40
 
 
 
cd7aa15
f0a5b40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import tarfile
import torch
import torchaudio
import numpy as np
import streamlit as st
from huggingface_hub import login
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,  # βœ… Fix: Use correct data collator
)

# ================================
# 1️⃣ Authenticate with Hugging Face Hub (Securely)
# ================================
HF_TOKEN = os.getenv("HF_TOKEN")  # Ensure it's set in Hugging Face Spaces Secrets

if HF_TOKEN is None:
    raise ValueError("❌ Hugging Face API token not found. Please set it in Secrets.")

login(token=HF_TOKEN)

# ================================
# 2️⃣ Load Model & Processor
# ================================
MODEL_NAME = "AqeelShafy7/AudioSangraha-Audio_to_Text"
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"βœ… Model loaded on {device}")

# ================================
# 3️⃣ Load Dataset (Recursively from Extracted Path)
# ================================
DATASET_TAR_PATH = "dev-clean.tar.gz"
EXTRACT_PATH = "./librispeech_dev_clean"

# Extract dataset if not already extracted
if not os.path.exists(EXTRACT_PATH):
    print("πŸ”„ Extracting dataset...")
    with tarfile.open(DATASET_TAR_PATH, "r:gz") as tar:
        tar.extractall(EXTRACT_PATH)
    print("βœ… Extraction complete.")
else:
    print("βœ… Dataset already extracted.")

# Base directory where audio files are stored
AUDIO_FOLDER = os.path.join(EXTRACT_PATH, "LibriSpeech", "dev-clean")

# Recursively find all `.flac` files inside the dataset directory
def find_audio_files(base_folder):
    """Recursively search for all .flac files in subdirectories."""
    audio_files = []
    for root, _, files in os.walk(base_folder):
        for file in files:
            if file.endswith(".flac"):
                audio_files.append(os.path.join(root, file))
    return audio_files

# Get all audio files
audio_files = find_audio_files(AUDIO_FOLDER)

if not audio_files:
    raise FileNotFoundError(f"❌ No .flac files found in {AUDIO_FOLDER}. Check dataset structure!")

print(f"βœ… Found {len(audio_files)} audio files in dataset!")

# ================================
# 4️⃣ Preprocess Dataset (Fixed input_features)
# ================================
def load_and_process_audio(audio_path):
    """Loads and processes a single audio file into model format."""
    waveform, sample_rate = torchaudio.load(audio_path)

    # Resample to 16kHz
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    # Convert to model input format
    input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]

    return input_features

# Manually create dataset structure
dataset = [{"input_features": load_and_process_audio(f), "labels": []} for f in audio_files[:100]]

# Split dataset into train and eval
train_size = int(0.9 * len(dataset))
train_dataset = dataset[:train_size]
eval_dataset = dataset[train_size:]

print(f"βœ… Dataset Loaded! Training: {len(train_dataset)}, Evaluation: {len(eval_dataset)}")

# ================================
# 5️⃣ Training Arguments & Trainer
# ================================
training_args = TrainingArguments(
    output_dir="./asr_model_finetuned",
    eval_strategy="epoch",  # Fixed deprecated evaluation_strategy
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    push_to_hub=True,  # Fix: Properly authenticate Hugging Face Hub
    hub_model_id="tahirsher/ASR_Model_for_Transcription_into_Text",  # Replace with your Hugging Face repo
    hub_token=HF_TOKEN,
)

# βœ… FIX: Use correct Data Collator
data_collator = DataCollatorForSeq2Seq(tokenizer=processor.tokenizer, model=model, return_tensors="pt")

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

# ================================
# 6️⃣ Fine-Tuning Execution
# ================================
if st.button("Start Fine-Tuning"):
    with st.spinner("Fine-tuning in progress... Please wait!"):
        trainer.train()
    st.success("βœ… Fine-Tuning Completed! Model updated.")

# ================================
# 7️⃣ Streamlit ASR Web App
# ================================
st.title("πŸŽ™οΈ Speech-to-Text ASR with Fine-Tuning 🎢")

# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if audio_file:
    # Save uploaded file temporarily
    audio_path = "temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(audio_file.read())

    # Load and process audio
    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    # Convert audio to model input
    input_features = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features

    # βœ… FIX: Ensure input tensor is correctly formatted
    input_tensor = input_features.to(device)  # Move to GPU/CPU

    # βœ… FIX: Provide decoder_input_ids
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)

    # Perform ASR inference
    with torch.no_grad():
        logits = model(input_tensor, decoder_input_ids=decoder_input_ids).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

    # Display transcription
    st.success("πŸ“„ Transcription:")
    st.write(transcription)