import os import torch import torchaudio import numpy as np import gradio as gr from transformers import AutoFeatureExtractor, HubertForSequenceClassification # ==== 1. Cấu hình đường dẫn và thiết bị ==== MODEL_PATH = "./voice_emotion_checkpoint" # Thay đổi nếu cần DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ==== 2. Load feature extractor và model ==== feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH) model = HubertForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE) model.eval() # Nếu bạn có file id2label.json: # import json # with open(os.path.join(MODEL_PATH, "id2label.json"), "r", encoding="utf-8") as f: # id2label = json.load(f) # Ngược lại: id2label = {int(k): v for k, v in model.config.id2label.items()} # ==== 3. Hàm xử lý và dự đoán ==== def predict_emotion(audio_filepath): # 1) Load file và chuyển về numpy waveform, sr = torchaudio.load(audio_filepath) # waveform: Tensor[chân âm][time] waveform = waveform.numpy() # -> numpy array # 2) Stereo -> mono if waveform.ndim > 1: waveform = np.mean(waveform, axis=0) # 3) Resample về 16 kHz nếu cần target_sr = feature_extractor.sampling_rate if sr != target_sr: waveform = torchaudio.functional.resample( torch.from_numpy(waveform), orig_freq=sr, new_freq=target_sr ).numpy() sr = target_sr # 4) Feature extraction inputs = feature_extractor( waveform, sampling_rate=sr, return_tensors="pt", padding=True ) input_values = inputs.input_values.to(DEVICE) # 5) Inference with torch.no_grad(): logits = model(input_values).logits.cpu().numpy()[0] probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy() pred_id = int(np.argmax(probs)) # 6) Chuẩn bị output pred_label = id2label[pred_id] label_probs = {id2label[i]: float(probs[i]) for i in range(len(probs))} return pred_label, label_probs # ==== 4. Xây dựng giao diện Gradio ==== demo = gr.Interface( fn=predict_emotion, inputs=gr.Audio(type="filepath", label="Upload or Record Audio"), outputs=[ gr.Label(num_top_classes=1, label="Predicted Emotion"), gr.Label(num_top_classes=len(id2label), label="All Probabilities"), ], title="Vietnamese Speech Emotion Recognition", description="Upload hoặc record audio, mô hình sẽ dự đoán cảm xúc (angry, happy, sad, …).", ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)