ducdatit2002's picture
Upload 6 files
a4d00d9 verified
import os
import torch
import torchaudio
import numpy as np
import gradio as gr
from transformers import AutoFeatureExtractor, HubertForSequenceClassification
# ==== 1. Cấu hình đường dẫn và thiết bị ====
MODEL_PATH = "./voice_emotion_checkpoint" # Thay đổi nếu cần
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==== 2. Load feature extractor và model ====
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
model = HubertForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()
# Nếu bạn có file id2label.json:
# import json
# with open(os.path.join(MODEL_PATH, "id2label.json"), "r", encoding="utf-8") as f:
# id2label = json.load(f)
# Ngược lại:
id2label = {int(k): v for k, v in model.config.id2label.items()}
# ==== 3. Hàm xử lý và dự đoán ====
def predict_emotion(audio_filepath):
# 1) Load file và chuyển về numpy
waveform, sr = torchaudio.load(audio_filepath) # waveform: Tensor[chân âm][time]
waveform = waveform.numpy() # -> numpy array
# 2) Stereo -> mono
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=0)
# 3) Resample về 16 kHz nếu cần
target_sr = feature_extractor.sampling_rate
if sr != target_sr:
waveform = torchaudio.functional.resample(
torch.from_numpy(waveform), orig_freq=sr, new_freq=target_sr
).numpy()
sr = target_sr
# 4) Feature extraction
inputs = feature_extractor(
waveform,
sampling_rate=sr,
return_tensors="pt",
padding=True
)
input_values = inputs.input_values.to(DEVICE)
# 5) Inference
with torch.no_grad():
logits = model(input_values).logits.cpu().numpy()[0]
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
pred_id = int(np.argmax(probs))
# 6) Chuẩn bị output
pred_label = id2label[pred_id]
label_probs = {id2label[i]: float(probs[i]) for i in range(len(probs))}
return pred_label, label_probs
# ==== 4. Xây dựng giao diện Gradio ====
demo = gr.Interface(
fn=predict_emotion,
inputs=gr.Audio(type="filepath", label="Upload or Record Audio"),
outputs=[
gr.Label(num_top_classes=1, label="Predicted Emotion"),
gr.Label(num_top_classes=len(id2label), label="All Probabilities"),
],
title="Vietnamese Speech Emotion Recognition",
description="Upload hoặc record audio, mô hình sẽ dự đoán cảm xúc (angry, happy, sad, …).",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)