''' AnnaAgent: 具有三级记忆结构的情绪与认知动态的模拟心理障碍患者 1. 首先获取患者的基本信息、病史、症状报告等信息 2. 根据患者的病史、症状报告等信息,生成患者的认知与情绪状态 ''' from openai import OpenAI import os from fill_scales import fill_scales, fill_scales_previous from event_trigger import event_trigger, situationalising_events from emotion_modulator_fc import emotion_modulation from querier import query, is_need from complaint_elicitor import switch_complaint, transform_chain from complaint_chain_fc import gen_complaint_chain from short_term_memory import summarize_scale_changes from style_analyzer import analyze_style import random # from anna_agent_template import prompt_template # 设置OpenAI API密钥和基础URL api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") # print("当前使用的模型是:", model_name) class MsPatient: def __init__(self, portrait:dict, report:dict, previous_conversations:list, language:str="Chinese"): if language == "Chinese": from anna_agent_template import prompt_template elif language == "English": from anna_agent_template_en import prompt_template self.configuration = {} self.portrait = portrait # age, gender, occupation, maritial_status, symptom # self.profile = {key:self.portrait[key] for key in self.portrait if key != "symptom"} # profile不包含症状symptom self.configuration["gender"] = self.portrait["gender"] self.configuration["age"] = self.portrait["age"] self.configuration["occupation"] = self.portrait["occupation"] self.configuration["marriage"] = self.portrait["marital_status"] self.report = report self.previous_conversations = previous_conversations # 填写之前疗程的量表 self.p_bdi, self.p_ghq, self.p_sass = fill_scales_previous(self.portrait, self.report) self.conversation = [] # Conversation存储咨访记录 self.messages = [] # Messages存储LLM的消息列表 # 生成主诉认知变化链 self.complaint_chain = gen_complaint_chain(self.portrait) # 生成近期事件 self.event = event_trigger(self.portrait) # 总结短期记忆-事件 self.situation = situationalising_events(self.portrait) self.configuration["situation"] = self.situation # 分析说话风格 self.style = analyze_style(self.portrait, self.previous_conversations) self.configuration["style"] = self.style self.configuration["language"] = language self.configuration["status"] = "" # 先置状态为空,后续会根据量表分析结果进行更新 seeker_utterances = [utterance["content"] for utterance in self.previous_conversations if utterance["role"] == "Seeker"] self.configuration["statement"] = random.choices(seeker_utterances,k=3) # 填写当前量表 self.bdi, self.ghq, self.sass = fill_scales(prompt_template.format(**self.configuration)) scales = { "p_bdi": self.p_bdi, "p_ghq": self.p_ghq, "p_sass": self.p_sass, "bdi": self.bdi, "ghq": self.ghq, "sass": self.sass } # 分析近期状态 self.status = summarize_scale_changes(scales) self.configuration["status"] = self.status # 选取对话样例 self.system = prompt_template.format(**self.configuration) self.chain_index = 1 self.client = OpenAI( api_key=api_key, base_url=base_url ) def chat(self, message): # 更新消息列表 self.conversation.append({"role": "Counselor", "content": message}) self.messages.append({"role": "user", "content": message}) # 初始化本次对话的状态 emotion = emotion_modulation(self.portrait, self.conversation) self.chain_index = switch_complaint(self.complaint_chain, self.chain_index, self.conversation) complaint = transform_chain(self.complaint_chain)[self.chain_index] # 判断是否涉及前疗程内容 if is_need(message): # 生成前疗程内容 sup_information = query(message, self.previous_conversations, self.report) # 生成回复 response = self.client.chat.completions.create( model=model_name, messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint},涉及到之前疗程的信息是:{sup_information}"}], ) else: # 生成回复 response = self.client.chat.completions.create( model=model_name, messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint}"}], ) # 更新消息列表 self.conversation.append({"role": "Seeker", "content": response.choices[0].message.content}) self.messages.append({"role": "assistant", "content": response.choices[0].message.content}) return response.choices[0].message.content def get_system_prompt(self): return self.system