File size: 5,605 Bytes
1d4c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
'''
AnnaAgent: 具有三级记忆结构的情绪与认知动态的模拟心理障碍患者
1. 首先获取患者的基本信息、病史、症状报告等信息
2. 根据患者的病史、症状报告等信息,生成患者的认知与情绪状态
'''


from openai import OpenAI
import os
from fill_scales import fill_scales, fill_scales_previous
from event_trigger import event_trigger, situationalising_events
from emotion_modulator_fc import emotion_modulation
from querier import query, is_need
from complaint_elicitor import switch_complaint, transform_chain
from complaint_chain_fc import gen_complaint_chain
from short_term_memory import summarize_scale_changes
from style_analyzer import analyze_style
import random
# from anna_agent_template import prompt_template

# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")

# print("当前使用的模型是:", model_name)

class MsPatient:
    def __init__(self, portrait:dict, report:dict, previous_conversations:list, language:str="Chinese"):
        if language == "Chinese":
            from anna_agent_template import prompt_template
        elif language == "English":
            from anna_agent_template_en import prompt_template
        self.configuration = {}
        self.portrait = portrait              # age, gender, occupation, maritial_status, symptom
        # self.profile = {key:self.portrait[key] for key in self.portrait if key != "symptom"}          # profile不包含症状symptom
        self.configuration["gender"] = self.portrait["gender"]
        self.configuration["age"] = self.portrait["age"]
        self.configuration["occupation"] = self.portrait["occupation"]
        self.configuration["marriage"] = self.portrait["marital_status"]
        self.report = report
        self.previous_conversations = previous_conversations
        # 填写之前疗程的量表
        self.p_bdi, self.p_ghq, self.p_sass = fill_scales_previous(self.portrait, self.report)
        self.conversation = []          # Conversation存储咨访记录
        self.messages = []              # Messages存储LLM的消息列表
        # 生成主诉认知变化链
        self.complaint_chain = gen_complaint_chain(self.portrait)
        # 生成近期事件
        self.event = event_trigger(self.portrait)
        # 总结短期记忆-事件
        self.situation = situationalising_events(self.portrait)
        self.configuration["situation"] = self.situation
        # 分析说话风格
        self.style = analyze_style(self.portrait, self.previous_conversations)
        self.configuration["style"] = self.style
        self.configuration["language"] = language
        self.configuration["status"] = ""  # 先置状态为空,后续会根据量表分析结果进行更新
        seeker_utterances = [utterance["content"] for utterance in self.previous_conversations if utterance["role"] == "Seeker"]
        self.configuration["statement"] = random.choices(seeker_utterances,k=3)
        # 填写当前量表
        self.bdi, self.ghq, self.sass = fill_scales(prompt_template.format(**self.configuration))
        scales = {
            "p_bdi": self.p_bdi,
            "p_ghq": self.p_ghq,
            "p_sass": self.p_sass,
            "bdi": self.bdi,
            "ghq": self.ghq,
            "sass": self.sass
        }
        # 分析近期状态
        self.status = summarize_scale_changes(scales)
        self.configuration["status"] = self.status
        # 选取对话样例
        self.system = prompt_template.format(**self.configuration)
        self.chain_index = 1
        self.client = OpenAI(
            api_key=api_key,
            base_url=base_url
        )

    def chat(self, message):
        # 更新消息列表
        self.conversation.append({"role": "Counselor", "content": message})
        self.messages.append({"role": "user", "content": message})
        # 初始化本次对话的状态
        emotion = emotion_modulation(self.portrait, self.conversation)
        self.chain_index = switch_complaint(self.complaint_chain, self.chain_index, self.conversation)
        complaint = transform_chain(self.complaint_chain)[self.chain_index]
        # 判断是否涉及前疗程内容
        if is_need(message):
            # 生成前疗程内容
            sup_information = query(message, self.previous_conversations, self.report)
            
            # 生成回复
            response = self.client.chat.completions.create(
                model=model_name,
                messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint},涉及到之前疗程的信息是:{sup_information}"}],
            )
        else:
            # 生成回复
            response = self.client.chat.completions.create(
                model=model_name,
                messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint}"}],
            )
        # 更新消息列表
        self.conversation.append({"role": "Seeker", "content": response.choices[0].message.content})
        self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
        return response.choices[0].message.content
    
    def get_system_prompt(self):
        return self.system