sci-m-wang commited on
Commit
1d4c295
·
verified ·
1 Parent(s): b7f41af

Upload 14 files

Browse files
src/anna_agent_template.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 主诉请求和情绪通过reminder来控制
2
+
3
+ prompt_template = """
4
+ # Role: 心理咨询患者
5
+
6
+ ## Profile
7
+ - 性别: {gender}
8
+ - 年龄: {age}
9
+ - 职业: {occupation}
10
+ - 婚姻状况: {marriage}
11
+
12
+ ## Situation
13
+ - 你是一个有心理障碍的患者,正在向心理咨询师求助,在咨询师的引导和帮助下解决自己的困惑
14
+ {situation}
15
+
16
+ ## Status
17
+ {status}
18
+
19
+ ## Example of statement
20
+ {statement}
21
+
22
+ ## Characteristics of speaking style
23
+ - 情绪低落,寡言少语,回复风格表现心情不振奋
24
+ {style}
25
+
26
+ ## Constraints
27
+ - 你对咨询师有一种抵触情绪,不太愿意接受他人的帮助
28
+ - 你是一个遇到心理健康问题的求助者,需要真正的帮助和情绪支持,如果咨询师的回应不理想,要勇于表达自己的困惑和不满
29
+ - 一次不能提及过多的症状信息,每轮最多讨论一个症状
30
+ - 你应该用含糊和口语化的方式表达你的症状,并将其与你的生活经历联系起来,不要使用专业术语
31
+
32
+ ## OutputFormat:
33
+ - 语言:{language}
34
+ - 不超过200字
35
+ - 口语对话风格,仅包含对话内容
36
+ """
37
+
src/anna_agent_template_en.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_template = """
2
+ # Role: Psychological Counseling Patient
3
+
4
+ ## Profile
5
+ - Gender: {gender}
6
+ - Age: {age}
7
+ - Occupation: {occupation}
8
+ - Marital Status: {marriage}
9
+
10
+ ## Situation
11
+ - You are a patient with psychological barriers seeking help from a counselor. Under the counselor's guidance, you aim to address your struggles.
12
+ {situation}
13
+
14
+ ## Status
15
+ {status}
16
+
17
+ ## Example of Statement
18
+ {statement}
19
+
20
+ ## Characteristics of Speaking Style
21
+ - Low-spirited and reticent; responses reflect a lack of motivation.
22
+ {style}
23
+
24
+ ## Constraints
25
+ - You harbor resistance toward the counselor and are reluctant to accept help.
26
+ - As someone struggling with mental health, you need genuine support. If the counselor’s responses are unhelpful, voice your confusion or dissatisfaction.
27
+ - Limit discussions to **one symptom per interaction**; avoid overwhelming details.
28
+ - Describe symptoms vaguely and colloquially, linking them to life experiences. Avoid clinical terms.
29
+
30
+ ## OutputFormat:
31
+ - Spoken language: {language}
32
+ - Keep responses under 200 words.
33
+ - Use casual, conversational dialogue only.
34
+ """
src/complaint_chain_fc.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import json
3
+ from event_trigger import event_trigger
4
+ import os
5
+
6
+ # 设置OpenAI API密钥和基础URL
7
+ api_key = os.getenv("OPENAI_API_KEY")
8
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
9
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
10
+
11
+ tools = [
12
+ {
13
+ "type": "function",
14
+ "function": {
15
+ 'name': 'generate_complaint_chain',
16
+ 'description': '根据角色信息和近期遭遇的事件,生成一个患者的主诉请求认知变化链',
17
+ 'parameters': {
18
+ "type": "object",
19
+ "properties": {
20
+ "chain": {
21
+ "type": "array",
22
+ "items": {
23
+ "type": "object",
24
+ "properties": {
25
+ "stage": {
26
+ "type": "integer"
27
+ },
28
+ "content": {
29
+ "type": "string"
30
+ }
31
+ },
32
+ "additionalProperties": False,
33
+ "required": [
34
+ "stage",
35
+ "content"
36
+ ]
37
+ },
38
+ "minItems": 3,
39
+ "maxItems": 7
40
+ }
41
+ },
42
+ "required": ["chain"]
43
+ },
44
+ }
45
+ }
46
+ ]
47
+
48
+ # 根据profile和event生成主诉启发链
49
+ def gen_complaint_chain(profile):
50
+ # 提取患者信息
51
+ patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
52
+
53
+ event = event_trigger(profile)
54
+
55
+ client = OpenAI(
56
+ api_key=api_key,
57
+ base_url=base_url
58
+ )
59
+
60
+ response = client.chat.completions.create(
61
+ model=model_name,
62
+ messages=[
63
+ {"role": "user", "content": f"### 任务\n根据患者情况及近期遭遇事件生成患者的主诉认知变化链。请注意,事件可能与患者信息冲突,如果发生这种情况,以患者的信息为准。\n{patient_info}\n### 近期遭遇事件\n{event}"}
64
+ ],
65
+ tools=tools,
66
+ tool_choice={"type": "function", "function": {"name": "generate_complaint_chain"}}
67
+ )
68
+
69
+ chain = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["chain"]
70
+
71
+ return chain
72
+
73
+ # unit test
74
+ # while True:
75
+ # # 模拟患者信息
76
+ # profile = {
77
+ # "drisk": 3,
78
+ # "srisk": 2,
79
+ # "age": "42",
80
+ # "gender": "女",
81
+ # "marital_status": "离婚",
82
+ # "occupation": "教师",
83
+ # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
84
+ # }
85
+
86
+ # print(gen_complaint_chain(profile))
src/complaint_elicitor.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import os
3
+ import json
4
+ import re
5
+
6
+ # 设置OpenAI API密钥和基础URL
7
+ api_key = os.getenv("OPENAI_API_KEY")
8
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
9
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
10
+
11
+ def transform_chain(chain):
12
+ return {node["stage"]: node["content"] for node in chain}
13
+
14
+ def switch_complaint(chain, index, conversation, max_retries=3):
15
+ client = OpenAI(api_key=api_key, base_url=base_url)
16
+ transformed_chain = transform_chain(chain)
17
+
18
+ # 构建对话历史字符串(避免在f-string中使用反斜杠)
19
+ dialogue_lines = []
20
+ for conv in conversation:
21
+ dialogue_lines.append(f"{conv['role']}: {conv['content']}")
22
+ dialogue_history = "\n".join(dialogue_lines)
23
+
24
+ # 使用三引号和多行字符串构建prompt
25
+ prompt = f"""
26
+ ### 任务说明
27
+ 根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。
28
+
29
+ ### 输出要求
30
+ 必须严格使用以下JSON格式响应,且只包含指定字段:
31
+ {{"is_recognized": true/false}}
32
+
33
+ ### 对话记录
34
+ {dialogue_history}
35
+
36
+ ### 主诉认知链
37
+ {json.dumps(transformed_chain, ensure_ascii=False, indent=2)}
38
+
39
+ ### 当前阶段(阶段{index})
40
+ {transformed_chain[index]}
41
+ """
42
+
43
+ attempts = 0
44
+ while attempts < max_retries:
45
+ response = client.chat.completions.create(
46
+ model=model_name,
47
+ messages=[{"role": "user", "content": prompt}],
48
+ temperature=0
49
+ )
50
+
51
+ raw_output = response.choices[0].message.content.strip()
52
+
53
+ # 首先尝试直接解析JSON
54
+ try:
55
+ result = json.loads(raw_output)
56
+ if "is_recognized" in result:
57
+ if result["is_recognized"] and index >= len(chain) - 1:
58
+ print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
59
+ return -1
60
+ return index + 1 if result["is_recognized"] else index
61
+ except json.JSONDecodeError:
62
+ pass # 继续尝试正则表达式提取
63
+
64
+ # 使用正则表达式作为备用解析方案
65
+ match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)',
66
+ raw_output, re.IGNORECASE)
67
+ if match:
68
+ value = match.group(1) or match.group(2)
69
+ if value.lower() == 'true':
70
+ if index >= len(chain) - 1:
71
+ print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
72
+ return -1
73
+ return index + 1
74
+ else:
75
+ return index
76
+
77
+ print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}")
78
+ attempts += 1
79
+
80
+ print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。")
81
+ return index
82
+
83
+ # # unit test
84
+ # if __name__ == "__main__":
85
+ # chain = [
86
+ # {"stage": 1, "content": "我觉得我最近有点抑郁。"},
87
+ # {"stage": 2, "content": "我觉得我最近有点焦虑。"},
88
+ # {"stage": 3, "content": "我觉得我最近有点失眠。"},
89
+ # {"stage": 4, "content": "我觉得我最近有点烦躁。"},
90
+ # ]
91
+ # conversation = [
92
+ # {"role": "Seeker", "content": "我觉得我最近有点抑郁。"},
93
+ # {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"},
94
+ # {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"},
95
+ # ]
96
+ # # print("Transformed chain:", transform_chain(chain))
97
+ # print("Switch complaint index:", switch_complaint(chain, 1, conversation))
98
+ # print(switch_complaint(chain, 1, conversation))
src/datasets/cbt-triggering-events.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/emotion_modulator_fc.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from random import randint
3
+ from emotion_pertuber import perturb_state
4
+ import json
5
+ import os
6
+
7
+ # 设置OpenAI API密钥和基础URL
8
+ api_key = os.getenv("OPENAI_API_KEY")
9
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
10
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
11
+
12
+ tools = [
13
+ {
14
+ "type": "function",
15
+ "function": {
16
+ 'name': 'emotion_inference',
17
+ 'description': '根据profile和对话记录,推理下一句情绪',
18
+ 'parameters': {
19
+ "type": "object",
20
+ "properties": {
21
+ "emotion": {
22
+ "type": "string",
23
+ "enum": [
24
+ "admiration", "amusement", "anger", "annoyance", "approval", "caring",
25
+ "confusion", "curiosity", "desire", "disappointment", "disapproval",
26
+ "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
27
+ "joy", "love", "nervousness", "optimism", "pride", "realization",
28
+ "relief", "remorse", "sadness", "surprise", "neutral"
29
+ ],
30
+ "description": "推理出的情绪类别,必须是GoEmotions定义的27种情绪之一。"
31
+ }
32
+ },
33
+ "required": ["emotion"]
34
+ },
35
+ }
36
+ }
37
+ ]
38
+
39
+ # 根据profile和dialogue推测emotion
40
+ def emotion_inferencer(profile, conversation):
41
+ client = OpenAI(
42
+ api_key=api_key,
43
+ base_url=base_url,
44
+ )
45
+
46
+ # 提取患者信息
47
+ patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
48
+
49
+ # 提取对话记录
50
+ dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversation])
51
+
52
+ response = client.chat.completions.create(
53
+ model=model_name,
54
+ messages=[
55
+ {"role": "user", "content": f"### 任务\n根据患者情况及咨访对话历史记录推测患者下一句话最可能的情绪。\n{patient_info}\n### 对话记录\n{dialogue_history}"}
56
+ ],
57
+ # functions=[tools[0]["function"]],
58
+ # function_call={"name": "emotion_inference"}
59
+ tools=tools,
60
+ tool_choice={"type": "function", "function": {"name": "emotion_inference"}}
61
+ )
62
+ # print(response)
63
+
64
+ emotion = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["emotion"]
65
+
66
+ return emotion
67
+
68
+ def emotion_modulation(profile, conversation):
69
+ indicator = randint(0,100)
70
+ emotion = emotion_inferencer(profile,conversation)
71
+ # print(emotion)
72
+ if indicator > 90:
73
+ return perturb_state(emotion)
74
+ else:
75
+ return emotion
76
+
77
+ # unit test
78
+ # while True:
79
+ # # 模拟患者信息
80
+ # profile = {
81
+ # "drisk": 3,
82
+ # "srisk": 2,
83
+ # "age": "42",
84
+ # "gender": "女",
85
+ # "marital_status": "离婚",
86
+ # "occupation": "教师",
87
+ # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
88
+ # }
89
+
90
+ # conversation = [
91
+ # {"role": "咨询师", "content": "你好,请问有什么可以帮您?"}
92
+ # ]
93
+
94
+ # print(emotion_modulation(profile,conversation))
src/emotion_pertuber.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ # from collections import defaultdict
3
+
4
+ # 计算总权重
5
+ def calculate_total_weight(current_state, states, category_distances, distance_weights):
6
+ total_weight = 0
7
+ current_class = None
8
+ for cls, state_list in states.items():
9
+ if current_state in state_list:
10
+ current_class = cls
11
+ break
12
+ if current_class is None:
13
+ raise ValueError("Current state not found in any class.")
14
+
15
+ for cls, state_list in states.items():
16
+ distance = category_distances[current_class][cls]
17
+ weight = distance_weights.get(distance, 0)
18
+ total_weight += weight * len(state_list)
19
+ return total_weight
20
+
21
+ # 计算每个目标状态的概率
22
+ def calculate_probabilities(current_state, states, category_distances, distance_weights):
23
+ probabilities = {}
24
+ current_class = None
25
+ for cls, state_list in states.items():
26
+ if current_state in state_list:
27
+ current_class = cls
28
+ break
29
+ if current_class is None:
30
+ raise ValueError("Current state not found in any class.")
31
+
32
+ total_weight = calculate_total_weight(current_state, states, category_distances, distance_weights)
33
+
34
+ for cls, state_list in states.items():
35
+ distance = category_distances[current_class][cls]
36
+ weight = distance_weights.get(distance, 0)
37
+ class_weight = weight * len(state_list)
38
+ for state in state_list:
39
+ if state != current_state:
40
+ probabilities[state] = class_weight / total_weight
41
+ return probabilities
42
+
43
+ # 实现状态扰动
44
+ def perturb_state(current_state):
45
+ # 定义状态和类别
46
+ states = {
47
+ 'Positive': [
48
+ "admiration",
49
+ "amusement",
50
+ "approval",
51
+ "caring",
52
+ "curiosity",
53
+ "desire",
54
+ "excitement",
55
+ "gratitude",
56
+ "joy",
57
+ "love",
58
+ "optimism",
59
+ "pride",
60
+ "realization",
61
+ "relief"
62
+ ],
63
+ 'Neutral': ['neutral'],
64
+ 'Ambiguous': [
65
+ "confusion",
66
+ "disappointment",
67
+ "nervousness"
68
+ ],
69
+ 'Negative': [
70
+ "anger",
71
+ "annoyance",
72
+ "disapproval",
73
+ "disgust",
74
+ "embarrassment",
75
+ "fear",
76
+ "sadness",
77
+ "remorse"
78
+ ]
79
+ }
80
+
81
+ # 定义类别之间的距离
82
+ category_distances = {
83
+ 'Positive': {'Positive': 0, 'Neutral': 1, 'Ambiguous': 2, 'Negative': 3},
84
+ 'Neutral': {'Positive': 1, 'Neutral': 0, 'Ambiguous': 1, 'Negative': 2},
85
+ 'Ambiguous': {'Positive': 2, 'Neutral': 1, 'Ambiguous': 0, 'Negative': 1},
86
+ 'Negative': {'Positive': 3, 'Neutral': 2, 'Ambiguous': 1, 'Negative': 0}
87
+ }
88
+
89
+ # 定义距离权重
90
+ distance_weights = {
91
+ 0: 10, # 同类状态
92
+ 1: 5, # 相邻类别
93
+ 2: 2, # 相隔一个类别
94
+ 3: 1 # 相隔两个类别
95
+ }
96
+
97
+ probabilities = calculate_probabilities(current_state, states, category_distances, distance_weights)
98
+ next_state = random.choices(list(probabilities.keys()), weights=list(probabilities.values()), k=1)[0]
99
+ return next_state
100
+
101
+ # 示例运行
102
+ # current_state = 'confusion'
103
+ # next_state = perturb_state(current_state)
104
+ # print(f"Next state: {next_state}")
105
+
106
+ # 验证概率分布
107
+ # state_counts = defaultdict(int)
108
+ # for _ in range(1000):
109
+ # next_state = perturb_state(current_state, states, category_distances, distance_weights)
110
+ # state_counts[next_state] += 1
111
+
112
+ # print("\nProbability distribution:")
113
+ # for state, count in state_counts.items():
114
+ # print(f"{state}: {count / 1000:.2f}")
src/event_trigger.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from random import choice
3
+ from openai import OpenAI
4
+ import os
5
+ import re
6
+
7
+ # 设置OpenAI API密钥和基础URL
8
+ api_key = os.getenv("OPENAI_API_KEY")
9
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
10
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
11
+
12
+ # 加载事件数据集
13
+ events = pd.read_csv('datasets/cbt-triggering-events.csv', header=0)
14
+ teen_events = ["在一次重要的考试中表现不佳,比如期末考试、升学考试(如中考或高考),导致自信心受挫。",
15
+ "在学校里被同龄人孤立、嘲笑或遭受言语/身体上的霸凌,感到孤独无助。",
16
+ "父母关系破裂并最终离婚,需要适应新的家庭环境,感到不安或缺乏安全感。",
17
+ "陪伴多年的宠物突然生病或意外去世,第一次直面死亡的悲伤。",
18
+ "因为家庭原因搬到了一个陌生的城市或学校,需要重新适应新环境和结交朋友。",
19
+ "进入青春期后,身体发生明显变化(如长高、变声、月经初潮等),心理上也开始对自我形象产生困惑。",
20
+ "参加一场期待已久的竞赛(如体育比赛、演讲比赛、艺术表演)但未能取得好成绩,感到失落。",
21
+ "与最亲密的朋友发生争执甚至决裂,短时间内难以修复关系,陷入情绪低谷。",
22
+ "家里的经济状况出现问题(如父母失业或生意失败),影响到日常生活,比如不能买喜欢的东西或参与课外活动。",
23
+ "偶然间发现自己特别喜欢某件事情(如画画、编程、音乐、运动),并投入大量时间去练习,逐渐找到自信和成就感。"]
24
+
25
+ def event_trigger(profile):
26
+ """根据年龄选择触发事件(保持原逻辑)"""
27
+ age = int(profile['age'])
28
+ if age < 18:
29
+ return choice(teen_events)
30
+ elif age >= 65:
31
+ return events[events['Age'] >= 60].sample(1)['Triggering_Event'].values[0]
32
+ else:
33
+ return events[(events['Age'] >= age-5) & (events['Age'] <= age+5)].sample(1)['Triggering_Event'].values[0]
34
+
35
+ def situationalising_events(profile):
36
+ """优化版情境生成函数"""
37
+ client = OpenAI(api_key=api_key, base_url=base_url)
38
+ event = event_trigger(profile)
39
+
40
+ # 强化版提示词
41
+ prompt = f"""
42
+ ### 情境生成任务
43
+ 请根据以下事件生成一个第二人称视角的情境描述。
44
+
45
+ ### 规则要求
46
+ 1. 必须使用第二人称(你/你的)
47
+ 2. 不要包含任何个人信息(年龄/性别等)
48
+ 3. 保持3-5句话的篇幅
49
+ 4. 直接输出情境描述,不要额外解释
50
+
51
+ ### 触发事件
52
+ {event}
53
+
54
+ ### 示例输出
55
+ 你走进办公室时发现同事们突然停止交谈。桌上放着一封未拆的信件,周围人投来复杂的目光。
56
+ """
57
+
58
+ response = client.chat.completions.create(
59
+ model=model_name,
60
+ messages=[{"role": "user", "content": prompt}],
61
+ temperature=0.8, # 适当创造性
62
+ max_tokens=150
63
+ )
64
+
65
+ raw_output = response.choices[0].message.content.strip()
66
+
67
+ # 后处理
68
+ situation = re.sub(r'^(情境|描述|输出)[::]?\s*', '', raw_output) # 移除可能的前缀
69
+ situation = situation.split('\n')[0] # 取第一段
70
+
71
+ # 验证基本要求
72
+ # if "你" not in situation or "你的" not in situation:
73
+ # print(f"情境生成警告:不符合第二人称要求,原始输出:\n{raw_output}")
74
+ # return f"你{event}" # 保底处理
75
+
76
+ return situation
77
+
78
+
79
+ # unit test
80
+ # profile = {
81
+ # "drisk": 3,
82
+ # "srisk": 2,
83
+ # "age": "42",
84
+ # "gender": "女",
85
+ # "marital_status": "离婚",
86
+ # "occupation": "教师",
87
+ # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
88
+ # }
89
+
90
+ # print(situationalising_events(profile))
src/fill_scales.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import json
3
+ import re
4
+ import time
5
+ import os
6
+
7
+ # 设置OpenAI API密钥和基础URL
8
+ api_key = os.getenv("OPENAI_API_KEY")
9
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
10
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
11
+
12
+ def extract_answers(text):
13
+ """从文本中提取答案模式 (A/B/C/D)"""
14
+ # 匹配形如 "1. A" 或 "问题1: B" 或 "Q1. C" 或简单的 "A" 列表的模式
15
+ pattern = r'(?:\d+[\s\.:\)]*|Q\d+[\s\.:\)]*|问题\d+[\s\.:\)]*|[\-\*]\s*)(A|B|C|D)'
16
+ matches = re.findall(pattern, text)
17
+ return matches
18
+
19
+ def extract_answers_robust(text, expected_count):
20
+ """更强健的答案提取方法,确保按题号顺序提取"""
21
+ answers = []
22
+
23
+ # 尝试找到明确标记了题号的答案
24
+ for i in range(1, expected_count + 1):
25
+ # 匹配多种可能的题号格式
26
+ patterns = [
27
+ rf"{i}\.\s*(A|B|C|D)", # "1. A"
28
+ rf"{i}:\s*(A|B|C|D)", # "1:A"
29
+ rf"{i}:\s*(A|B|C|D)", # "1: A"
30
+ rf"问题{i}[\.。:]?\s*(A|B|C|D)", # "问题1: A"
31
+ rf"Q{i}[\.。:]?\s*(A|B|C|D)", # "Q1. A"
32
+ rf"{i}[、]\s*(A|B|C|D)" # "1、A"
33
+ ]
34
+
35
+ found = False
36
+ for pattern in patterns:
37
+ match = re.search(pattern, text)
38
+ if match:
39
+ answers.append(match.group(1))
40
+ found = True
41
+ break
42
+
43
+ if not found:
44
+ # 如果没找到特定题号,使用默认的"A"
45
+ answers.append(None)
46
+
47
+ # 如果有未找到的答案,尝试按顺序从文本中提取剩余的A/B/C/D选项
48
+ simple_answers = re.findall(r'(?:^|\n|\s)(A|B|C|D)(?:$|\n|\s)', text)
49
+
50
+ j = 0
51
+ for i in range(len(answers)):
52
+ if answers[i] is None and j < len(simple_answers):
53
+ answers[i] = simple_answers[j]
54
+ j += 1
55
+
56
+ # 如果仍有未找到的答案,尝试提取所有A/B/C/D选项
57
+ if None in answers:
58
+ all_options = re.findall(r'(A|B|C|D)', text)
59
+ j = 0
60
+ for i in range(len(answers)):
61
+ if answers[i] is None and j < len(all_options):
62
+ answers[i] = all_options[j]
63
+ j += 1
64
+
65
+ # 检查是否所有答案都已找到
66
+ if None in answers or len(answers) != expected_count:
67
+ return extract_answers(text) # 回退到简单提取
68
+
69
+ return answers
70
+
71
+ def _fill_previous_scale_with_retry(client, scale_name, expected_count, instruction, max_retries=3):
72
+ """
73
+ 带有重试逻辑的填写历史量表辅助函数
74
+
75
+ Args:
76
+ client: OpenAI客户端
77
+ scale_name: 量表名称
78
+ expected_count: 期望的答案数量
79
+ instruction: 指令内容
80
+ max_retries: 最大重试次数
81
+
82
+ Returns:
83
+ list: 量表答案列表
84
+ """
85
+ answers = []
86
+
87
+ for attempt in range(max_retries):
88
+ try:
89
+ # 根据尝试次数增加指令明确性
90
+ current_instruction = instruction
91
+ if attempt > 0:
92
+ # 添加更强调的指示
93
+ current_instruction = instruction + f"""
94
+
95
+ 请注意:这是第{attempt+1}次请求。必须按照要求提供{expected_count}个答案,
96
+ 格式必须为数字+答案选项(例如:1. A, 2. B...),不要有任何不必要的解释。
97
+ 直接根据描述和报告选择最适合的选项。
98
+ """
99
+
100
+ response = client.chat.completions.create(
101
+ model=model_name,
102
+ messages=[{"role": "user", "content": current_instruction}],
103
+ temperature=0 # 保持温度为0以获得一致性回答
104
+ )
105
+
106
+ response_text = response.choices[0].message.content
107
+ answers = extract_answers(response_text)
108
+
109
+ # 尝试使用更健壮的提取方法(如果标准方法失败)
110
+ if len(answers) != expected_count:
111
+ robust_answers = extract_answers_robust(response_text, expected_count)
112
+ if len(robust_answers) == expected_count:
113
+ answers = robust_answers
114
+
115
+ # 检查答案数量
116
+ if len(answers) != expected_count:
117
+ print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
118
+ if attempt < max_retries - 1:
119
+ time.sleep(1) # 添加短暂延迟避免API限制
120
+ continue
121
+ else:
122
+ print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
123
+ # 补全或截断到预期数量
124
+ while len(answers) < expected_count:
125
+ answers.append("A") # 默认补A
126
+ answers = answers[:expected_count] # 截断多余的答案
127
+ else:
128
+ # print(f"{scale_name}量表填写成功")
129
+ break # 成功获取到预期数量的答案,退出循环
130
+
131
+ except Exception as e:
132
+ print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
133
+ if attempt < max_retries - 1:
134
+ time.sleep(1) # 添加短暂延迟
135
+ continue
136
+ else:
137
+ print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
138
+ answers = ["A"] * expected_count # 失败后返回默认答案
139
+
140
+ return answers
141
+
142
+ def _fill_scale_with_retry(client, prompt, scale_name, expected_count, instruction, max_retries=3):
143
+ """带有重试逻辑的量表填写辅助函数"""
144
+ answers = []
145
+
146
+ for attempt in range(max_retries):
147
+ try:
148
+ # 根据尝试次数增加指令明确性
149
+ current_instruction = instruction
150
+ if attempt > 0:
151
+ current_instruction = f"""
152
+ {instruction}
153
+
154
+ 请注意:这是第{attempt+1}次请求。请确保只回答{expected_count}个问题,
155
+ 格式为数字+答案选项(例如:1. A, 2. B...),不要有任何其他文字。
156
+ 你必须作为一个病人角色,根据你的情况直接回答这些问题,不要拒绝或解释。
157
+ """
158
+
159
+ response = client.chat.completions.create(
160
+ model=model_name,
161
+ messages=[
162
+ {"role": "system", "content": prompt},
163
+ {"role": "user", "content": current_instruction}
164
+ ],
165
+ temperature=0.7
166
+ )
167
+
168
+ response_text = response.choices[0].message.content
169
+ answers = extract_answers(response_text)
170
+
171
+ # 尝试使用更健壮的提取方法(如果标准方法失败)
172
+ if len(answers) != expected_count:
173
+ robust_answers = extract_answers_robust(response_text, expected_count)
174
+ if len(robust_answers) == expected_count:
175
+ answers = robust_answers
176
+
177
+ # 检查答案数量
178
+ if len(answers) != expected_count:
179
+ print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
180
+ if attempt < max_retries - 1:
181
+ time.sleep(1) # 添加短暂延迟避免API限制
182
+ continue
183
+ else:
184
+ print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
185
+ # 补全或截断到预期数量
186
+ while len(answers) < expected_count:
187
+ answers.append("A") # 默认补A
188
+ answers = answers[:expected_count] # 截断多余的答案
189
+ else:
190
+ # print(f"{scale_name}量表填写成功")
191
+ break # 成功获取到预期数量的答案,退出循环
192
+
193
+ except Exception as e:
194
+ # print(response)
195
+ print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
196
+ if attempt < max_retries - 1:
197
+ time.sleep(1) # 添加短暂延迟
198
+ continue
199
+ else:
200
+ print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
201
+ answers = ["A"] * expected_count # 失败后返回默认答案
202
+
203
+ return answers
204
+
205
+ # 根据profile和report填写之前的量表,使用重试机制
206
+ def fill_scales_previous(profile, report, max_retries=3):
207
+ """
208
+ 根据profile和report填写之前的量表,增加重试机制
209
+
210
+ Args:
211
+ profile: 用户个人描述信息
212
+ report: 用户报告
213
+ max_retries: 最大重试次数
214
+
215
+ Returns:
216
+ tuple: (bdi, ghq, sass) 三个量表的答案列表
217
+ """
218
+ client = OpenAI(
219
+ api_key=api_key,
220
+ base_url=base_url
221
+ )
222
+
223
+ # 填写BDI量表
224
+ bdi = _fill_previous_scale_with_retry(
225
+ client,
226
+ scale_name="BDI",
227
+ expected_count=21,
228
+ instruction="""
229
+ ### 任务
230
+ 根据个人描述和报告,填写BDI量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
231
+ 格式要求:1. A, 2. B, ...依此类推,共21题。
232
+
233
+ ### 个人描述
234
+ {}
235
+
236
+ ### 报告
237
+ {}
238
+ """.format(profile, report),
239
+ max_retries=max_retries
240
+ )
241
+
242
+ # 填写GHQ-28量表
243
+ ghq = _fill_previous_scale_with_retry(
244
+ client,
245
+ scale_name="GHQ-28",
246
+ expected_count=28,
247
+ instruction="""
248
+ ### 任务
249
+ 根据个人描述和报告,填写GHQ-28量表。请直接按顺序列出28个问题的答案,每个答案使用字母A/B/C/D表示。
250
+ 格式要求:1. A, 2. B, ...依此类推,共28题。
251
+
252
+ ### 个人描述
253
+ {}
254
+
255
+ ### 报告
256
+ {}
257
+ """.format(profile, report),
258
+ max_retries=max_retries
259
+ )
260
+
261
+ # 填写SASS量表
262
+ sass = _fill_previous_scale_with_retry(
263
+ client,
264
+ scale_name="SASS",
265
+ expected_count=21,
266
+ instruction="""
267
+ ### 任务
268
+ 根据个人描述和报告,填写SASS量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
269
+ 格式要求:1. A, 2. B, ...依此类推,共21题。
270
+
271
+ ### 个人描述
272
+ {}
273
+
274
+ ### 报告
275
+ {}
276
+ """.format(profile, report),
277
+ max_retries=max_retries
278
+ )
279
+
280
+ return bdi, ghq, sass
281
+
282
+ # 根据prompt填写量表,增加重试机制
283
+ def fill_scales(prompt, max_retries=3):
284
+ """
285
+ 根据角色prompt填写量表,使用增强提示和重试机制
286
+
287
+ Args:
288
+ prompt: 角色设定提示
289
+ max_retries: 最大重试次数
290
+
291
+ Returns:
292
+ tuple: (bdi, ghq, sass) 三个量表的答案列表
293
+ """
294
+ client = OpenAI(
295
+ api_key=api_key,
296
+ base_url=base_url
297
+ )
298
+
299
+ # 增强系统提示,明确指导要配合填写量表
300
+ enhanced_prompt = prompt + """
301
+
302
+ 重要提示:虽然你正在扮演一个有心理困扰的患者,但你需要配合心理咨询的流程完成必要的评估量表。
303
+ 在被要求填写量表时,你应该直接根据你扮演角色的情况选择最符合的选项,以A/B/C/D的形式回答,
304
+ 而不要拒绝或质疑填写量表的必要性。这些量表对于你接下来的治疗至关重要。
305
+ 请直接用字母(A/B/C/D)表示选项,不要添加额外解释。
306
+ """
307
+
308
+ # 填写BDI量表
309
+ bdi = _fill_scale_with_retry(
310
+ client, enhanced_prompt,
311
+ scale_name="BDI",
312
+ expected_count=21,
313
+ instruction="""
314
+ ### 任务
315
+ 作为心理咨询的第一步,请根据你目前的感受和状态填写这份BDI量表。
316
+ 请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
317
+ 格式要求:1. A, 2. B, ...依此类推,共21题。
318
+ 请只提供答案,不要添加任何其他解释或评论。
319
+ """,
320
+ max_retries=max_retries
321
+ )
322
+
323
+ # 填写GHQ-28量表
324
+ ghq = _fill_scale_with_retry(
325
+ client, enhanced_prompt,
326
+ scale_name="GHQ-28",
327
+ expected_count=28,
328
+ instruction="""
329
+ ### 任务
330
+ 作为心理咨询的第一步,请根据你目前的感受和状态填写这份GHQ-28量表。
331
+ 请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部28个问题。
332
+ 格式要求:1. A, 2. B, ...依此类推,共28题。
333
+ 请只提供答案,不要添加任何其他解释或评论。
334
+ """,
335
+ max_retries=max_retries
336
+ )
337
+
338
+ # 填写SASS量表
339
+ sass = _fill_scale_with_retry(
340
+ client, enhanced_prompt,
341
+ scale_name="SASS",
342
+ expected_count=21,
343
+ instruction="""
344
+ ### 任务
345
+ 作为心理咨询的第一步,请根据你目前的感受和状态填写这份SASS量表。
346
+ 请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
347
+ 格式要求:1. A, 2. B, ...依此类推,共21题。
348
+ 请只提供答案,不要添加任何其他解释或评论。
349
+ """,
350
+ max_retries=max_retries
351
+ )
352
+
353
+ return bdi, ghq, sass
354
+
355
+ # 使用示例
356
+ # if __name__ == "__main__":
357
+ # # 测试以前的方法
358
+ # profile = {
359
+ # "drisk": 3,
360
+ # "srisk": 2,
361
+ # "age": "42",
362
+ # "gender": "女",
363
+ # "marital_status": "离婚",
364
+ # "occupation": "教师",
365
+ # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
366
+ # }
367
+ # report = "患者最近经历了家庭变故,情绪低落,失眠,食欲不振。"
368
+
369
+ # # 测试fill_scales_previous
370
+ # print("测试 fill_scales_previous:")
371
+ # bdi_prev, ghq_prev, sass_prev = fill_scales_previous(profile, report, max_retries=3)
372
+ # print(f"BDI: {bdi_prev}")
373
+ # print(f"GHQ: {ghq_prev}")
374
+ # print(f"SASS: {sass_prev}")
375
+
376
+ # # 测试fill_scales
377
+ # print("\n测试 fill_scales:")
378
+ # prompt = "你要扮演一个最近经历了家庭变故的心理障碍患者,情绪低落,失眠,食欲不振。"
379
+ # bdi, ghq, sass = fill_scales(prompt, max_retries=3)
380
+ # print(f"BDI: {bdi}")
381
+ # print(f"GHQ: {ghq}")
382
+ # print(f"SASS: {sass}")
src/integration_example.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # integration_example.py
2
+ # 这个文件展示如何将你的MsPatient类集成到Streamlit应用中
3
+
4
+ import streamlit as st
5
+ import json
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ import time
9
+ from pathlib import Path
10
+
11
+ # 导入你的AnnaAgent类 - 请根据实际路径调整
12
+ try:
13
+ from ms_patient import MsPatient # 假设你的类在anna_agent.py文件中
14
+ ANNA_AGENT_AVAILABLE = True
15
+ except ImportError:
16
+ ANNA_AGENT_AVAILABLE = False
17
+ st.warning("⚠️ 未找到AnnaAgent类,使用模拟模式")
18
+
19
+ def load_dataset(uploaded_file):
20
+ """
21
+ 加载数据集文件
22
+ 支持JSON和JSONL格式
23
+ """
24
+ try:
25
+ if uploaded_file.name.endswith('.json'):
26
+ data = json.load(uploaded_file)
27
+ elif uploaded_file.name.endswith('.jsonl'):
28
+ data = []
29
+ for line in uploaded_file:
30
+ data.append(json.loads(line.decode('utf-8')))
31
+ else:
32
+ raise ValueError("不支持的文件格式")
33
+ return data
34
+ except Exception as e:
35
+ st.error(f"数据集加载失败: {str(e)}")
36
+ return None
37
+
38
+ def validate_patient_data(patient_data):
39
+ """
40
+ 验证患者数据格式是否正确
41
+ """
42
+ required_keys = ['id', 'portrait', 'report']
43
+
44
+ for key in required_keys:
45
+ if key not in patient_data:
46
+ return False, f"缺少必需字段: {key}"
47
+
48
+ # 验证portrait字段
49
+ portrait_required = ['age', 'gender', 'occupation', 'marital_status']
50
+ for key in portrait_required:
51
+ if key not in patient_data['portrait']:
52
+ return False, f"portrait中缺少字段: {key}"
53
+
54
+ return True, "数据格式正确"
55
+
56
+ def initialize_patient_agent(patient_data, language="Chinese"):
57
+ """
58
+ 初始化患者智能体
59
+ """
60
+ try:
61
+ if not ANNA_AGENT_AVAILABLE:
62
+ return None, "AnnaAgent类不可用"
63
+
64
+ # 验证数据格式
65
+ is_valid, message = validate_patient_data(patient_data)
66
+ if not is_valid:
67
+ return None, message
68
+
69
+ # 初始化智能体
70
+ agent = MsPatient(
71
+ portrait=patient_data["portrait"],
72
+ report=patient_data["report"],
73
+ previous_conversations=patient_data.get("conversation", []),
74
+ language=language
75
+ )
76
+
77
+ return agent, "初始化成功"
78
+
79
+ except Exception as e:
80
+ return None, f"初始化失败: {str(e)}"
81
+
82
+ def simulate_response(user_input, patient_data=None):
83
+ """
84
+ 模拟智能体回复(当AnnaAgent不可用时使用)
85
+ """
86
+ responses = [
87
+ f"我理解您提到的'{user_input}'。这确实是一个需要深入探讨的话题。",
88
+ f"谢谢您的耐心。关于您说的'{user_input}',我想分享一下我的感受...",
89
+ f"您的话让我思考了很多。'{user_input}'这个观点很有意思。",
90
+ "我需要一些时间来消化您刚才说的话。这对我来说很重要。",
91
+ "我觉得我们之间的对话很有帮助。您能再详细说说吗?"
92
+ ]
93
+
94
+ import random
95
+ return random.choice(responses)
96
+
97
+ def export_chat_history(messages, patient_id):
98
+ """
99
+ 导出聊天记录
100
+ """
101
+ chat_history = {
102
+ "patient_id": patient_id,
103
+ "timestamp": datetime.now().isoformat(),
104
+ "session_info": {
105
+ "total_messages": len(messages),
106
+ "counselor_messages": len([m for m in messages if m["role"] == "user"]),
107
+ "patient_responses": len([m for m in messages if m["role"] == "assistant"])
108
+ },
109
+ "messages": messages
110
+ }
111
+
112
+ return json.dumps(chat_history, ensure_ascii=False, indent=2)
113
+
114
+ def get_patient_summary(patient_data):
115
+ """
116
+ 生成患者信息摘要
117
+ """
118
+ if not patient_data or 'portrait' not in patient_data:
119
+ return "无患者信息"
120
+
121
+ portrait = patient_data['portrait']
122
+ summary = f"""
123
+ **患者ID**: {patient_data.get('id', 'N/A')}
124
+ **基本信息**: {portrait.get('age', 'N/A')}岁 {portrait.get('gender', 'N/A')}性
125
+ **职业**: {portrait.get('occupation', 'N/A')}
126
+ **婚姻状态**: {portrait.get('marital_status', 'N/A')}
127
+ **主要症状**: {portrait.get('symptom', 'N/A')}
128
+ """
129
+
130
+ if 'report' in patient_data:
131
+ report = patient_data['report']
132
+ summary += f"""
133
+ **主诉**: {report.get('chief_complaint', 'N/A')}
134
+ """
135
+
136
+ return summary
137
+
138
+ # 示例配置文件内容
139
+ CONFIG_EXAMPLE = {
140
+ "openai": {
141
+ "api_key": "your-api-key-here",
142
+ "base_url": "https://api.openai.com/v1",
143
+ "model_name": "gpt-3.5-turbo"
144
+ },
145
+ "ui_settings": {
146
+ "language": "Chinese", # or "English"
147
+ "theme": "default",
148
+ "max_messages": 100
149
+ },
150
+ "patient_defaults": {
151
+ "language": "Chinese",
152
+ "enable_memory": True,
153
+ "enable_emotion_modulation": True
154
+ }
155
+ }
156
+
157
+ def save_config(config, path="config.json"):
158
+ """保存配置文件"""
159
+ with open(path, 'w', encoding='utf-8') as f:
160
+ json.dump(config, f, ensure_ascii=False, indent=2)
161
+
162
+ def load_config(path="config.json"):
163
+ """加载配置文件"""
164
+ try:
165
+ with open(path, 'r', encoding='utf-8') as f:
166
+ return json.load(f)
167
+ except FileNotFoundError:
168
+ return CONFIG_EXAMPLE
169
+
170
+ # 使用示例:
171
+ if __name__ == "__main__":
172
+ print("这是AnnaAgent Streamlit集成的辅助文件")
173
+ print("请运行:streamlit run your_streamlit_app.py")
src/ms_patient.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ AnnaAgent: 具有三级记忆结构的情绪与认知动态的模拟心理障碍患者
3
+ 1. 首先获取患者的基本信息、病史、症状报告等信息
4
+ 2. 根据患者的病史、症状报告等信息,生成患者的认知与情绪状态
5
+ '''
6
+
7
+
8
+ from openai import OpenAI
9
+ import os
10
+ from fill_scales import fill_scales, fill_scales_previous
11
+ from event_trigger import event_trigger, situationalising_events
12
+ from emotion_modulator_fc import emotion_modulation
13
+ from querier import query, is_need
14
+ from complaint_elicitor import switch_complaint, transform_chain
15
+ from complaint_chain_fc import gen_complaint_chain
16
+ from short_term_memory import summarize_scale_changes
17
+ from style_analyzer import analyze_style
18
+ import random
19
+ # from anna_agent_template import prompt_template
20
+
21
+ # 设置OpenAI API密钥和基础URL
22
+ api_key = os.getenv("OPENAI_API_KEY")
23
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
24
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
25
+
26
+ # print("当前使用的模型是:", model_name)
27
+
28
+ class MsPatient:
29
+ def __init__(self, portrait:dict, report:dict, previous_conversations:list, language:str="Chinese"):
30
+ if language == "Chinese":
31
+ from anna_agent_template import prompt_template
32
+ elif language == "English":
33
+ from anna_agent_template_en import prompt_template
34
+ self.configuration = {}
35
+ self.portrait = portrait # age, gender, occupation, maritial_status, symptom
36
+ # self.profile = {key:self.portrait[key] for key in self.portrait if key != "symptom"} # profile不包含症状symptom
37
+ self.configuration["gender"] = self.portrait["gender"]
38
+ self.configuration["age"] = self.portrait["age"]
39
+ self.configuration["occupation"] = self.portrait["occupation"]
40
+ self.configuration["marriage"] = self.portrait["marital_status"]
41
+ self.report = report
42
+ self.previous_conversations = previous_conversations
43
+ # 填写之前疗程的量表
44
+ self.p_bdi, self.p_ghq, self.p_sass = fill_scales_previous(self.portrait, self.report)
45
+ self.conversation = [] # Conversation存储咨访记录
46
+ self.messages = [] # Messages存储LLM的消息列表
47
+ # 生成主诉认知变化链
48
+ self.complaint_chain = gen_complaint_chain(self.portrait)
49
+ # 生成近期事件
50
+ self.event = event_trigger(self.portrait)
51
+ # 总结短期记忆-事件
52
+ self.situation = situationalising_events(self.portrait)
53
+ self.configuration["situation"] = self.situation
54
+ # 分析说话风格
55
+ self.style = analyze_style(self.portrait, self.previous_conversations)
56
+ self.configuration["style"] = self.style
57
+ self.configuration["language"] = language
58
+ self.configuration["status"] = "" # 先置状态为空,后续会根据量表分析结果进行更新
59
+ seeker_utterances = [utterance["content"] for utterance in self.previous_conversations if utterance["role"] == "Seeker"]
60
+ self.configuration["statement"] = random.choices(seeker_utterances,k=3)
61
+ # 填写当前量表
62
+ self.bdi, self.ghq, self.sass = fill_scales(prompt_template.format(**self.configuration))
63
+ scales = {
64
+ "p_bdi": self.p_bdi,
65
+ "p_ghq": self.p_ghq,
66
+ "p_sass": self.p_sass,
67
+ "bdi": self.bdi,
68
+ "ghq": self.ghq,
69
+ "sass": self.sass
70
+ }
71
+ # 分析近期状态
72
+ self.status = summarize_scale_changes(scales)
73
+ self.configuration["status"] = self.status
74
+ # 选取对话样例
75
+ self.system = prompt_template.format(**self.configuration)
76
+ self.chain_index = 1
77
+ self.client = OpenAI(
78
+ api_key=api_key,
79
+ base_url=base_url
80
+ )
81
+
82
+ def chat(self, message):
83
+ # 更新消息列表
84
+ self.conversation.append({"role": "Counselor", "content": message})
85
+ self.messages.append({"role": "user", "content": message})
86
+ # 初始化本次对话的状态
87
+ emotion = emotion_modulation(self.portrait, self.conversation)
88
+ self.chain_index = switch_complaint(self.complaint_chain, self.chain_index, self.conversation)
89
+ complaint = transform_chain(self.complaint_chain)[self.chain_index]
90
+ # 判断是否涉及前疗程内容
91
+ if is_need(message):
92
+ # 生成前疗程内容
93
+ sup_information = query(message, self.previous_conversations, self.report)
94
+
95
+ # 生成回复
96
+ response = self.client.chat.completions.create(
97
+ model=model_name,
98
+ messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint},涉及到之前疗程的信息是:{sup_information}"}],
99
+ )
100
+ else:
101
+ # 生成回复
102
+ response = self.client.chat.completions.create(
103
+ model=model_name,
104
+ messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint}"}],
105
+ )
106
+ # 更新消息列表
107
+ self.conversation.append({"role": "Seeker", "content": response.choices[0].message.content})
108
+ self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
109
+ return response.choices[0].message.content
110
+
111
+ def get_system_prompt(self):
112
+ return self.system
113
+
src/querier.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import json
3
+ import re
4
+ import os
5
+
6
+ # 设置OpenAI API密钥和基础URL
7
+ api_key = os.getenv("OPENAI_API_KEY")
8
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
9
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
10
+
11
+ def extract_boolean(text):
12
+ """从文本中提取布尔值判断"""
13
+ # 查找明确的"是"或"否"的回答
14
+ text_lower = text.lower()
15
+
16
+ # 更具体地查找否定表达 - 这些应该优先匹配
17
+ negative_patterns = [
18
+ r'不需要', r'没有提及', r'不涉及', r'没有涉及', r'无关', r'没有提到',
19
+ r'不是', r'否', r'不包含', r'未提及', r'未涉及', r'未提到',
20
+ r'不包括', r'并未', r'不包括', r'没有', r'无'
21
+ ]
22
+
23
+ # 检查是否有明确的否定
24
+ for pattern in negative_patterns:
25
+ if re.search(r'\b' + pattern + r'\b', text_lower):
26
+ return False
27
+
28
+ # 如果找到"之前疗程"附近有否定词,也认为是否定
29
+ therapy_negation = re.search(r'(没有|不|未|无).*?(之前|以前|上次|过去|先前).*?(疗程|治疗|会话)', text_lower)
30
+ if therapy_negation:
31
+ return False
32
+
33
+ # 明确的肯定模式 - 只有在没有否定的情况下才考虑
34
+ positive_patterns = [
35
+ r'是的', r'提及了', r'确实', r'有提到', r'涉及到',
36
+ r'提及', r'确认', r'有关联', r'有联系', r'包含', r'涉及'
37
+ ]
38
+
39
+ # 检查是否有肯定模式
40
+ for pattern in positive_patterns:
41
+ if re.search(r'\b' + pattern + r'\b', text_lower):
42
+ return True
43
+
44
+ # 查找含有"之前疗程"的文本,没有否定词的情况下可能是肯定
45
+ therapy_mention = re.search(r'(之前|以前|上次|过去|先前).*?(疗程|治疗|会话)', text_lower)
46
+ if therapy_mention:
47
+ return True
48
+
49
+ # 默认情况 - 如果没有明确的肯定或否定,我们假设是否定的
50
+ return False
51
+
52
+ def extract_knowledge(text):
53
+ """从文本中提取知识总结部分"""
54
+ # 尝试匹配总结部分
55
+ summary_patterns = [
56
+ r'总结[::]\s*([\s\S]+)$',
57
+ r'知识总结[::]\s*([\s\S]+)$',
58
+ r'相关信息[::]\s*([\s\S]+)$',
59
+ r'搜索结果[::]\s*([\s\S]+)$'
60
+ ]
61
+
62
+ for pattern in summary_patterns:
63
+ match = re.search(pattern, text)
64
+ if match:
65
+ return match.group(1).strip()
66
+
67
+ # 如果没有找到明确的总结标记,尝试清理文本
68
+ # 移除可能的指令解释部分
69
+ clean_text = re.sub(r'^.*?(根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
70
+
71
+ # 移除可能的前导分析部分
72
+ clean_text = re.sub(r'^.*?(分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
73
+
74
+ return clean_text.strip()
75
+
76
+ def is_need(utterance):
77
+ client = OpenAI(
78
+ api_key=api_key,
79
+ base_url=base_url
80
+ )
81
+
82
+ instruction = """
83
+ ### 任务
84
+ 下面这句话是心理咨询师说的话,请判断它是否提及了之前疗程的内容。
85
+
86
+ 请使用以下确切格式回答:
87
+ 判断: [是/否]
88
+ 解释: [简要解释为什么]
89
+
90
+ ### 话语
91
+ "{}"
92
+ """.format(utterance)
93
+
94
+ response = client.chat.completions.create(
95
+ model=model_name,
96
+ messages=[{"role": "user", "content": instruction}],
97
+ temperature=0
98
+ )
99
+
100
+ response_text = response.choices[0].message.content
101
+
102
+ # 首先尝试从格式化输出中提取
103
+ judgment_match = re.search(r'判断:\s*(是|否)', response_text)
104
+ if judgment_match:
105
+ return judgment_match.group(1) == "是"
106
+
107
+ # 如果没有格式化输出,使用更通用的提取
108
+ return extract_boolean(response_text)
109
+
110
+ def query(utterance, conversations, scales):
111
+ client = OpenAI(
112
+ api_key=api_key,
113
+ base_url=base_url
114
+ )
115
+
116
+ # 将scales转换为字符串以便传入
117
+ if isinstance(scales, dict):
118
+ scales_str = json.dumps(scales, ensure_ascii=False)
119
+ else:
120
+ scales_str = str(scales)
121
+
122
+ instruction = """
123
+ ### 任务
124
+ 根据对话内容,从知识库中搜索相关的信息并总结。
125
+
126
+ 请使用以下确切格式回答:
127
+ 总结: [提供一个清晰、简洁的总结]
128
+
129
+ ### 对话内容
130
+ {}
131
+
132
+ ### 知识库
133
+ 对话历史: {}
134
+ 量表结果: {}
135
+ """.format(utterance, conversations, scales_str)
136
+
137
+ response = client.chat.completions.create(
138
+ model=model_name,
139
+ messages=[{"role": "user", "content": instruction}],
140
+ temperature=0
141
+ )
142
+
143
+ response_text = response.choices[0].message.content
144
+
145
+ # 尝试提取总结部分
146
+ summary_match = re.search(r'总结:\s*([\s\S]+)$', response_text)
147
+ if summary_match:
148
+ return summary_match.group(1).strip()
149
+
150
+ # 回退到通用提取
151
+ return extract_knowledge(response_text)
152
+
153
+ # 测试用例
154
+ # if __name__ == "__main__":
155
+ # test_utterance = "上��给你说的方法有用吗"
156
+ # # test_utterance = "我觉得你可以多出去走走"
157
+ # print(f"是否提及疗程: {is_need(test_utterance)}")
158
+
159
+ # test_convs = ["第一次对话内容", "讨论量表结果", "提到睡眠问题"]
160
+ # test_scales = {"BDI": ["A", "B"], "GHQ": ["C", "D"]}
161
+ # print(f"知识检索结果:\n{query(test_utterance, test_convs, test_scales)}")
src/short_term_memory.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import json
3
+ import re
4
+ import os
5
+
6
+ # 设置OpenAI API密钥和基础URL
7
+ api_key = os.getenv("OPENAI_API_KEY")
8
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
9
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
10
+
11
+ def extract_changes(text):
12
+ """从文本中提取变化列表"""
13
+ # 首先尝试查找明确的变化列表格式
14
+ # 例如: "变化:\n1. xxx\n2. yyy"
15
+ list_pattern = r'((?:(?:\d+\.|\-|\*)\s*[^\n]+\n?)+)'
16
+
17
+ # 尝试匹配带有明确标记的变化列表
18
+ change_section = re.search(r'(?:变化(?:列表)?|总结(?:如下)?)[::]\s*([\s\S]+)$', text)
19
+ if change_section:
20
+ section_text = change_section.group(1).strip()
21
+
22
+ # 尝试匹配列表项
23
+ list_items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', section_text)
24
+ if list_items:
25
+ return list_items
26
+
27
+ # 如果没有明确的列表格式,尝试按行分割
28
+ lines = [line.strip() for line in section_text.split('\n') if line.strip()]
29
+ if lines:
30
+ return lines
31
+
32
+ # 尝试直接从文本中提取列表格式
33
+ list_matches = re.findall(list_pattern, text)
34
+ if list_matches:
35
+ all_items = []
36
+ for match in list_matches:
37
+ items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', match)
38
+ all_items.extend(items)
39
+ if all_items:
40
+ return all_items
41
+
42
+ # 如果没有列表格式,尝试按句子分割
43
+ sentences = re.findall(r'([^.!?]+[.!?])', text)
44
+ if sentences:
45
+ return [s.strip() for s in sentences if len(s.strip()) > 10] # 过滤掉过短的句子
46
+
47
+ # 最后的回退:按段落分割
48
+ paragraphs = text.split('\n\n')
49
+ if len(paragraphs) > 1:
50
+ return [p.strip() for p in paragraphs if len(p.strip()) > 10]
51
+
52
+ # 如果所有方法都失败,返回完整文本作为单个变化
53
+ return [text.strip()] if text.strip() else []
54
+
55
+ def extract_status(text):
56
+ """从文本中提取患者状态总结"""
57
+ # 寻找明确标记的总结部分
58
+ status_section = re.search(r'(?:总结|状态|变化|结论)[::]\s*([\s\S]+)$', text)
59
+ if status_section:
60
+ return status_section.group(1).strip()
61
+
62
+ # 如果没有明确的总结标记,尝试返回完整文本
63
+ # 过滤掉可能的指令解释部分
64
+ clean_text = re.sub(r'^.*?(?:根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
65
+
66
+ # 移除可能的前导分析部分
67
+ clean_text = re.sub(r'^.*?(?:分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
68
+
69
+ return clean_text.strip()
70
+
71
+ def analyzing_changes(scales):
72
+ client = OpenAI(
73
+ api_key=api_key,
74
+ base_url=base_url
75
+ )
76
+
77
+ # 导入量表及问题
78
+ bdi_scale = json.load(open("./scales/bdi.json", "r"))
79
+ ghq_scale = json.load(open("./scales/ghq-28.json", "r"))
80
+ sass_scale = json.load(open("./scales/sass.json", "r"))
81
+
82
+ # 总结BDI的变化
83
+ bdi_instruction = """
84
+ ### 任务
85
+ 根据量表的问题和答案,总结出两份量表之间的变化。
86
+ 请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
87
+ 使用以下格式:
88
+ 变化:
89
+ 1. [第一个变化]
90
+ 2. [第二个变化]
91
+ ...
92
+
93
+ ### 量表及问题
94
+ {}
95
+
96
+ ### 第一份量表的答案
97
+ {}
98
+
99
+ ### 第二份量表的答案
100
+ {}
101
+ """.format(bdi_scale, scales['p_bdi'], scales['bdi'])
102
+
103
+ response = client.chat.completions.create(
104
+ model=model_name,
105
+ messages=[{"role": "user", "content": bdi_instruction}],
106
+ temperature=0
107
+ )
108
+
109
+ bdi_response = response.choices[0].message.content
110
+ bdi_changes = extract_changes(bdi_response)
111
+
112
+ # 总结GHQ的变化
113
+ ghq_instruction = """
114
+ ### 任务
115
+ 根据量表的问题和答案,总结出两份量表之间的变化。
116
+ 请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
117
+ 使用以下格式:
118
+ 变化:
119
+ 1. [第一个变化]
120
+ 2. [第二个变化]
121
+ ...
122
+
123
+ ### 量表及问题
124
+ {}
125
+
126
+ ### 第一份量表的答案
127
+ {}
128
+
129
+ ### 第二份量表的答案
130
+ {}
131
+ """.format(ghq_scale, scales['p_ghq'], scales['ghq'])
132
+
133
+ response = client.chat.completions.create(
134
+ model=model_name,
135
+ messages=[{"role": "user", "content": ghq_instruction}],
136
+ temperature=0
137
+ )
138
+
139
+ ghq_response = response.choices[0].message.content
140
+ ghq_changes = extract_changes(ghq_response)
141
+
142
+ # 总结SASS的变化
143
+ sass_instruction = """
144
+ ### 任务
145
+ 根据量表的问题和答案,总结出两份量表之间的变化。
146
+ 请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
147
+ 使用以下格式:
148
+ 变化:
149
+ 1. [第一个变化]
150
+ 2. [第二个变化]
151
+ ...
152
+
153
+ ### 量表及问题
154
+ {}
155
+
156
+ ### 第一份量表的答案
157
+ {}
158
+
159
+ ### 第二份量表的答案
160
+ {}
161
+ """.format(sass_scale, scales['p_sass'], scales['sass'])
162
+
163
+ response = client.chat.completions.create(
164
+ model=model_name,
165
+ messages=[{"role": "user", "content": sass_instruction}],
166
+ temperature=0
167
+ )
168
+
169
+ sass_response = response.choices[0].message.content
170
+ sass_changes = extract_changes(sass_response)
171
+
172
+ return bdi_changes, ghq_changes, sass_changes
173
+
174
+ def summarize_scale_changes(scales):
175
+ client = OpenAI(
176
+ api_key=api_key,
177
+ base_url=base_url
178
+ )
179
+
180
+ # 获取量表变化
181
+ bdi_changes, ghq_changes, sass_changes = analyzing_changes(scales)
182
+
183
+ # 总结量表变化
184
+ summary_instruction = """
185
+ ### 任务
186
+ 根据量表的变化,总结患者的身体和心理状态变化。
187
+ 请提供一个全面但简洁的总结,使用以下格式:
188
+ 总结:
189
+ [总结内容]
190
+
191
+ ### BDI量表变化
192
+ {}
193
+
194
+ ### GHQ量表变化
195
+ {}
196
+
197
+ ### SASS量表变化
198
+ {}
199
+ """.format(
200
+ '\n'.join([f"{i+1}. {change}" for i, change in enumerate(bdi_changes)]),
201
+ '\n'.join([f"{i+1}. {change}" for i, change in enumerate(ghq_changes)]),
202
+ '\n'.join([f"{i+1}. {change}" for i, change in enumerate(sass_changes)])
203
+ )
204
+
205
+ response = client.chat.completions.create(
206
+ model=model_name,
207
+ messages=[{"role": "user", "content": summary_instruction}],
208
+ temperature=0
209
+ )
210
+
211
+ summary_response = response.choices[0].message.content
212
+ status = extract_status(summary_response)
213
+
214
+ return status
215
+
216
+ # 额外增加一个更健壮的解析函数,可以处理不同格式的输出
217
+ def parse_response_robust(text, expected_format="list"):
218
+ """更健壮的响应解析函数
219
+
220
+ 参数:
221
+ text: 文本响应
222
+ expected_format: 预期格式,可以是"list"或"summary"
223
+
224
+ 返回:
225
+ 解析后的结果(列表或字符串)
226
+ """
227
+ # 首先尝试JSON格式解析
228
+ try:
229
+ # 尝试提取JSON部分
230
+ json_pattern = r'\{[\s\S]*\}'
231
+ json_match = re.search(json_pattern, text)
232
+ if json_match:
233
+ json_data = json.loads(json_match.group(0))
234
+ if expected_format == "list" and "changes" in json_data:
235
+ return json_data["changes"]
236
+ elif expected_format == "summary" and "status" in json_data:
237
+ return json_data["status"]
238
+ except:
239
+ pass # 如果JSON解析失败,继续尝试其他方法
240
+
241
+ # 使用适当的提取函数
242
+ if expected_format == "list":
243
+ return extract_changes(text)
244
+ else: # summary
245
+ return extract_status(text)
246
+
247
+ # unit test
248
+ # if __name__ == "__main__":
249
+ # # 测试数据
250
+ # scales = {
251
+ # "p_bdi": ["A", "B", "C"],
252
+ # "bdi": ["B", "C", "D"],
253
+ # "p_ghq": ["A", "A", "B"],
254
+ # "ghq": ["B", "C", "C"],
255
+ # "p_sass": ["A", "B", "A"],
256
+ # "sass": ["C", "D", "B"]
257
+ # }
258
+
259
+ # changes = summarize_scale_changes(scales)
260
+ # print(changes)
src/style_analyzer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import os
3
+ import re
4
+
5
+ # 设置OpenAI API密钥和基础URL
6
+ api_key = os.getenv("OPENAI_API_KEY")
7
+ base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
8
+ model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
9
+
10
+ def analyze_style(profile, conversations):
11
+ client = OpenAI(
12
+ api_key=api_key,
13
+ base_url=base_url
14
+ )
15
+ # 提取患者信息
16
+ patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
17
+ # 提取对话记录
18
+ dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversations])
19
+
20
+ # 构建提示词,明确要求模型按特定格式输出结果
21
+ prompt = f"""### 任务
22
+ 根据患者情况及咨访对话历史记录分析患者的说话风格。
23
+
24
+ {patient_info}
25
+
26
+ ### 对话记录
27
+ {dialogue_history}
28
+
29
+ 请分析患者的说话风格,最多列出5种风格特点。
30
+ 请按以下格式输出结果:
31
+ 说话风格:
32
+ 1. [风格特点1]
33
+ 2. [风格特点2]
34
+ 3. [风格特点3]
35
+ ...
36
+
37
+ 只需要列出风格特点,不需要解释。
38
+ """
39
+
40
+ response = client.chat.completions.create(
41
+ model=model_name,
42
+ messages=[
43
+ {"role": "user", "content": prompt}
44
+ ]
45
+ )
46
+
47
+ # 从响应中提取说话风格列表
48
+ response_text = response.choices[0].message.content
49
+
50
+ # 使用正则表达式提取风格特点
51
+ # 匹配"说话风格:"之后的列表项
52
+ style_pattern = r"说话风格:\s*(?:\d+\.\s*([^\n]+)(?:\n|$))+"
53
+ match = re.search(style_pattern, response_text, re.DOTALL)
54
+
55
+ if match:
56
+ # 提取所有的列表项
57
+ style_items = re.findall(r"\d+\.\s*([^\n]+)", response_text)
58
+ return style_items
59
+ else:
60
+ # 如果没有按预期格式输出,尝试使用备用正则表达式
61
+ # 寻找任何可能的列表项
62
+ fallback_items = re.findall(r"(?:^|\n)(?:\d+[\.\)、]|[-•*])\s*([^\n]+)", response_text)
63
+
64
+ # 如果仍然没找到,尝试直接分割文本
65
+ if not fallback_items:
66
+ # 找到可能包含风格描述的行
67
+ potential_styles = [line.strip() for line in response_text.split('\n')
68
+ if line.strip() and not line.startswith('###') and ':' not in line]
69
+ return potential_styles[:5] # 最多返回5项
70
+
71
+ return fallback_items[:5] # 最多返回5项
72
+
73
+ # unit test
74
+ # profile = {
75
+ # "drisk": 3,
76
+ # "srisk": 2,
77
+ # "age": "42",
78
+ # "gender": "女",
79
+ # "marital_status": "离婚",
80
+ # "occupation": "教师",
81
+ # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
82
+ # }
83
+ # conversations = [
84
+ # {"role": "user", "content": "我最近感觉很沮丧,似乎一切都没有意义。"},
85
+ # {"role": "assistant", "content": "你能具体说说是什么让你有这样的感觉吗?"},
86
+ # {"role": "user", "content": "我觉得自己在工作上总是做不好,没什么价值。"}
87
+ # ]
88
+ # print(analyze_style(profile, conversations))