AnnaAgent-Demo / src /emotion_modulator_fc.py
sci-m-wang's picture
Upload 14 files
1d4c295 verified
raw
history blame
3.45 kB
from openai import OpenAI
from random import randint
from emotion_pertuber import perturb_state
import json
import os
# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
tools = [
{
"type": "function",
"function": {
'name': 'emotion_inference',
'description': '根据profile和对话记录,推理下一句情绪',
'parameters': {
"type": "object",
"properties": {
"emotion": {
"type": "string",
"enum": [
"admiration", "amusement", "anger", "annoyance", "approval", "caring",
"confusion", "curiosity", "desire", "disappointment", "disapproval",
"disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
"joy", "love", "nervousness", "optimism", "pride", "realization",
"relief", "remorse", "sadness", "surprise", "neutral"
],
"description": "推理出的情绪类别,必须是GoEmotions定义的27种情绪之一。"
}
},
"required": ["emotion"]
},
}
}
]
# 根据profile和dialogue推测emotion
def emotion_inferencer(profile, conversation):
client = OpenAI(
api_key=api_key,
base_url=base_url,
)
# 提取患者信息
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
# 提取对话记录
dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversation])
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "user", "content": f"### 任务\n根据患者情况及咨访对话历史记录推测患者下一句话最可能的情绪。\n{patient_info}\n### 对话记录\n{dialogue_history}"}
],
# functions=[tools[0]["function"]],
# function_call={"name": "emotion_inference"}
tools=tools,
tool_choice={"type": "function", "function": {"name": "emotion_inference"}}
)
# print(response)
emotion = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["emotion"]
return emotion
def emotion_modulation(profile, conversation):
indicator = randint(0,100)
emotion = emotion_inferencer(profile,conversation)
# print(emotion)
if indicator > 90:
return perturb_state(emotion)
else:
return emotion
# unit test
# while True:
# # 模拟患者信息
# profile = {
# "drisk": 3,
# "srisk": 2,
# "age": "42",
# "gender": "女",
# "marital_status": "离婚",
# "occupation": "教师",
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
# }
# conversation = [
# {"role": "咨询师", "content": "你好,请问有什么可以帮您?"}
# ]
# print(emotion_modulation(profile,conversation))