File size: 2,682 Bytes
a4d00d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
import torch
import torchaudio
import numpy as np
import gradio as gr
from transformers import AutoFeatureExtractor, HubertForSequenceClassification
# ==== 1. Cấu hình đường dẫn và thiết bị ====
MODEL_PATH = "./voice_emotion_checkpoint" # Thay đổi nếu cần
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==== 2. Load feature extractor và model ====
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
model = HubertForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()
# Nếu bạn có file id2label.json:
# import json
# with open(os.path.join(MODEL_PATH, "id2label.json"), "r", encoding="utf-8") as f:
# id2label = json.load(f)
# Ngược lại:
id2label = {int(k): v for k, v in model.config.id2label.items()}
# ==== 3. Hàm xử lý và dự đoán ====
def predict_emotion(audio_filepath):
# 1) Load file và chuyển về numpy
waveform, sr = torchaudio.load(audio_filepath) # waveform: Tensor[chân âm][time]
waveform = waveform.numpy() # -> numpy array
# 2) Stereo -> mono
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=0)
# 3) Resample về 16 kHz nếu cần
target_sr = feature_extractor.sampling_rate
if sr != target_sr:
waveform = torchaudio.functional.resample(
torch.from_numpy(waveform), orig_freq=sr, new_freq=target_sr
).numpy()
sr = target_sr
# 4) Feature extraction
inputs = feature_extractor(
waveform,
sampling_rate=sr,
return_tensors="pt",
padding=True
)
input_values = inputs.input_values.to(DEVICE)
# 5) Inference
with torch.no_grad():
logits = model(input_values).logits.cpu().numpy()[0]
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
pred_id = int(np.argmax(probs))
# 6) Chuẩn bị output
pred_label = id2label[pred_id]
label_probs = {id2label[i]: float(probs[i]) for i in range(len(probs))}
return pred_label, label_probs
# ==== 4. Xây dựng giao diện Gradio ====
demo = gr.Interface(
fn=predict_emotion,
inputs=gr.Audio(type="filepath", label="Upload or Record Audio"),
outputs=[
gr.Label(num_top_classes=1, label="Predicted Emotion"),
gr.Label(num_top_classes=len(id2label), label="All Probabilities"),
],
title="Vietnamese Speech Emotion Recognition",
description="Upload hoặc record audio, mô hình sẽ dự đoán cảm xúc (angry, happy, sad, …).",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|