AnnaAgent-Demo / src /ms_patient.py
sci-m-wang's picture
Upload 14 files
1d4c295 verified
raw
history blame
5.61 kB
'''
AnnaAgent: 具有三级记忆结构的情绪与认知动态的模拟心理障碍患者
1. 首先获取患者的基本信息、病史、症状报告等信息
2. 根据患者的病史、症状报告等信息,生成患者的认知与情绪状态
'''
from openai import OpenAI
import os
from fill_scales import fill_scales, fill_scales_previous
from event_trigger import event_trigger, situationalising_events
from emotion_modulator_fc import emotion_modulation
from querier import query, is_need
from complaint_elicitor import switch_complaint, transform_chain
from complaint_chain_fc import gen_complaint_chain
from short_term_memory import summarize_scale_changes
from style_analyzer import analyze_style
import random
# from anna_agent_template import prompt_template
# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
# print("当前使用的模型是:", model_name)
class MsPatient:
def __init__(self, portrait:dict, report:dict, previous_conversations:list, language:str="Chinese"):
if language == "Chinese":
from anna_agent_template import prompt_template
elif language == "English":
from anna_agent_template_en import prompt_template
self.configuration = {}
self.portrait = portrait # age, gender, occupation, maritial_status, symptom
# self.profile = {key:self.portrait[key] for key in self.portrait if key != "symptom"} # profile不包含症状symptom
self.configuration["gender"] = self.portrait["gender"]
self.configuration["age"] = self.portrait["age"]
self.configuration["occupation"] = self.portrait["occupation"]
self.configuration["marriage"] = self.portrait["marital_status"]
self.report = report
self.previous_conversations = previous_conversations
# 填写之前疗程的量表
self.p_bdi, self.p_ghq, self.p_sass = fill_scales_previous(self.portrait, self.report)
self.conversation = [] # Conversation存储咨访记录
self.messages = [] # Messages存储LLM的消息列表
# 生成主诉认知变化链
self.complaint_chain = gen_complaint_chain(self.portrait)
# 生成近期事件
self.event = event_trigger(self.portrait)
# 总结短期记忆-事件
self.situation = situationalising_events(self.portrait)
self.configuration["situation"] = self.situation
# 分析说话风格
self.style = analyze_style(self.portrait, self.previous_conversations)
self.configuration["style"] = self.style
self.configuration["language"] = language
self.configuration["status"] = "" # 先置状态为空,后续会根据量表分析结果进行更新
seeker_utterances = [utterance["content"] for utterance in self.previous_conversations if utterance["role"] == "Seeker"]
self.configuration["statement"] = random.choices(seeker_utterances,k=3)
# 填写当前量表
self.bdi, self.ghq, self.sass = fill_scales(prompt_template.format(**self.configuration))
scales = {
"p_bdi": self.p_bdi,
"p_ghq": self.p_ghq,
"p_sass": self.p_sass,
"bdi": self.bdi,
"ghq": self.ghq,
"sass": self.sass
}
# 分析近期状态
self.status = summarize_scale_changes(scales)
self.configuration["status"] = self.status
# 选取对话样例
self.system = prompt_template.format(**self.configuration)
self.chain_index = 1
self.client = OpenAI(
api_key=api_key,
base_url=base_url
)
def chat(self, message):
# 更新消息列表
self.conversation.append({"role": "Counselor", "content": message})
self.messages.append({"role": "user", "content": message})
# 初始化本次对话的状态
emotion = emotion_modulation(self.portrait, self.conversation)
self.chain_index = switch_complaint(self.complaint_chain, self.chain_index, self.conversation)
complaint = transform_chain(self.complaint_chain)[self.chain_index]
# 判断是否涉及前疗程内容
if is_need(message):
# 生成前疗程内容
sup_information = query(message, self.previous_conversations, self.report)
# 生成回复
response = self.client.chat.completions.create(
model=model_name,
messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint},涉及到之前疗程的信息是:{sup_information}"}],
)
else:
# 生成回复
response = self.client.chat.completions.create(
model=model_name,
messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint}"}],
)
# 更新消息列表
self.conversation.append({"role": "Seeker", "content": response.choices[0].message.content})
self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
return response.choices[0].message.content
def get_system_prompt(self):
return self.system