File size: 3,835 Bytes
f0a5b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a68fd86
 
f0a5b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

# Load Processor & Model
processor = AutoProcessor.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")
model = AutoModelForSpeechSeq2Seq.from_pretrained("AqeelShafy7/AudioSangraha-Audio_to_Text")

# Move model to GPU if available
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on {device}")

from datasets import load_dataset
import torchaudio
import torch

# Fix: Add trust_remote_code=True
dataset = load_dataset("librispeech_asr", "clean", split="train", trust_remote_code=True)


# Function to load & resample audio
def preprocess_audio(batch):
    audio = batch["audio"]
    waveform, sample_rate = torchaudio.load(audio["path"])
    
    # Resample to 16kHz (ASR models usually require this)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    
    # Convert to correct format
    batch["input_values"] = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess_audio, remove_columns=["audio"])

from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

# Define Training Arguments
training_args = TrainingArguments(
    output_dir="./asr_model_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
    save_total_limit=2,
    push_to_hub=True,  # Enable uploading to Hugging Face Hub
)

# Define Data Collator
data_collator = DataCollatorForSeq2Seq(processor.tokenizer, model=model)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=None,  # We use only training data here
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
)

# Start Fine-Tuning
trainer.train()

# Deployment of Huggingface using streamlit
import streamlit as st
import soundfile as sf
import numpy as np

st.title("๐ŸŽ™๏ธ Automatic Speech Recognition with Fine-Tuning ๐ŸŽถ")

# Upload audio file
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])

if audio_file:
    # Save and load audio file
    with open("temp_audio.wav", "wb") as f:
        f.write(audio_file.read())

    waveform, sample_rate = torchaudio.load("temp_audio.wav")

    # Resample to 16kHz
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

    # Convert to model input
    input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000).input_values[0]

    # Perform transcription
    with torch.no_grad():
        input_tensor = torch.tensor([input_values]).to(device)
        logits = model(input_tensor).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]

    # Display transcription
    st.success("Transcription:")
    st.write(transcription)

    # Fine-tune with user input
    user_correction = st.text_area("Correct the transcription (if needed):")
    
    if st.button("Fine-Tune Model"):
        if user_correction:
            # Convert correction to training format
            corrected_input = processor.tokenizer(user_correction).input_ids

            # Update dataset dynamically (simple approach)
            dataset = dataset.add_item({"input_values": input_values, "labels": corrected_input})

            # Retrain for one step
            trainer.train()

            st.success("Model fine-tuned successfully! Try another audio file.")