|
import os |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import gradio as gr |
|
from transformers import AutoFeatureExtractor, HubertForSequenceClassification |
|
|
|
|
|
MODEL_PATH = "./voice_emotion_checkpoint" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH) |
|
model = HubertForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
id2label = {int(k): v for k, v in model.config.id2label.items()} |
|
|
|
|
|
def predict_emotion(audio_filepath): |
|
|
|
waveform, sr = torchaudio.load(audio_filepath) |
|
waveform = waveform.numpy() |
|
|
|
if waveform.ndim > 1: |
|
waveform = np.mean(waveform, axis=0) |
|
|
|
target_sr = feature_extractor.sampling_rate |
|
if sr != target_sr: |
|
waveform = torchaudio.functional.resample( |
|
torch.from_numpy(waveform), orig_freq=sr, new_freq=target_sr |
|
).numpy() |
|
sr = target_sr |
|
|
|
inputs = feature_extractor( |
|
waveform, |
|
sampling_rate=sr, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
input_values = inputs.input_values.to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_values).logits.cpu().numpy()[0] |
|
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy() |
|
pred_id = int(np.argmax(probs)) |
|
|
|
pred_label = id2label[pred_id] |
|
label_probs = {id2label[i]: float(probs[i]) for i in range(len(probs))} |
|
return pred_label, label_probs |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_emotion, |
|
inputs=gr.Audio(type="filepath", label="Upload or Record Audio"), |
|
outputs=[ |
|
gr.Label(num_top_classes=1, label="Predicted Emotion"), |
|
gr.Label(num_top_classes=len(id2label), label="All Probabilities"), |
|
], |
|
title="Vietnamese Speech Emotion Recognition", |
|
description="Upload hoặc record audio, mô hình sẽ dự đoán cảm xúc (angry, happy, sad, …).", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|