File size: 2,682 Bytes
a4d00d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
import torchaudio
import numpy as np
import gradio as gr
from transformers import AutoFeatureExtractor, HubertForSequenceClassification

# ==== 1. Cấu hình đường dẫn và thiết bị ====
MODEL_PATH = "./voice_emotion_checkpoint"  # Thay đổi nếu cần
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== 2. Load feature extractor và model ====
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
model = HubertForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()

# Nếu bạn có file id2label.json:
# import json
# with open(os.path.join(MODEL_PATH, "id2label.json"), "r", encoding="utf-8") as f:
#     id2label = json.load(f)
# Ngược lại:
id2label = {int(k): v for k, v in model.config.id2label.items()}

# ==== 3. Hàm xử lý và dự đoán ====
def predict_emotion(audio_filepath):
    # 1) Load file và chuyển về numpy
    waveform, sr = torchaudio.load(audio_filepath)           # waveform: Tensor[chân âm][time]
    waveform = waveform.numpy()                              # -> numpy array
    # 2) Stereo -> mono
    if waveform.ndim > 1:
        waveform = np.mean(waveform, axis=0)
    # 3) Resample về 16 kHz nếu cần
    target_sr = feature_extractor.sampling_rate
    if sr != target_sr:
        waveform = torchaudio.functional.resample(
            torch.from_numpy(waveform), orig_freq=sr, new_freq=target_sr
        ).numpy()
        sr = target_sr
    # 4) Feature extraction
    inputs = feature_extractor(
        waveform,
        sampling_rate=sr,
        return_tensors="pt",
        padding=True
    )
    input_values = inputs.input_values.to(DEVICE)
    # 5) Inference
    with torch.no_grad():
        logits = model(input_values).logits.cpu().numpy()[0]
        probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
        pred_id = int(np.argmax(probs))
    # 6) Chuẩn bị output
    pred_label = id2label[pred_id]
    label_probs = {id2label[i]: float(probs[i]) for i in range(len(probs))}
    return pred_label, label_probs

# ==== 4. Xây dựng giao diện Gradio ====
demo = gr.Interface(
    fn=predict_emotion,
    inputs=gr.Audio(type="filepath", label="Upload or Record Audio"),
    outputs=[
        gr.Label(num_top_classes=1, label="Predicted Emotion"),
        gr.Label(num_top_classes=len(id2label), label="All Probabilities"),
    ],
    title="Vietnamese Speech Emotion Recognition",
    description="Upload hoặc record audio, mô hình sẽ dự đoán cảm xúc (angry, happy, sad, …).",
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)