|
|
|
import gradio as gr |
|
from transformers import Wav2Vec2Processor |
|
import torch |
|
import librosa |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
|
|
class Wav2Vec2Classifier(torch.nn.Module): |
|
def __init__(self, num_classes): |
|
super().__init__() |
|
from transformers import Wav2Vec2Model |
|
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") |
|
self.dropout = torch.nn.Dropout(0.3) |
|
self.classifier = torch.nn.Linear(self.wav2vec2.config.hidden_size, num_classes) |
|
|
|
def forward(self, input_values, attention_mask=None): |
|
outputs = self.wav2vec2(input_values, attention_mask=attention_mask) |
|
pooled_output = outputs.last_hidden_state.mean(dim=1) |
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|
|
processor = Wav2Vec2Processor.from_pretrained("hrid0yyy/BornoNet") |
|
num_classes = 50 |
|
model = Wav2Vec2Classifier(num_classes=num_classes) |
|
model.load_state_dict(torch.load(hf_hub_download("hrid0yyy/BornoNet", "pytorch_model.bin"), map_location="cpu")) |
|
model.eval() |
|
le_classes = np.load(hf_hub_download("hrid0yyy/BornoNet", "label_encoder_classes.npy"), allow_pickle=True) |
|
|
|
def predict(audio): |
|
try: |
|
y, sr = librosa.load(audio, sr=16000) |
|
inputs = processor(y, sampling_rate=sr, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = model(inputs.input_values) |
|
predicted = le_classes[torch.argmax(logits, dim=1).item()] |
|
return f"Predicted character: {predicted}" |
|
except Exception as e: |
|
return f"Error processing audio: {str(e)}" |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Audio(type="filepath", label="Upload an MP3 file (16kHz)"), |
|
outputs=gr.Textbox(label="Prediction"), |
|
title="BornoNet: Bengali Speech Recognition", |
|
description="Upload a 16kHz MP3 file to classify Bengali speech into characters (e.g., ত, অ, ক)." |
|
) |
|
iface.launch() |