File size: 3,451 Bytes
1d4c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from openai import OpenAI
from random import randint
from emotion_pertuber import perturb_state
import json
import os

# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")

tools = [
    {
        "type": "function",
        "function": {
            'name': 'emotion_inference',
            'description': '根据profile和对话记录,推理下一句情绪',
            'parameters': {
                "type": "object",
                "properties": {
                    "emotion": {
                        "type": "string",
                        "enum": [
                            "admiration", "amusement", "anger", "annoyance", "approval", "caring",
                            "confusion", "curiosity", "desire", "disappointment", "disapproval",
                            "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
                            "joy", "love", "nervousness", "optimism", "pride", "realization",
                            "relief", "remorse", "sadness", "surprise", "neutral"
                        ],
                        "description": "推理出的情绪类别,必须是GoEmotions定义的27种情绪之一。"
                    }
                },
                "required": ["emotion"]
            },
        }
    }
]

# 根据profile和dialogue推测emotion
def emotion_inferencer(profile, conversation):
    client = OpenAI(
        api_key=api_key,
        base_url=base_url,
    )

    # 提取患者信息
    patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
    
    # 提取对话记录
    dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversation])

    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "user", "content": f"### 任务\n根据患者情况及咨访对话历史记录推测患者下一句话最可能的情绪。\n{patient_info}\n### 对话记录\n{dialogue_history}"}
        ],
        # functions=[tools[0]["function"]],
        # function_call={"name": "emotion_inference"}
        tools=tools,
        tool_choice={"type": "function", "function": {"name": "emotion_inference"}}
    )
    # print(response)

    emotion = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["emotion"]

    return emotion

def emotion_modulation(profile, conversation):
    indicator = randint(0,100)
    emotion = emotion_inferencer(profile,conversation)
    # print(emotion)
    if indicator > 90:
        return perturb_state(emotion)
    else:
        return emotion

# unit test
# while True:
#     # 模拟患者信息
#     profile = {
#         "drisk": 3,
#         "srisk": 2,
#         "age": "42",
#         "gender": "女",
#         "marital_status": "离婚",
#         "occupation": "教师",
#         "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
#         }

#     conversation = [
#         {"role": "咨询师", "content": "你好,请问有什么可以帮您?"}
#     ]

#     print(emotion_modulation(profile,conversation))