File size: 2,944 Bytes
2b7c233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import whisper
import torch
import numpy as np
from feature import (
    AudioTextEmotionModel,
    extract_audio_features,
    extract_text_features
)

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 載入模型
emotion_model = AudioTextEmotionModel(audio_input_dim=180, text_input_dim=768, hidden_dim=128, output_dim=3)
emotion_model.load_state_dict(torch.load("model_weights.pth", map_location=device))
emotion_model.to(device)
emotion_model.eval()

# Whisper 模型
whisper_model = whisper.load_model("base")
EMOTION_LABELS = {0: '正面', 1: '中性', 2: '負面'}

# 情緒預測主函式(支援語音 / 文字 / 雙模)
def analyze_input(audio, text_input):
    audio_feat = None
    text_feat = None
    result_text = ""
    
    # 若有語音輸入
    if audio:
        result = whisper_model.transcribe(audio, language="zh")
        transcribed_text = result["text"]
        result_text += f"🎧 語音轉文字:「{transcribed_text}」\n"
        audio_feat = extract_audio_features(audio)
    else:
        transcribed_text = None

    # 若有文字輸入(用戶輸入或語音轉出)
    text = text_input or transcribed_text
    if text:
        text_feat = extract_text_features(text)
        result_text += f"✏️ 文字內容:「{text}」\n"
    
    if audio_feat is None and text_feat is None:
        return "請提供語音或文字輸入進行情緒辨識。"

    # 製作 tensor 輸入
    audio_tensor = (
        torch.tensor(audio_feat, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
        if audio_feat is not None else
        torch.zeros(1, 1, 180).to(device)
    )
    text_tensor = (
        torch.tensor(text_feat, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
        if text_feat is not None else
        torch.zeros(1, 1, 768).to(device)
    )

    with torch.no_grad():
        output = emotion_model(audio_tensor, text_tensor)
        pred = torch.argmax(output, dim=1).item()

    result_text += f"📊 預測情緒:{EMOTION_LABELS[pred]}"
    return result_text

# Gradio Chat UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎧 中文語音情緒辨識聊天機器人\n支援語音輸入、文字輸入,或兩者結合分析")

    chatbot = gr.Chatbot()
    with gr.Row():
        audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="語音")
        text_input = gr.Textbox(lines=2, placeholder="輸入文字內容...", label="文字")
    send_btn = gr.Button("送出分析")

    def chat_handler(audio, text, history):
        response = analyze_input(audio, text)
        history = history or []
        history.append(("👤", response))
        return history, None, ""

    send_btn.click(fn=chat_handler,
                   inputs=[audio_input, text_input, chatbot],
                   outputs=[chatbot, audio_input, text_input])

demo.launch(share=True)