|
import gradio as gr |
|
import whisper |
|
import torch |
|
import numpy as np |
|
from feature import ( |
|
AudioTextEmotionModel, |
|
extract_audio_features, |
|
extract_text_features |
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
emotion_model = AudioTextEmotionModel(audio_input_dim=180, text_input_dim=768, hidden_dim=128, output_dim=3) |
|
emotion_model.load_state_dict(torch.load("model_weights.pth", map_location=device)) |
|
emotion_model.to(device) |
|
emotion_model.eval() |
|
|
|
|
|
whisper_model = whisper.load_model("base") |
|
EMOTION_LABELS = {0: '正面', 1: '中性', 2: '負面'} |
|
|
|
|
|
def analyze_input(audio, text_input): |
|
audio_feat = None |
|
text_feat = None |
|
result_text = "" |
|
|
|
|
|
if audio: |
|
result = whisper_model.transcribe(audio, language="zh") |
|
transcribed_text = result["text"] |
|
result_text += f"🎧 語音轉文字:「{transcribed_text}」\n" |
|
audio_feat = extract_audio_features(audio) |
|
else: |
|
transcribed_text = None |
|
|
|
|
|
text = text_input or transcribed_text |
|
if text: |
|
text_feat = extract_text_features(text) |
|
result_text += f"✏️ 文字內容:「{text}」\n" |
|
|
|
if audio_feat is None and text_feat is None: |
|
return "請提供語音或文字輸入進行情緒辨識。" |
|
|
|
|
|
audio_tensor = ( |
|
torch.tensor(audio_feat, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device) |
|
if audio_feat is not None else |
|
torch.zeros(1, 1, 180).to(device) |
|
) |
|
text_tensor = ( |
|
torch.tensor(text_feat, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device) |
|
if text_feat is not None else |
|
torch.zeros(1, 1, 768).to(device) |
|
) |
|
|
|
with torch.no_grad(): |
|
output = emotion_model(audio_tensor, text_tensor) |
|
pred = torch.argmax(output, dim=1).item() |
|
|
|
result_text += f"📊 預測情緒:{EMOTION_LABELS[pred]}" |
|
return result_text |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🎧 中文語音情緒辨識聊天機器人\n支援語音輸入、文字輸入,或兩者結合分析") |
|
|
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="語音") |
|
text_input = gr.Textbox(lines=2, placeholder="輸入文字內容...", label="文字") |
|
send_btn = gr.Button("送出分析") |
|
|
|
def chat_handler(audio, text, history): |
|
response = analyze_input(audio, text) |
|
history = history or [] |
|
history.append(("👤", response)) |
|
return history, None, "" |
|
|
|
send_btn.click(fn=chat_handler, |
|
inputs=[audio_input, text_input, chatbot], |
|
outputs=[chatbot, audio_input, text_input]) |
|
|
|
demo.launch(share=True) |
|
|