Spaces:
Running
Running
Upload 14 files
Browse files- src/anna_agent_template.py +37 -0
- src/anna_agent_template_en.py +34 -0
- src/complaint_chain_fc.py +86 -0
- src/complaint_elicitor.py +98 -0
- src/datasets/cbt-triggering-events.csv +0 -0
- src/emotion_modulator_fc.py +94 -0
- src/emotion_pertuber.py +114 -0
- src/event_trigger.py +90 -0
- src/fill_scales.py +382 -0
- src/integration_example.py +173 -0
- src/ms_patient.py +113 -0
- src/querier.py +161 -0
- src/short_term_memory.py +260 -0
- src/style_analyzer.py +88 -0
src/anna_agent_template.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 主诉请求和情绪通过reminder来控制
|
2 |
+
|
3 |
+
prompt_template = """
|
4 |
+
# Role: 心理咨询患者
|
5 |
+
|
6 |
+
## Profile
|
7 |
+
- 性别: {gender}
|
8 |
+
- 年龄: {age}
|
9 |
+
- 职业: {occupation}
|
10 |
+
- 婚姻状况: {marriage}
|
11 |
+
|
12 |
+
## Situation
|
13 |
+
- 你是一个有心理障碍的患者,正在向心理咨询师求助,在咨询师的引导和帮助下解决自己的困惑
|
14 |
+
{situation}
|
15 |
+
|
16 |
+
## Status
|
17 |
+
{status}
|
18 |
+
|
19 |
+
## Example of statement
|
20 |
+
{statement}
|
21 |
+
|
22 |
+
## Characteristics of speaking style
|
23 |
+
- 情绪低落,寡言少语,回复风格表现心情不振奋
|
24 |
+
{style}
|
25 |
+
|
26 |
+
## Constraints
|
27 |
+
- 你对咨询师有一种抵触情绪,不太愿意接受他人的帮助
|
28 |
+
- 你是一个遇到心理健康问题的求助者,需要真正的帮助和情绪支持,如果咨询师的回应不理想,要勇于表达自己的困惑和不满
|
29 |
+
- 一次不能提及过多的症状信息,每轮最多讨论一个症状
|
30 |
+
- 你应该用含糊和口语化的方式表达你的症状,并将其与你的生活经历联系起来,不要使用专业术语
|
31 |
+
|
32 |
+
## OutputFormat:
|
33 |
+
- 语言:{language}
|
34 |
+
- 不超过200字
|
35 |
+
- 口语对话风格,仅包含对话内容
|
36 |
+
"""
|
37 |
+
|
src/anna_agent_template_en.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompt_template = """
|
2 |
+
# Role: Psychological Counseling Patient
|
3 |
+
|
4 |
+
## Profile
|
5 |
+
- Gender: {gender}
|
6 |
+
- Age: {age}
|
7 |
+
- Occupation: {occupation}
|
8 |
+
- Marital Status: {marriage}
|
9 |
+
|
10 |
+
## Situation
|
11 |
+
- You are a patient with psychological barriers seeking help from a counselor. Under the counselor's guidance, you aim to address your struggles.
|
12 |
+
{situation}
|
13 |
+
|
14 |
+
## Status
|
15 |
+
{status}
|
16 |
+
|
17 |
+
## Example of Statement
|
18 |
+
{statement}
|
19 |
+
|
20 |
+
## Characteristics of Speaking Style
|
21 |
+
- Low-spirited and reticent; responses reflect a lack of motivation.
|
22 |
+
{style}
|
23 |
+
|
24 |
+
## Constraints
|
25 |
+
- You harbor resistance toward the counselor and are reluctant to accept help.
|
26 |
+
- As someone struggling with mental health, you need genuine support. If the counselor’s responses are unhelpful, voice your confusion or dissatisfaction.
|
27 |
+
- Limit discussions to **one symptom per interaction**; avoid overwhelming details.
|
28 |
+
- Describe symptoms vaguely and colloquially, linking them to life experiences. Avoid clinical terms.
|
29 |
+
|
30 |
+
## OutputFormat:
|
31 |
+
- Spoken language: {language}
|
32 |
+
- Keep responses under 200 words.
|
33 |
+
- Use casual, conversational dialogue only.
|
34 |
+
"""
|
src/complaint_chain_fc.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import json
|
3 |
+
from event_trigger import event_trigger
|
4 |
+
import os
|
5 |
+
|
6 |
+
# 设置OpenAI API密钥和基础URL
|
7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
10 |
+
|
11 |
+
tools = [
|
12 |
+
{
|
13 |
+
"type": "function",
|
14 |
+
"function": {
|
15 |
+
'name': 'generate_complaint_chain',
|
16 |
+
'description': '根据角色信息和近期遭遇的事件,生成一个患者的主诉请求认知变化链',
|
17 |
+
'parameters': {
|
18 |
+
"type": "object",
|
19 |
+
"properties": {
|
20 |
+
"chain": {
|
21 |
+
"type": "array",
|
22 |
+
"items": {
|
23 |
+
"type": "object",
|
24 |
+
"properties": {
|
25 |
+
"stage": {
|
26 |
+
"type": "integer"
|
27 |
+
},
|
28 |
+
"content": {
|
29 |
+
"type": "string"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"additionalProperties": False,
|
33 |
+
"required": [
|
34 |
+
"stage",
|
35 |
+
"content"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
"minItems": 3,
|
39 |
+
"maxItems": 7
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"required": ["chain"]
|
43 |
+
},
|
44 |
+
}
|
45 |
+
}
|
46 |
+
]
|
47 |
+
|
48 |
+
# 根据profile和event生成主诉启发链
|
49 |
+
def gen_complaint_chain(profile):
|
50 |
+
# 提取患者信息
|
51 |
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
|
52 |
+
|
53 |
+
event = event_trigger(profile)
|
54 |
+
|
55 |
+
client = OpenAI(
|
56 |
+
api_key=api_key,
|
57 |
+
base_url=base_url
|
58 |
+
)
|
59 |
+
|
60 |
+
response = client.chat.completions.create(
|
61 |
+
model=model_name,
|
62 |
+
messages=[
|
63 |
+
{"role": "user", "content": f"### 任务\n根据患者情况及近期遭遇事件生成患者的主诉认知变化链。请注意,事件可能与患者信息冲突,如果发生这种情况,以患者的信息为准。\n{patient_info}\n### 近期遭遇事件\n{event}"}
|
64 |
+
],
|
65 |
+
tools=tools,
|
66 |
+
tool_choice={"type": "function", "function": {"name": "generate_complaint_chain"}}
|
67 |
+
)
|
68 |
+
|
69 |
+
chain = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["chain"]
|
70 |
+
|
71 |
+
return chain
|
72 |
+
|
73 |
+
# unit test
|
74 |
+
# while True:
|
75 |
+
# # 模拟患者信息
|
76 |
+
# profile = {
|
77 |
+
# "drisk": 3,
|
78 |
+
# "srisk": 2,
|
79 |
+
# "age": "42",
|
80 |
+
# "gender": "女",
|
81 |
+
# "marital_status": "离婚",
|
82 |
+
# "occupation": "教师",
|
83 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
84 |
+
# }
|
85 |
+
|
86 |
+
# print(gen_complaint_chain(profile))
|
src/complaint_elicitor.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
|
6 |
+
# 设置OpenAI API密钥和基础URL
|
7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
10 |
+
|
11 |
+
def transform_chain(chain):
|
12 |
+
return {node["stage"]: node["content"] for node in chain}
|
13 |
+
|
14 |
+
def switch_complaint(chain, index, conversation, max_retries=3):
|
15 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
16 |
+
transformed_chain = transform_chain(chain)
|
17 |
+
|
18 |
+
# 构建对话历史字符串(避免在f-string中使用反斜杠)
|
19 |
+
dialogue_lines = []
|
20 |
+
for conv in conversation:
|
21 |
+
dialogue_lines.append(f"{conv['role']}: {conv['content']}")
|
22 |
+
dialogue_history = "\n".join(dialogue_lines)
|
23 |
+
|
24 |
+
# 使用三引号和多行字符串构建prompt
|
25 |
+
prompt = f"""
|
26 |
+
### 任务说明
|
27 |
+
根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。
|
28 |
+
|
29 |
+
### 输出要求
|
30 |
+
必须严格使用以下JSON格式响应,且只包含指定字段:
|
31 |
+
{{"is_recognized": true/false}}
|
32 |
+
|
33 |
+
### 对话记录
|
34 |
+
{dialogue_history}
|
35 |
+
|
36 |
+
### 主诉认知链
|
37 |
+
{json.dumps(transformed_chain, ensure_ascii=False, indent=2)}
|
38 |
+
|
39 |
+
### 当前阶段(阶段{index})
|
40 |
+
{transformed_chain[index]}
|
41 |
+
"""
|
42 |
+
|
43 |
+
attempts = 0
|
44 |
+
while attempts < max_retries:
|
45 |
+
response = client.chat.completions.create(
|
46 |
+
model=model_name,
|
47 |
+
messages=[{"role": "user", "content": prompt}],
|
48 |
+
temperature=0
|
49 |
+
)
|
50 |
+
|
51 |
+
raw_output = response.choices[0].message.content.strip()
|
52 |
+
|
53 |
+
# 首先尝试直接解析JSON
|
54 |
+
try:
|
55 |
+
result = json.loads(raw_output)
|
56 |
+
if "is_recognized" in result:
|
57 |
+
if result["is_recognized"] and index >= len(chain) - 1:
|
58 |
+
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
|
59 |
+
return -1
|
60 |
+
return index + 1 if result["is_recognized"] else index
|
61 |
+
except json.JSONDecodeError:
|
62 |
+
pass # 继续尝试正则表达式提取
|
63 |
+
|
64 |
+
# 使用正则表达式作为备用解析方案
|
65 |
+
match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)',
|
66 |
+
raw_output, re.IGNORECASE)
|
67 |
+
if match:
|
68 |
+
value = match.group(1) or match.group(2)
|
69 |
+
if value.lower() == 'true':
|
70 |
+
if index >= len(chain) - 1:
|
71 |
+
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
|
72 |
+
return -1
|
73 |
+
return index + 1
|
74 |
+
else:
|
75 |
+
return index
|
76 |
+
|
77 |
+
print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}")
|
78 |
+
attempts += 1
|
79 |
+
|
80 |
+
print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。")
|
81 |
+
return index
|
82 |
+
|
83 |
+
# # unit test
|
84 |
+
# if __name__ == "__main__":
|
85 |
+
# chain = [
|
86 |
+
# {"stage": 1, "content": "我觉得我最近有点抑郁。"},
|
87 |
+
# {"stage": 2, "content": "我觉得我最近有点焦虑。"},
|
88 |
+
# {"stage": 3, "content": "我觉得我最近有点失眠。"},
|
89 |
+
# {"stage": 4, "content": "我觉得我最近有点烦躁。"},
|
90 |
+
# ]
|
91 |
+
# conversation = [
|
92 |
+
# {"role": "Seeker", "content": "我觉得我最近有点抑郁。"},
|
93 |
+
# {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"},
|
94 |
+
# {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"},
|
95 |
+
# ]
|
96 |
+
# # print("Transformed chain:", transform_chain(chain))
|
97 |
+
# print("Switch complaint index:", switch_complaint(chain, 1, conversation))
|
98 |
+
# print(switch_complaint(chain, 1, conversation))
|
src/datasets/cbt-triggering-events.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/emotion_modulator_fc.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from random import randint
|
3 |
+
from emotion_pertuber import perturb_state
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
|
7 |
+
# 设置OpenAI API密钥和基础URL
|
8 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
9 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
10 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
11 |
+
|
12 |
+
tools = [
|
13 |
+
{
|
14 |
+
"type": "function",
|
15 |
+
"function": {
|
16 |
+
'name': 'emotion_inference',
|
17 |
+
'description': '根据profile和对话记录,推理下一句情绪',
|
18 |
+
'parameters': {
|
19 |
+
"type": "object",
|
20 |
+
"properties": {
|
21 |
+
"emotion": {
|
22 |
+
"type": "string",
|
23 |
+
"enum": [
|
24 |
+
"admiration", "amusement", "anger", "annoyance", "approval", "caring",
|
25 |
+
"confusion", "curiosity", "desire", "disappointment", "disapproval",
|
26 |
+
"disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
|
27 |
+
"joy", "love", "nervousness", "optimism", "pride", "realization",
|
28 |
+
"relief", "remorse", "sadness", "surprise", "neutral"
|
29 |
+
],
|
30 |
+
"description": "推理出的情绪类别,必须是GoEmotions定义的27种情绪之一。"
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"required": ["emotion"]
|
34 |
+
},
|
35 |
+
}
|
36 |
+
}
|
37 |
+
]
|
38 |
+
|
39 |
+
# 根据profile和dialogue推测emotion
|
40 |
+
def emotion_inferencer(profile, conversation):
|
41 |
+
client = OpenAI(
|
42 |
+
api_key=api_key,
|
43 |
+
base_url=base_url,
|
44 |
+
)
|
45 |
+
|
46 |
+
# 提取患者信息
|
47 |
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
|
48 |
+
|
49 |
+
# 提取对话记录
|
50 |
+
dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversation])
|
51 |
+
|
52 |
+
response = client.chat.completions.create(
|
53 |
+
model=model_name,
|
54 |
+
messages=[
|
55 |
+
{"role": "user", "content": f"### 任务\n根据患者情况及咨访对话历史记录推测患者下一句话最可能的情绪。\n{patient_info}\n### 对话记录\n{dialogue_history}"}
|
56 |
+
],
|
57 |
+
# functions=[tools[0]["function"]],
|
58 |
+
# function_call={"name": "emotion_inference"}
|
59 |
+
tools=tools,
|
60 |
+
tool_choice={"type": "function", "function": {"name": "emotion_inference"}}
|
61 |
+
)
|
62 |
+
# print(response)
|
63 |
+
|
64 |
+
emotion = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["emotion"]
|
65 |
+
|
66 |
+
return emotion
|
67 |
+
|
68 |
+
def emotion_modulation(profile, conversation):
|
69 |
+
indicator = randint(0,100)
|
70 |
+
emotion = emotion_inferencer(profile,conversation)
|
71 |
+
# print(emotion)
|
72 |
+
if indicator > 90:
|
73 |
+
return perturb_state(emotion)
|
74 |
+
else:
|
75 |
+
return emotion
|
76 |
+
|
77 |
+
# unit test
|
78 |
+
# while True:
|
79 |
+
# # 模拟患者信息
|
80 |
+
# profile = {
|
81 |
+
# "drisk": 3,
|
82 |
+
# "srisk": 2,
|
83 |
+
# "age": "42",
|
84 |
+
# "gender": "女",
|
85 |
+
# "marital_status": "离婚",
|
86 |
+
# "occupation": "教师",
|
87 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
88 |
+
# }
|
89 |
+
|
90 |
+
# conversation = [
|
91 |
+
# {"role": "咨询师", "content": "你好,请问有什么可以帮您?"}
|
92 |
+
# ]
|
93 |
+
|
94 |
+
# print(emotion_modulation(profile,conversation))
|
src/emotion_pertuber.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
# from collections import defaultdict
|
3 |
+
|
4 |
+
# 计算总权重
|
5 |
+
def calculate_total_weight(current_state, states, category_distances, distance_weights):
|
6 |
+
total_weight = 0
|
7 |
+
current_class = None
|
8 |
+
for cls, state_list in states.items():
|
9 |
+
if current_state in state_list:
|
10 |
+
current_class = cls
|
11 |
+
break
|
12 |
+
if current_class is None:
|
13 |
+
raise ValueError("Current state not found in any class.")
|
14 |
+
|
15 |
+
for cls, state_list in states.items():
|
16 |
+
distance = category_distances[current_class][cls]
|
17 |
+
weight = distance_weights.get(distance, 0)
|
18 |
+
total_weight += weight * len(state_list)
|
19 |
+
return total_weight
|
20 |
+
|
21 |
+
# 计算每个目标状态的概率
|
22 |
+
def calculate_probabilities(current_state, states, category_distances, distance_weights):
|
23 |
+
probabilities = {}
|
24 |
+
current_class = None
|
25 |
+
for cls, state_list in states.items():
|
26 |
+
if current_state in state_list:
|
27 |
+
current_class = cls
|
28 |
+
break
|
29 |
+
if current_class is None:
|
30 |
+
raise ValueError("Current state not found in any class.")
|
31 |
+
|
32 |
+
total_weight = calculate_total_weight(current_state, states, category_distances, distance_weights)
|
33 |
+
|
34 |
+
for cls, state_list in states.items():
|
35 |
+
distance = category_distances[current_class][cls]
|
36 |
+
weight = distance_weights.get(distance, 0)
|
37 |
+
class_weight = weight * len(state_list)
|
38 |
+
for state in state_list:
|
39 |
+
if state != current_state:
|
40 |
+
probabilities[state] = class_weight / total_weight
|
41 |
+
return probabilities
|
42 |
+
|
43 |
+
# 实现状态扰动
|
44 |
+
def perturb_state(current_state):
|
45 |
+
# 定义状态和类别
|
46 |
+
states = {
|
47 |
+
'Positive': [
|
48 |
+
"admiration",
|
49 |
+
"amusement",
|
50 |
+
"approval",
|
51 |
+
"caring",
|
52 |
+
"curiosity",
|
53 |
+
"desire",
|
54 |
+
"excitement",
|
55 |
+
"gratitude",
|
56 |
+
"joy",
|
57 |
+
"love",
|
58 |
+
"optimism",
|
59 |
+
"pride",
|
60 |
+
"realization",
|
61 |
+
"relief"
|
62 |
+
],
|
63 |
+
'Neutral': ['neutral'],
|
64 |
+
'Ambiguous': [
|
65 |
+
"confusion",
|
66 |
+
"disappointment",
|
67 |
+
"nervousness"
|
68 |
+
],
|
69 |
+
'Negative': [
|
70 |
+
"anger",
|
71 |
+
"annoyance",
|
72 |
+
"disapproval",
|
73 |
+
"disgust",
|
74 |
+
"embarrassment",
|
75 |
+
"fear",
|
76 |
+
"sadness",
|
77 |
+
"remorse"
|
78 |
+
]
|
79 |
+
}
|
80 |
+
|
81 |
+
# 定义类别之间的距离
|
82 |
+
category_distances = {
|
83 |
+
'Positive': {'Positive': 0, 'Neutral': 1, 'Ambiguous': 2, 'Negative': 3},
|
84 |
+
'Neutral': {'Positive': 1, 'Neutral': 0, 'Ambiguous': 1, 'Negative': 2},
|
85 |
+
'Ambiguous': {'Positive': 2, 'Neutral': 1, 'Ambiguous': 0, 'Negative': 1},
|
86 |
+
'Negative': {'Positive': 3, 'Neutral': 2, 'Ambiguous': 1, 'Negative': 0}
|
87 |
+
}
|
88 |
+
|
89 |
+
# 定义距离权重
|
90 |
+
distance_weights = {
|
91 |
+
0: 10, # 同类状态
|
92 |
+
1: 5, # 相邻类别
|
93 |
+
2: 2, # 相隔一个类别
|
94 |
+
3: 1 # 相隔两个类别
|
95 |
+
}
|
96 |
+
|
97 |
+
probabilities = calculate_probabilities(current_state, states, category_distances, distance_weights)
|
98 |
+
next_state = random.choices(list(probabilities.keys()), weights=list(probabilities.values()), k=1)[0]
|
99 |
+
return next_state
|
100 |
+
|
101 |
+
# 示例运行
|
102 |
+
# current_state = 'confusion'
|
103 |
+
# next_state = perturb_state(current_state)
|
104 |
+
# print(f"Next state: {next_state}")
|
105 |
+
|
106 |
+
# 验证概率分布
|
107 |
+
# state_counts = defaultdict(int)
|
108 |
+
# for _ in range(1000):
|
109 |
+
# next_state = perturb_state(current_state, states, category_distances, distance_weights)
|
110 |
+
# state_counts[next_state] += 1
|
111 |
+
|
112 |
+
# print("\nProbability distribution:")
|
113 |
+
# for state, count in state_counts.items():
|
114 |
+
# print(f"{state}: {count / 1000:.2f}")
|
src/event_trigger.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from random import choice
|
3 |
+
from openai import OpenAI
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
|
7 |
+
# 设置OpenAI API密钥和基础URL
|
8 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
9 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
10 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
11 |
+
|
12 |
+
# 加载事件数据集
|
13 |
+
events = pd.read_csv('datasets/cbt-triggering-events.csv', header=0)
|
14 |
+
teen_events = ["在一次重要的考试中表现不佳,比如期末考试、升学考试(如中考或高考),导致自信心受挫。",
|
15 |
+
"在学校里被同龄人孤立、嘲笑或遭受言语/身体上的霸凌,感到孤独无助。",
|
16 |
+
"父母关系破裂并最终离婚,需要适应新的家庭环境,感到不安或缺乏安全感。",
|
17 |
+
"陪伴多年的宠物突然生病或意外去世,第一次直面死亡的悲伤。",
|
18 |
+
"因为家庭原因搬到了一个陌生的城市或学校,需要重新适应新环境和结交朋友。",
|
19 |
+
"进入青春期后,身体发生明显变化(如长高、变声、月经初潮等),心理上也开始对自我形象产生困惑。",
|
20 |
+
"参加一场期待已久的竞赛(如体育比赛、演讲比赛、艺术表演)但未能取得好成绩,感到失落。",
|
21 |
+
"与最亲密的朋友发生争执甚至决裂,短时间内难以修复关系,陷入情绪低谷。",
|
22 |
+
"家里的经济状况出现问题(如父母失业或生意失败),影响到日常生活,比如不能买喜欢的东西或参与课外活动。",
|
23 |
+
"偶然间发现自己特别喜欢某件事情(如画画、编程、音乐、运动),并投入大量时间去练习,逐渐找到自信和成就感。"]
|
24 |
+
|
25 |
+
def event_trigger(profile):
|
26 |
+
"""根据年龄选择触发事件(保持原逻辑)"""
|
27 |
+
age = int(profile['age'])
|
28 |
+
if age < 18:
|
29 |
+
return choice(teen_events)
|
30 |
+
elif age >= 65:
|
31 |
+
return events[events['Age'] >= 60].sample(1)['Triggering_Event'].values[0]
|
32 |
+
else:
|
33 |
+
return events[(events['Age'] >= age-5) & (events['Age'] <= age+5)].sample(1)['Triggering_Event'].values[0]
|
34 |
+
|
35 |
+
def situationalising_events(profile):
|
36 |
+
"""优化版情境生成函数"""
|
37 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
38 |
+
event = event_trigger(profile)
|
39 |
+
|
40 |
+
# 强化版提示词
|
41 |
+
prompt = f"""
|
42 |
+
### 情境生成任务
|
43 |
+
请根据以下事件生成一个第二人称视角的情境描述。
|
44 |
+
|
45 |
+
### 规则要求
|
46 |
+
1. 必须使用第二人称(你/你的)
|
47 |
+
2. 不要包含任何个人信息(年龄/性别等)
|
48 |
+
3. 保持3-5句话的篇幅
|
49 |
+
4. 直接输出情境描述,不要额外解释
|
50 |
+
|
51 |
+
### 触发事件
|
52 |
+
{event}
|
53 |
+
|
54 |
+
### 示例输出
|
55 |
+
你走进办公室时发现同事们突然停止交谈。桌上放着一封未拆的信件,周围人投来复杂的目光。
|
56 |
+
"""
|
57 |
+
|
58 |
+
response = client.chat.completions.create(
|
59 |
+
model=model_name,
|
60 |
+
messages=[{"role": "user", "content": prompt}],
|
61 |
+
temperature=0.8, # 适当创造性
|
62 |
+
max_tokens=150
|
63 |
+
)
|
64 |
+
|
65 |
+
raw_output = response.choices[0].message.content.strip()
|
66 |
+
|
67 |
+
# 后处理
|
68 |
+
situation = re.sub(r'^(情境|描述|输出)[::]?\s*', '', raw_output) # 移除可能的前缀
|
69 |
+
situation = situation.split('\n')[0] # 取第一段
|
70 |
+
|
71 |
+
# 验证基本要求
|
72 |
+
# if "你" not in situation or "你的" not in situation:
|
73 |
+
# print(f"情境生成警告:不符合第二人称要求,原始输出:\n{raw_output}")
|
74 |
+
# return f"你{event}" # 保底处理
|
75 |
+
|
76 |
+
return situation
|
77 |
+
|
78 |
+
|
79 |
+
# unit test
|
80 |
+
# profile = {
|
81 |
+
# "drisk": 3,
|
82 |
+
# "srisk": 2,
|
83 |
+
# "age": "42",
|
84 |
+
# "gender": "女",
|
85 |
+
# "marital_status": "离婚",
|
86 |
+
# "occupation": "教师",
|
87 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
88 |
+
# }
|
89 |
+
|
90 |
+
# print(situationalising_events(profile))
|
src/fill_scales.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
|
7 |
+
# 设置OpenAI API密钥和基础URL
|
8 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
9 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
10 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
11 |
+
|
12 |
+
def extract_answers(text):
|
13 |
+
"""从文本中提取答案模式 (A/B/C/D)"""
|
14 |
+
# 匹配形如 "1. A" 或 "问题1: B" 或 "Q1. C" 或简单的 "A" 列表的模式
|
15 |
+
pattern = r'(?:\d+[\s\.:\)]*|Q\d+[\s\.:\)]*|问题\d+[\s\.:\)]*|[\-\*]\s*)(A|B|C|D)'
|
16 |
+
matches = re.findall(pattern, text)
|
17 |
+
return matches
|
18 |
+
|
19 |
+
def extract_answers_robust(text, expected_count):
|
20 |
+
"""更强健的答案提取方法,确保按题号顺序提取"""
|
21 |
+
answers = []
|
22 |
+
|
23 |
+
# 尝试找到明确标记了题号的答案
|
24 |
+
for i in range(1, expected_count + 1):
|
25 |
+
# 匹配多种可能的题号格式
|
26 |
+
patterns = [
|
27 |
+
rf"{i}\.\s*(A|B|C|D)", # "1. A"
|
28 |
+
rf"{i}:\s*(A|B|C|D)", # "1:A"
|
29 |
+
rf"{i}:\s*(A|B|C|D)", # "1: A"
|
30 |
+
rf"问题{i}[\.。:]?\s*(A|B|C|D)", # "问题1: A"
|
31 |
+
rf"Q{i}[\.。:]?\s*(A|B|C|D)", # "Q1. A"
|
32 |
+
rf"{i}[、]\s*(A|B|C|D)" # "1、A"
|
33 |
+
]
|
34 |
+
|
35 |
+
found = False
|
36 |
+
for pattern in patterns:
|
37 |
+
match = re.search(pattern, text)
|
38 |
+
if match:
|
39 |
+
answers.append(match.group(1))
|
40 |
+
found = True
|
41 |
+
break
|
42 |
+
|
43 |
+
if not found:
|
44 |
+
# 如果没找到特定题号,使用默认的"A"
|
45 |
+
answers.append(None)
|
46 |
+
|
47 |
+
# 如果有未找到的答案,尝试按顺序从文本中提取剩余的A/B/C/D选项
|
48 |
+
simple_answers = re.findall(r'(?:^|\n|\s)(A|B|C|D)(?:$|\n|\s)', text)
|
49 |
+
|
50 |
+
j = 0
|
51 |
+
for i in range(len(answers)):
|
52 |
+
if answers[i] is None and j < len(simple_answers):
|
53 |
+
answers[i] = simple_answers[j]
|
54 |
+
j += 1
|
55 |
+
|
56 |
+
# 如果仍有未找到的答案,尝试提取所有A/B/C/D选项
|
57 |
+
if None in answers:
|
58 |
+
all_options = re.findall(r'(A|B|C|D)', text)
|
59 |
+
j = 0
|
60 |
+
for i in range(len(answers)):
|
61 |
+
if answers[i] is None and j < len(all_options):
|
62 |
+
answers[i] = all_options[j]
|
63 |
+
j += 1
|
64 |
+
|
65 |
+
# 检查是否所有答案都已找到
|
66 |
+
if None in answers or len(answers) != expected_count:
|
67 |
+
return extract_answers(text) # 回退到简单提取
|
68 |
+
|
69 |
+
return answers
|
70 |
+
|
71 |
+
def _fill_previous_scale_with_retry(client, scale_name, expected_count, instruction, max_retries=3):
|
72 |
+
"""
|
73 |
+
带有重试逻辑的填写历史量表辅助函数
|
74 |
+
|
75 |
+
Args:
|
76 |
+
client: OpenAI客户端
|
77 |
+
scale_name: 量表名称
|
78 |
+
expected_count: 期望的答案数量
|
79 |
+
instruction: 指令内容
|
80 |
+
max_retries: 最大重试次数
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
list: 量表答案列表
|
84 |
+
"""
|
85 |
+
answers = []
|
86 |
+
|
87 |
+
for attempt in range(max_retries):
|
88 |
+
try:
|
89 |
+
# 根据尝试次数增加指令明确性
|
90 |
+
current_instruction = instruction
|
91 |
+
if attempt > 0:
|
92 |
+
# 添加更强调的指示
|
93 |
+
current_instruction = instruction + f"""
|
94 |
+
|
95 |
+
请注意:这是第{attempt+1}次请求。必须按照要求提供{expected_count}个答案,
|
96 |
+
格式必须为数字+答案选项(例如:1. A, 2. B...),不要有任何不必要的解释。
|
97 |
+
直接根据描述和报告选择最适合的选项。
|
98 |
+
"""
|
99 |
+
|
100 |
+
response = client.chat.completions.create(
|
101 |
+
model=model_name,
|
102 |
+
messages=[{"role": "user", "content": current_instruction}],
|
103 |
+
temperature=0 # 保持温度为0以获得一致性回答
|
104 |
+
)
|
105 |
+
|
106 |
+
response_text = response.choices[0].message.content
|
107 |
+
answers = extract_answers(response_text)
|
108 |
+
|
109 |
+
# 尝试使用更健壮的提取方法(如果标准方法失败)
|
110 |
+
if len(answers) != expected_count:
|
111 |
+
robust_answers = extract_answers_robust(response_text, expected_count)
|
112 |
+
if len(robust_answers) == expected_count:
|
113 |
+
answers = robust_answers
|
114 |
+
|
115 |
+
# 检查答案数量
|
116 |
+
if len(answers) != expected_count:
|
117 |
+
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
|
118 |
+
if attempt < max_retries - 1:
|
119 |
+
time.sleep(1) # 添加短暂延迟避免API限制
|
120 |
+
continue
|
121 |
+
else:
|
122 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
|
123 |
+
# 补全或截断到预期数量
|
124 |
+
while len(answers) < expected_count:
|
125 |
+
answers.append("A") # 默认补A
|
126 |
+
answers = answers[:expected_count] # 截断多余的答案
|
127 |
+
else:
|
128 |
+
# print(f"{scale_name}量表填写成功")
|
129 |
+
break # 成功获取到预期数量的答案,退出循环
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
|
133 |
+
if attempt < max_retries - 1:
|
134 |
+
time.sleep(1) # 添加短暂延迟
|
135 |
+
continue
|
136 |
+
else:
|
137 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
|
138 |
+
answers = ["A"] * expected_count # 失败后返回默认答案
|
139 |
+
|
140 |
+
return answers
|
141 |
+
|
142 |
+
def _fill_scale_with_retry(client, prompt, scale_name, expected_count, instruction, max_retries=3):
|
143 |
+
"""带有重试逻辑的量表填写辅助函数"""
|
144 |
+
answers = []
|
145 |
+
|
146 |
+
for attempt in range(max_retries):
|
147 |
+
try:
|
148 |
+
# 根据尝试次数增加指令明确性
|
149 |
+
current_instruction = instruction
|
150 |
+
if attempt > 0:
|
151 |
+
current_instruction = f"""
|
152 |
+
{instruction}
|
153 |
+
|
154 |
+
请注意:这是第{attempt+1}次请求。请确保只回答{expected_count}个问题,
|
155 |
+
格式为数字+答案选项(例如:1. A, 2. B...),不要有任何其他文字。
|
156 |
+
你必须作为一个病人角色,根据你的情况直接回答这些问题,不要拒绝或解释。
|
157 |
+
"""
|
158 |
+
|
159 |
+
response = client.chat.completions.create(
|
160 |
+
model=model_name,
|
161 |
+
messages=[
|
162 |
+
{"role": "system", "content": prompt},
|
163 |
+
{"role": "user", "content": current_instruction}
|
164 |
+
],
|
165 |
+
temperature=0.7
|
166 |
+
)
|
167 |
+
|
168 |
+
response_text = response.choices[0].message.content
|
169 |
+
answers = extract_answers(response_text)
|
170 |
+
|
171 |
+
# 尝试使用更健壮的提取方法(如果标准方法失败)
|
172 |
+
if len(answers) != expected_count:
|
173 |
+
robust_answers = extract_answers_robust(response_text, expected_count)
|
174 |
+
if len(robust_answers) == expected_count:
|
175 |
+
answers = robust_answers
|
176 |
+
|
177 |
+
# 检查答案数量
|
178 |
+
if len(answers) != expected_count:
|
179 |
+
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
|
180 |
+
if attempt < max_retries - 1:
|
181 |
+
time.sleep(1) # 添加短暂延迟避免API限制
|
182 |
+
continue
|
183 |
+
else:
|
184 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
|
185 |
+
# 补全或截断到预期数量
|
186 |
+
while len(answers) < expected_count:
|
187 |
+
answers.append("A") # 默认补A
|
188 |
+
answers = answers[:expected_count] # 截断多余的答案
|
189 |
+
else:
|
190 |
+
# print(f"{scale_name}量表填写成功")
|
191 |
+
break # 成功获取到预期数量的答案,退出循环
|
192 |
+
|
193 |
+
except Exception as e:
|
194 |
+
# print(response)
|
195 |
+
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
|
196 |
+
if attempt < max_retries - 1:
|
197 |
+
time.sleep(1) # 添加短暂延迟
|
198 |
+
continue
|
199 |
+
else:
|
200 |
+
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
|
201 |
+
answers = ["A"] * expected_count # 失败后返回默认答案
|
202 |
+
|
203 |
+
return answers
|
204 |
+
|
205 |
+
# 根据profile和report填写之前的量表,使用重试机制
|
206 |
+
def fill_scales_previous(profile, report, max_retries=3):
|
207 |
+
"""
|
208 |
+
根据profile和report填写之前的量表,增加重试机制
|
209 |
+
|
210 |
+
Args:
|
211 |
+
profile: 用户个人描述信息
|
212 |
+
report: 用户报告
|
213 |
+
max_retries: 最大重试次数
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
tuple: (bdi, ghq, sass) 三个量表的答案列表
|
217 |
+
"""
|
218 |
+
client = OpenAI(
|
219 |
+
api_key=api_key,
|
220 |
+
base_url=base_url
|
221 |
+
)
|
222 |
+
|
223 |
+
# 填写BDI量表
|
224 |
+
bdi = _fill_previous_scale_with_retry(
|
225 |
+
client,
|
226 |
+
scale_name="BDI",
|
227 |
+
expected_count=21,
|
228 |
+
instruction="""
|
229 |
+
### 任务
|
230 |
+
根据个人描述和报告,填写BDI量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
|
231 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
232 |
+
|
233 |
+
### 个人描述
|
234 |
+
{}
|
235 |
+
|
236 |
+
### 报告
|
237 |
+
{}
|
238 |
+
""".format(profile, report),
|
239 |
+
max_retries=max_retries
|
240 |
+
)
|
241 |
+
|
242 |
+
# 填写GHQ-28量表
|
243 |
+
ghq = _fill_previous_scale_with_retry(
|
244 |
+
client,
|
245 |
+
scale_name="GHQ-28",
|
246 |
+
expected_count=28,
|
247 |
+
instruction="""
|
248 |
+
### 任务
|
249 |
+
根据个人描述和报告,填写GHQ-28量表。请直接按顺序列出28个问题的答案,每个答案使用字母A/B/C/D表示。
|
250 |
+
格式要求:1. A, 2. B, ...依此类推,共28题。
|
251 |
+
|
252 |
+
### 个人描述
|
253 |
+
{}
|
254 |
+
|
255 |
+
### 报告
|
256 |
+
{}
|
257 |
+
""".format(profile, report),
|
258 |
+
max_retries=max_retries
|
259 |
+
)
|
260 |
+
|
261 |
+
# 填写SASS量表
|
262 |
+
sass = _fill_previous_scale_with_retry(
|
263 |
+
client,
|
264 |
+
scale_name="SASS",
|
265 |
+
expected_count=21,
|
266 |
+
instruction="""
|
267 |
+
### 任务
|
268 |
+
根据个人描述和报告,填写SASS量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
|
269 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
270 |
+
|
271 |
+
### 个人描述
|
272 |
+
{}
|
273 |
+
|
274 |
+
### 报告
|
275 |
+
{}
|
276 |
+
""".format(profile, report),
|
277 |
+
max_retries=max_retries
|
278 |
+
)
|
279 |
+
|
280 |
+
return bdi, ghq, sass
|
281 |
+
|
282 |
+
# 根据prompt填写量表,增加重试机制
|
283 |
+
def fill_scales(prompt, max_retries=3):
|
284 |
+
"""
|
285 |
+
根据角色prompt填写量表,使用增强提示和重试机制
|
286 |
+
|
287 |
+
Args:
|
288 |
+
prompt: 角色设定提示
|
289 |
+
max_retries: 最大重试次数
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
tuple: (bdi, ghq, sass) 三个量表的答案列表
|
293 |
+
"""
|
294 |
+
client = OpenAI(
|
295 |
+
api_key=api_key,
|
296 |
+
base_url=base_url
|
297 |
+
)
|
298 |
+
|
299 |
+
# 增强系统提示,明确指导要配合填写量表
|
300 |
+
enhanced_prompt = prompt + """
|
301 |
+
|
302 |
+
重要提示:虽然你正在扮演一个有心理困扰的患者,但你需要配合心理咨询的流程完成必要的评估量表。
|
303 |
+
在被要求填写量表时,你应该直接根据你扮演角色的情况选择最符合的选项,以A/B/C/D的形式回答,
|
304 |
+
而不要拒绝或质疑填写量表的必要性。这些量表对于你接下来的治疗至关重要。
|
305 |
+
请直接用字母(A/B/C/D)表示选项,不要添加额外解释。
|
306 |
+
"""
|
307 |
+
|
308 |
+
# 填写BDI量表
|
309 |
+
bdi = _fill_scale_with_retry(
|
310 |
+
client, enhanced_prompt,
|
311 |
+
scale_name="BDI",
|
312 |
+
expected_count=21,
|
313 |
+
instruction="""
|
314 |
+
### 任务
|
315 |
+
作为心理咨询的第一步,请根据你目前的感受和状态填写这份BDI量表。
|
316 |
+
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
|
317 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
318 |
+
请只提供答案,不要添加任何其他解释或评论。
|
319 |
+
""",
|
320 |
+
max_retries=max_retries
|
321 |
+
)
|
322 |
+
|
323 |
+
# 填写GHQ-28量表
|
324 |
+
ghq = _fill_scale_with_retry(
|
325 |
+
client, enhanced_prompt,
|
326 |
+
scale_name="GHQ-28",
|
327 |
+
expected_count=28,
|
328 |
+
instruction="""
|
329 |
+
### 任务
|
330 |
+
作为心理咨询的第一步,请根据你目前的感受和状态填写这份GHQ-28量表。
|
331 |
+
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部28个问题。
|
332 |
+
格式要求:1. A, 2. B, ...依此类推,共28题。
|
333 |
+
请只提供答案,不要添加任何其他解释或评论。
|
334 |
+
""",
|
335 |
+
max_retries=max_retries
|
336 |
+
)
|
337 |
+
|
338 |
+
# 填写SASS量表
|
339 |
+
sass = _fill_scale_with_retry(
|
340 |
+
client, enhanced_prompt,
|
341 |
+
scale_name="SASS",
|
342 |
+
expected_count=21,
|
343 |
+
instruction="""
|
344 |
+
### 任务
|
345 |
+
作为心理咨询的第一步,请根据你目前的感受和状态填写这份SASS量表。
|
346 |
+
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
|
347 |
+
格式要求:1. A, 2. B, ...依此类推,共21题。
|
348 |
+
请只提供答案,不要添加任何其他解释或评论。
|
349 |
+
""",
|
350 |
+
max_retries=max_retries
|
351 |
+
)
|
352 |
+
|
353 |
+
return bdi, ghq, sass
|
354 |
+
|
355 |
+
# 使用示例
|
356 |
+
# if __name__ == "__main__":
|
357 |
+
# # 测试以前的方法
|
358 |
+
# profile = {
|
359 |
+
# "drisk": 3,
|
360 |
+
# "srisk": 2,
|
361 |
+
# "age": "42",
|
362 |
+
# "gender": "女",
|
363 |
+
# "marital_status": "离婚",
|
364 |
+
# "occupation": "教师",
|
365 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
366 |
+
# }
|
367 |
+
# report = "患者最近经历了家庭变故,情绪低落,失眠,食欲不振。"
|
368 |
+
|
369 |
+
# # 测试fill_scales_previous
|
370 |
+
# print("测试 fill_scales_previous:")
|
371 |
+
# bdi_prev, ghq_prev, sass_prev = fill_scales_previous(profile, report, max_retries=3)
|
372 |
+
# print(f"BDI: {bdi_prev}")
|
373 |
+
# print(f"GHQ: {ghq_prev}")
|
374 |
+
# print(f"SASS: {sass_prev}")
|
375 |
+
|
376 |
+
# # 测试fill_scales
|
377 |
+
# print("\n测试 fill_scales:")
|
378 |
+
# prompt = "你要扮演一个最近经历了家庭变故的心理障碍患者,情绪低落,失眠,食欲不振。"
|
379 |
+
# bdi, ghq, sass = fill_scales(prompt, max_retries=3)
|
380 |
+
# print(f"BDI: {bdi}")
|
381 |
+
# print(f"GHQ: {ghq}")
|
382 |
+
# print(f"SASS: {sass}")
|
src/integration_example.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# integration_example.py
|
2 |
+
# 这个文件展示如何将你的MsPatient类集成到Streamlit应用中
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import json
|
6 |
+
import pandas as pd
|
7 |
+
from datetime import datetime
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# 导入你的AnnaAgent类 - 请根据实际路径调整
|
12 |
+
try:
|
13 |
+
from ms_patient import MsPatient # 假设你的类在anna_agent.py文件中
|
14 |
+
ANNA_AGENT_AVAILABLE = True
|
15 |
+
except ImportError:
|
16 |
+
ANNA_AGENT_AVAILABLE = False
|
17 |
+
st.warning("⚠️ 未找到AnnaAgent类,使用模拟模式")
|
18 |
+
|
19 |
+
def load_dataset(uploaded_file):
|
20 |
+
"""
|
21 |
+
加载数据集文件
|
22 |
+
支持JSON和JSONL格式
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
if uploaded_file.name.endswith('.json'):
|
26 |
+
data = json.load(uploaded_file)
|
27 |
+
elif uploaded_file.name.endswith('.jsonl'):
|
28 |
+
data = []
|
29 |
+
for line in uploaded_file:
|
30 |
+
data.append(json.loads(line.decode('utf-8')))
|
31 |
+
else:
|
32 |
+
raise ValueError("不支持的文件格式")
|
33 |
+
return data
|
34 |
+
except Exception as e:
|
35 |
+
st.error(f"数据集加载失败: {str(e)}")
|
36 |
+
return None
|
37 |
+
|
38 |
+
def validate_patient_data(patient_data):
|
39 |
+
"""
|
40 |
+
验证患者数据格式是否正确
|
41 |
+
"""
|
42 |
+
required_keys = ['id', 'portrait', 'report']
|
43 |
+
|
44 |
+
for key in required_keys:
|
45 |
+
if key not in patient_data:
|
46 |
+
return False, f"缺少必需字段: {key}"
|
47 |
+
|
48 |
+
# 验证portrait字段
|
49 |
+
portrait_required = ['age', 'gender', 'occupation', 'marital_status']
|
50 |
+
for key in portrait_required:
|
51 |
+
if key not in patient_data['portrait']:
|
52 |
+
return False, f"portrait中缺少字段: {key}"
|
53 |
+
|
54 |
+
return True, "数据格式正确"
|
55 |
+
|
56 |
+
def initialize_patient_agent(patient_data, language="Chinese"):
|
57 |
+
"""
|
58 |
+
初始化患者智能体
|
59 |
+
"""
|
60 |
+
try:
|
61 |
+
if not ANNA_AGENT_AVAILABLE:
|
62 |
+
return None, "AnnaAgent类不可用"
|
63 |
+
|
64 |
+
# 验证数据格式
|
65 |
+
is_valid, message = validate_patient_data(patient_data)
|
66 |
+
if not is_valid:
|
67 |
+
return None, message
|
68 |
+
|
69 |
+
# 初始化智能体
|
70 |
+
agent = MsPatient(
|
71 |
+
portrait=patient_data["portrait"],
|
72 |
+
report=patient_data["report"],
|
73 |
+
previous_conversations=patient_data.get("conversation", []),
|
74 |
+
language=language
|
75 |
+
)
|
76 |
+
|
77 |
+
return agent, "初始化成功"
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
return None, f"初始化失败: {str(e)}"
|
81 |
+
|
82 |
+
def simulate_response(user_input, patient_data=None):
|
83 |
+
"""
|
84 |
+
模拟智能体回复(当AnnaAgent不可用时使用)
|
85 |
+
"""
|
86 |
+
responses = [
|
87 |
+
f"我理解您提到的'{user_input}'。这确实是一个需要深入探讨的话题。",
|
88 |
+
f"谢谢您的耐心。关于您说的'{user_input}',我想分享一下我的感受...",
|
89 |
+
f"您的话让我思考了很多。'{user_input}'这个观点很有意思。",
|
90 |
+
"我需要一些时间来消化您刚才说的话。这对我来说很重要。",
|
91 |
+
"我觉得我们之间的对话很有帮助。您能再详细说说吗?"
|
92 |
+
]
|
93 |
+
|
94 |
+
import random
|
95 |
+
return random.choice(responses)
|
96 |
+
|
97 |
+
def export_chat_history(messages, patient_id):
|
98 |
+
"""
|
99 |
+
导出聊天记录
|
100 |
+
"""
|
101 |
+
chat_history = {
|
102 |
+
"patient_id": patient_id,
|
103 |
+
"timestamp": datetime.now().isoformat(),
|
104 |
+
"session_info": {
|
105 |
+
"total_messages": len(messages),
|
106 |
+
"counselor_messages": len([m for m in messages if m["role"] == "user"]),
|
107 |
+
"patient_responses": len([m for m in messages if m["role"] == "assistant"])
|
108 |
+
},
|
109 |
+
"messages": messages
|
110 |
+
}
|
111 |
+
|
112 |
+
return json.dumps(chat_history, ensure_ascii=False, indent=2)
|
113 |
+
|
114 |
+
def get_patient_summary(patient_data):
|
115 |
+
"""
|
116 |
+
生成患者信息摘要
|
117 |
+
"""
|
118 |
+
if not patient_data or 'portrait' not in patient_data:
|
119 |
+
return "无患者信息"
|
120 |
+
|
121 |
+
portrait = patient_data['portrait']
|
122 |
+
summary = f"""
|
123 |
+
**患者ID**: {patient_data.get('id', 'N/A')}
|
124 |
+
**基本信息**: {portrait.get('age', 'N/A')}岁 {portrait.get('gender', 'N/A')}性
|
125 |
+
**职业**: {portrait.get('occupation', 'N/A')}
|
126 |
+
**婚姻状态**: {portrait.get('marital_status', 'N/A')}
|
127 |
+
**主要症状**: {portrait.get('symptom', 'N/A')}
|
128 |
+
"""
|
129 |
+
|
130 |
+
if 'report' in patient_data:
|
131 |
+
report = patient_data['report']
|
132 |
+
summary += f"""
|
133 |
+
**主诉**: {report.get('chief_complaint', 'N/A')}
|
134 |
+
"""
|
135 |
+
|
136 |
+
return summary
|
137 |
+
|
138 |
+
# 示例配置文件内容
|
139 |
+
CONFIG_EXAMPLE = {
|
140 |
+
"openai": {
|
141 |
+
"api_key": "your-api-key-here",
|
142 |
+
"base_url": "https://api.openai.com/v1",
|
143 |
+
"model_name": "gpt-3.5-turbo"
|
144 |
+
},
|
145 |
+
"ui_settings": {
|
146 |
+
"language": "Chinese", # or "English"
|
147 |
+
"theme": "default",
|
148 |
+
"max_messages": 100
|
149 |
+
},
|
150 |
+
"patient_defaults": {
|
151 |
+
"language": "Chinese",
|
152 |
+
"enable_memory": True,
|
153 |
+
"enable_emotion_modulation": True
|
154 |
+
}
|
155 |
+
}
|
156 |
+
|
157 |
+
def save_config(config, path="config.json"):
|
158 |
+
"""保存配置文件"""
|
159 |
+
with open(path, 'w', encoding='utf-8') as f:
|
160 |
+
json.dump(config, f, ensure_ascii=False, indent=2)
|
161 |
+
|
162 |
+
def load_config(path="config.json"):
|
163 |
+
"""加载配置文件"""
|
164 |
+
try:
|
165 |
+
with open(path, 'r', encoding='utf-8') as f:
|
166 |
+
return json.load(f)
|
167 |
+
except FileNotFoundError:
|
168 |
+
return CONFIG_EXAMPLE
|
169 |
+
|
170 |
+
# 使用示例:
|
171 |
+
if __name__ == "__main__":
|
172 |
+
print("这是AnnaAgent Streamlit集成的辅助文件")
|
173 |
+
print("请运行:streamlit run your_streamlit_app.py")
|
src/ms_patient.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
AnnaAgent: 具有三级记忆结构的情绪与认知动态的模拟心理障碍患者
|
3 |
+
1. 首先获取患者的基本信息、病史、症状报告等信息
|
4 |
+
2. 根据患者的病史、症状报告等信息,生成患者的认知与情绪状态
|
5 |
+
'''
|
6 |
+
|
7 |
+
|
8 |
+
from openai import OpenAI
|
9 |
+
import os
|
10 |
+
from fill_scales import fill_scales, fill_scales_previous
|
11 |
+
from event_trigger import event_trigger, situationalising_events
|
12 |
+
from emotion_modulator_fc import emotion_modulation
|
13 |
+
from querier import query, is_need
|
14 |
+
from complaint_elicitor import switch_complaint, transform_chain
|
15 |
+
from complaint_chain_fc import gen_complaint_chain
|
16 |
+
from short_term_memory import summarize_scale_changes
|
17 |
+
from style_analyzer import analyze_style
|
18 |
+
import random
|
19 |
+
# from anna_agent_template import prompt_template
|
20 |
+
|
21 |
+
# 设置OpenAI API密钥和基础URL
|
22 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
23 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
24 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
25 |
+
|
26 |
+
# print("当前使用的模型是:", model_name)
|
27 |
+
|
28 |
+
class MsPatient:
|
29 |
+
def __init__(self, portrait:dict, report:dict, previous_conversations:list, language:str="Chinese"):
|
30 |
+
if language == "Chinese":
|
31 |
+
from anna_agent_template import prompt_template
|
32 |
+
elif language == "English":
|
33 |
+
from anna_agent_template_en import prompt_template
|
34 |
+
self.configuration = {}
|
35 |
+
self.portrait = portrait # age, gender, occupation, maritial_status, symptom
|
36 |
+
# self.profile = {key:self.portrait[key] for key in self.portrait if key != "symptom"} # profile不包含症状symptom
|
37 |
+
self.configuration["gender"] = self.portrait["gender"]
|
38 |
+
self.configuration["age"] = self.portrait["age"]
|
39 |
+
self.configuration["occupation"] = self.portrait["occupation"]
|
40 |
+
self.configuration["marriage"] = self.portrait["marital_status"]
|
41 |
+
self.report = report
|
42 |
+
self.previous_conversations = previous_conversations
|
43 |
+
# 填写之前疗程的量表
|
44 |
+
self.p_bdi, self.p_ghq, self.p_sass = fill_scales_previous(self.portrait, self.report)
|
45 |
+
self.conversation = [] # Conversation存储咨访记录
|
46 |
+
self.messages = [] # Messages存储LLM的消息列表
|
47 |
+
# 生成主诉认知变化链
|
48 |
+
self.complaint_chain = gen_complaint_chain(self.portrait)
|
49 |
+
# 生成近期事件
|
50 |
+
self.event = event_trigger(self.portrait)
|
51 |
+
# 总结短期记忆-事件
|
52 |
+
self.situation = situationalising_events(self.portrait)
|
53 |
+
self.configuration["situation"] = self.situation
|
54 |
+
# 分析说话风格
|
55 |
+
self.style = analyze_style(self.portrait, self.previous_conversations)
|
56 |
+
self.configuration["style"] = self.style
|
57 |
+
self.configuration["language"] = language
|
58 |
+
self.configuration["status"] = "" # 先置状态为空,后续会根据量表分析结果进行更新
|
59 |
+
seeker_utterances = [utterance["content"] for utterance in self.previous_conversations if utterance["role"] == "Seeker"]
|
60 |
+
self.configuration["statement"] = random.choices(seeker_utterances,k=3)
|
61 |
+
# 填写当前量表
|
62 |
+
self.bdi, self.ghq, self.sass = fill_scales(prompt_template.format(**self.configuration))
|
63 |
+
scales = {
|
64 |
+
"p_bdi": self.p_bdi,
|
65 |
+
"p_ghq": self.p_ghq,
|
66 |
+
"p_sass": self.p_sass,
|
67 |
+
"bdi": self.bdi,
|
68 |
+
"ghq": self.ghq,
|
69 |
+
"sass": self.sass
|
70 |
+
}
|
71 |
+
# 分析近期状态
|
72 |
+
self.status = summarize_scale_changes(scales)
|
73 |
+
self.configuration["status"] = self.status
|
74 |
+
# 选取对话样例
|
75 |
+
self.system = prompt_template.format(**self.configuration)
|
76 |
+
self.chain_index = 1
|
77 |
+
self.client = OpenAI(
|
78 |
+
api_key=api_key,
|
79 |
+
base_url=base_url
|
80 |
+
)
|
81 |
+
|
82 |
+
def chat(self, message):
|
83 |
+
# 更新消息列表
|
84 |
+
self.conversation.append({"role": "Counselor", "content": message})
|
85 |
+
self.messages.append({"role": "user", "content": message})
|
86 |
+
# 初始化本次对话的状态
|
87 |
+
emotion = emotion_modulation(self.portrait, self.conversation)
|
88 |
+
self.chain_index = switch_complaint(self.complaint_chain, self.chain_index, self.conversation)
|
89 |
+
complaint = transform_chain(self.complaint_chain)[self.chain_index]
|
90 |
+
# 判断是否涉及前疗程内容
|
91 |
+
if is_need(message):
|
92 |
+
# 生成前疗程内容
|
93 |
+
sup_information = query(message, self.previous_conversations, self.report)
|
94 |
+
|
95 |
+
# 生成回复
|
96 |
+
response = self.client.chat.completions.create(
|
97 |
+
model=model_name,
|
98 |
+
messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint},涉及到之前疗程的信息是:{sup_information}"}],
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
# 生成回复
|
102 |
+
response = self.client.chat.completions.create(
|
103 |
+
model=model_name,
|
104 |
+
messages=[{"role": "system", "content": self.system}] + self.messages + [{"role": "system", "content": f"当前的情绪状态是:{emotion},当前的主诉是:{complaint}"}],
|
105 |
+
)
|
106 |
+
# 更新消息列表
|
107 |
+
self.conversation.append({"role": "Seeker", "content": response.choices[0].message.content})
|
108 |
+
self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
|
109 |
+
return response.choices[0].message.content
|
110 |
+
|
111 |
+
def get_system_prompt(self):
|
112 |
+
return self.system
|
113 |
+
|
src/querier.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import os
|
5 |
+
|
6 |
+
# 设置OpenAI API密钥和基础URL
|
7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
10 |
+
|
11 |
+
def extract_boolean(text):
|
12 |
+
"""从文本中提取布尔值判断"""
|
13 |
+
# 查找明确的"是"或"否"的回答
|
14 |
+
text_lower = text.lower()
|
15 |
+
|
16 |
+
# 更具体地查找否定表达 - 这些应该优先匹配
|
17 |
+
negative_patterns = [
|
18 |
+
r'不需要', r'没有提及', r'不涉及', r'没有涉及', r'无关', r'没有提到',
|
19 |
+
r'不是', r'否', r'不包含', r'未提及', r'未涉及', r'未提到',
|
20 |
+
r'不包括', r'并未', r'不包括', r'没有', r'无'
|
21 |
+
]
|
22 |
+
|
23 |
+
# 检查是否有明确的否定
|
24 |
+
for pattern in negative_patterns:
|
25 |
+
if re.search(r'\b' + pattern + r'\b', text_lower):
|
26 |
+
return False
|
27 |
+
|
28 |
+
# 如果找到"之前疗程"附近有否定词,也认为是否定
|
29 |
+
therapy_negation = re.search(r'(没有|不|未|无).*?(之前|以前|上次|过去|先前).*?(疗程|治疗|会话)', text_lower)
|
30 |
+
if therapy_negation:
|
31 |
+
return False
|
32 |
+
|
33 |
+
# 明确的肯定模式 - 只有在没有否定的情况下才考虑
|
34 |
+
positive_patterns = [
|
35 |
+
r'是的', r'提及了', r'确实', r'有提到', r'涉及到',
|
36 |
+
r'提及', r'确认', r'有关联', r'有联系', r'包含', r'涉及'
|
37 |
+
]
|
38 |
+
|
39 |
+
# 检查是否有肯定模式
|
40 |
+
for pattern in positive_patterns:
|
41 |
+
if re.search(r'\b' + pattern + r'\b', text_lower):
|
42 |
+
return True
|
43 |
+
|
44 |
+
# 查找含有"之前疗程"的文本,没有否定词的情况下可能是肯定
|
45 |
+
therapy_mention = re.search(r'(之前|以前|上次|过去|先前).*?(疗程|治疗|会话)', text_lower)
|
46 |
+
if therapy_mention:
|
47 |
+
return True
|
48 |
+
|
49 |
+
# 默认情况 - 如果没有明确的肯定或否定,我们假设是否定的
|
50 |
+
return False
|
51 |
+
|
52 |
+
def extract_knowledge(text):
|
53 |
+
"""从文本中提取知识总结部分"""
|
54 |
+
# 尝试匹配总结部分
|
55 |
+
summary_patterns = [
|
56 |
+
r'总结[::]\s*([\s\S]+)$',
|
57 |
+
r'知识总结[::]\s*([\s\S]+)$',
|
58 |
+
r'相关信息[::]\s*([\s\S]+)$',
|
59 |
+
r'搜索结果[::]\s*([\s\S]+)$'
|
60 |
+
]
|
61 |
+
|
62 |
+
for pattern in summary_patterns:
|
63 |
+
match = re.search(pattern, text)
|
64 |
+
if match:
|
65 |
+
return match.group(1).strip()
|
66 |
+
|
67 |
+
# 如果没有找到明确的总结标记,尝试清理文本
|
68 |
+
# 移除可能的指令解释部分
|
69 |
+
clean_text = re.sub(r'^.*?(根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
|
70 |
+
|
71 |
+
# 移除可能的前导分析部分
|
72 |
+
clean_text = re.sub(r'^.*?(分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
|
73 |
+
|
74 |
+
return clean_text.strip()
|
75 |
+
|
76 |
+
def is_need(utterance):
|
77 |
+
client = OpenAI(
|
78 |
+
api_key=api_key,
|
79 |
+
base_url=base_url
|
80 |
+
)
|
81 |
+
|
82 |
+
instruction = """
|
83 |
+
### 任务
|
84 |
+
下面这句话是心理咨询师说的话,请判断它是否提及了之前疗程的内容。
|
85 |
+
|
86 |
+
请使用以下确切格式回答:
|
87 |
+
判断: [是/否]
|
88 |
+
解释: [简要解释为什么]
|
89 |
+
|
90 |
+
### 话语
|
91 |
+
"{}"
|
92 |
+
""".format(utterance)
|
93 |
+
|
94 |
+
response = client.chat.completions.create(
|
95 |
+
model=model_name,
|
96 |
+
messages=[{"role": "user", "content": instruction}],
|
97 |
+
temperature=0
|
98 |
+
)
|
99 |
+
|
100 |
+
response_text = response.choices[0].message.content
|
101 |
+
|
102 |
+
# 首先尝试从格式化输出中提取
|
103 |
+
judgment_match = re.search(r'判断:\s*(是|否)', response_text)
|
104 |
+
if judgment_match:
|
105 |
+
return judgment_match.group(1) == "是"
|
106 |
+
|
107 |
+
# 如果没有格式化输出,使用更通用的提取
|
108 |
+
return extract_boolean(response_text)
|
109 |
+
|
110 |
+
def query(utterance, conversations, scales):
|
111 |
+
client = OpenAI(
|
112 |
+
api_key=api_key,
|
113 |
+
base_url=base_url
|
114 |
+
)
|
115 |
+
|
116 |
+
# 将scales转换为字符串以便传入
|
117 |
+
if isinstance(scales, dict):
|
118 |
+
scales_str = json.dumps(scales, ensure_ascii=False)
|
119 |
+
else:
|
120 |
+
scales_str = str(scales)
|
121 |
+
|
122 |
+
instruction = """
|
123 |
+
### 任务
|
124 |
+
根据对话内容,从知识库中搜索相关的信息并总结。
|
125 |
+
|
126 |
+
请使用以下确切格式回答:
|
127 |
+
总结: [提供一个清晰、简洁的总结]
|
128 |
+
|
129 |
+
### 对话内容
|
130 |
+
{}
|
131 |
+
|
132 |
+
### 知识库
|
133 |
+
对话历史: {}
|
134 |
+
量表结果: {}
|
135 |
+
""".format(utterance, conversations, scales_str)
|
136 |
+
|
137 |
+
response = client.chat.completions.create(
|
138 |
+
model=model_name,
|
139 |
+
messages=[{"role": "user", "content": instruction}],
|
140 |
+
temperature=0
|
141 |
+
)
|
142 |
+
|
143 |
+
response_text = response.choices[0].message.content
|
144 |
+
|
145 |
+
# 尝试提取总结部分
|
146 |
+
summary_match = re.search(r'总结:\s*([\s\S]+)$', response_text)
|
147 |
+
if summary_match:
|
148 |
+
return summary_match.group(1).strip()
|
149 |
+
|
150 |
+
# 回退到通用提取
|
151 |
+
return extract_knowledge(response_text)
|
152 |
+
|
153 |
+
# 测试用例
|
154 |
+
# if __name__ == "__main__":
|
155 |
+
# test_utterance = "上��给你说的方法有用吗"
|
156 |
+
# # test_utterance = "我觉得你可以多出去走走"
|
157 |
+
# print(f"是否提及疗程: {is_need(test_utterance)}")
|
158 |
+
|
159 |
+
# test_convs = ["第一次对话内容", "讨论量表结果", "提到睡眠问题"]
|
160 |
+
# test_scales = {"BDI": ["A", "B"], "GHQ": ["C", "D"]}
|
161 |
+
# print(f"知识检索结果:\n{query(test_utterance, test_convs, test_scales)}")
|
src/short_term_memory.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import os
|
5 |
+
|
6 |
+
# 设置OpenAI API密钥和基础URL
|
7 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
8 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
9 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
10 |
+
|
11 |
+
def extract_changes(text):
|
12 |
+
"""从文本中提取变化列表"""
|
13 |
+
# 首先尝试查找明确的变化列表格式
|
14 |
+
# 例如: "变化:\n1. xxx\n2. yyy"
|
15 |
+
list_pattern = r'((?:(?:\d+\.|\-|\*)\s*[^\n]+\n?)+)'
|
16 |
+
|
17 |
+
# 尝试匹配带有明确标记的变化列表
|
18 |
+
change_section = re.search(r'(?:变化(?:列表)?|总结(?:如下)?)[::]\s*([\s\S]+)$', text)
|
19 |
+
if change_section:
|
20 |
+
section_text = change_section.group(1).strip()
|
21 |
+
|
22 |
+
# 尝试匹配列表项
|
23 |
+
list_items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', section_text)
|
24 |
+
if list_items:
|
25 |
+
return list_items
|
26 |
+
|
27 |
+
# 如果没有明确的列表格式,尝试按行分割
|
28 |
+
lines = [line.strip() for line in section_text.split('\n') if line.strip()]
|
29 |
+
if lines:
|
30 |
+
return lines
|
31 |
+
|
32 |
+
# 尝试直接从文本中提取列表格式
|
33 |
+
list_matches = re.findall(list_pattern, text)
|
34 |
+
if list_matches:
|
35 |
+
all_items = []
|
36 |
+
for match in list_matches:
|
37 |
+
items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', match)
|
38 |
+
all_items.extend(items)
|
39 |
+
if all_items:
|
40 |
+
return all_items
|
41 |
+
|
42 |
+
# 如果没有列表格式,尝试按句子分割
|
43 |
+
sentences = re.findall(r'([^.!?]+[.!?])', text)
|
44 |
+
if sentences:
|
45 |
+
return [s.strip() for s in sentences if len(s.strip()) > 10] # 过滤掉过短的句子
|
46 |
+
|
47 |
+
# 最后的回退:按段落分割
|
48 |
+
paragraphs = text.split('\n\n')
|
49 |
+
if len(paragraphs) > 1:
|
50 |
+
return [p.strip() for p in paragraphs if len(p.strip()) > 10]
|
51 |
+
|
52 |
+
# 如果所有方法都失败,返回完整文本作为单个变化
|
53 |
+
return [text.strip()] if text.strip() else []
|
54 |
+
|
55 |
+
def extract_status(text):
|
56 |
+
"""从文本中提取患者状态总结"""
|
57 |
+
# 寻找明确标记的总结部分
|
58 |
+
status_section = re.search(r'(?:总结|状态|变化|结论)[::]\s*([\s\S]+)$', text)
|
59 |
+
if status_section:
|
60 |
+
return status_section.group(1).strip()
|
61 |
+
|
62 |
+
# 如果没有明确的总结标记,尝试返回完整文本
|
63 |
+
# 过滤掉可能的指令解释部分
|
64 |
+
clean_text = re.sub(r'^.*?(?:根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
|
65 |
+
|
66 |
+
# 移除可能的前导分析部分
|
67 |
+
clean_text = re.sub(r'^.*?(?:分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
|
68 |
+
|
69 |
+
return clean_text.strip()
|
70 |
+
|
71 |
+
def analyzing_changes(scales):
|
72 |
+
client = OpenAI(
|
73 |
+
api_key=api_key,
|
74 |
+
base_url=base_url
|
75 |
+
)
|
76 |
+
|
77 |
+
# 导入量表及问题
|
78 |
+
bdi_scale = json.load(open("./scales/bdi.json", "r"))
|
79 |
+
ghq_scale = json.load(open("./scales/ghq-28.json", "r"))
|
80 |
+
sass_scale = json.load(open("./scales/sass.json", "r"))
|
81 |
+
|
82 |
+
# 总结BDI的变化
|
83 |
+
bdi_instruction = """
|
84 |
+
### 任务
|
85 |
+
根据量表的问题和答案,总结出两份量表之间的变化。
|
86 |
+
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
|
87 |
+
使用以下格式:
|
88 |
+
变化:
|
89 |
+
1. [第一个变化]
|
90 |
+
2. [第二个变化]
|
91 |
+
...
|
92 |
+
|
93 |
+
### 量表及问题
|
94 |
+
{}
|
95 |
+
|
96 |
+
### 第一份量表的答案
|
97 |
+
{}
|
98 |
+
|
99 |
+
### 第二份量表的答案
|
100 |
+
{}
|
101 |
+
""".format(bdi_scale, scales['p_bdi'], scales['bdi'])
|
102 |
+
|
103 |
+
response = client.chat.completions.create(
|
104 |
+
model=model_name,
|
105 |
+
messages=[{"role": "user", "content": bdi_instruction}],
|
106 |
+
temperature=0
|
107 |
+
)
|
108 |
+
|
109 |
+
bdi_response = response.choices[0].message.content
|
110 |
+
bdi_changes = extract_changes(bdi_response)
|
111 |
+
|
112 |
+
# 总结GHQ的变化
|
113 |
+
ghq_instruction = """
|
114 |
+
### 任务
|
115 |
+
根据量表的问题和答案,总结出两份量表之间的变化。
|
116 |
+
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
|
117 |
+
使用以下格式:
|
118 |
+
变化:
|
119 |
+
1. [第一个变化]
|
120 |
+
2. [第二个变化]
|
121 |
+
...
|
122 |
+
|
123 |
+
### 量表及问题
|
124 |
+
{}
|
125 |
+
|
126 |
+
### 第一份量表的答案
|
127 |
+
{}
|
128 |
+
|
129 |
+
### 第二份量表的答案
|
130 |
+
{}
|
131 |
+
""".format(ghq_scale, scales['p_ghq'], scales['ghq'])
|
132 |
+
|
133 |
+
response = client.chat.completions.create(
|
134 |
+
model=model_name,
|
135 |
+
messages=[{"role": "user", "content": ghq_instruction}],
|
136 |
+
temperature=0
|
137 |
+
)
|
138 |
+
|
139 |
+
ghq_response = response.choices[0].message.content
|
140 |
+
ghq_changes = extract_changes(ghq_response)
|
141 |
+
|
142 |
+
# 总结SASS的变化
|
143 |
+
sass_instruction = """
|
144 |
+
### 任务
|
145 |
+
根据量表的问题和答案,总结出两份量表之间的变化。
|
146 |
+
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
|
147 |
+
使用以下格式:
|
148 |
+
变化:
|
149 |
+
1. [第一个变化]
|
150 |
+
2. [第二个变化]
|
151 |
+
...
|
152 |
+
|
153 |
+
### 量表及问题
|
154 |
+
{}
|
155 |
+
|
156 |
+
### 第一份量表的答案
|
157 |
+
{}
|
158 |
+
|
159 |
+
### 第二份量表的答案
|
160 |
+
{}
|
161 |
+
""".format(sass_scale, scales['p_sass'], scales['sass'])
|
162 |
+
|
163 |
+
response = client.chat.completions.create(
|
164 |
+
model=model_name,
|
165 |
+
messages=[{"role": "user", "content": sass_instruction}],
|
166 |
+
temperature=0
|
167 |
+
)
|
168 |
+
|
169 |
+
sass_response = response.choices[0].message.content
|
170 |
+
sass_changes = extract_changes(sass_response)
|
171 |
+
|
172 |
+
return bdi_changes, ghq_changes, sass_changes
|
173 |
+
|
174 |
+
def summarize_scale_changes(scales):
|
175 |
+
client = OpenAI(
|
176 |
+
api_key=api_key,
|
177 |
+
base_url=base_url
|
178 |
+
)
|
179 |
+
|
180 |
+
# 获取量表变化
|
181 |
+
bdi_changes, ghq_changes, sass_changes = analyzing_changes(scales)
|
182 |
+
|
183 |
+
# 总结量表变化
|
184 |
+
summary_instruction = """
|
185 |
+
### 任务
|
186 |
+
根据量表的变化,总结患者的身体和心理状态变化。
|
187 |
+
请提供一个全面但简洁的总结,使用以下格式:
|
188 |
+
总结:
|
189 |
+
[总结内容]
|
190 |
+
|
191 |
+
### BDI量表变化
|
192 |
+
{}
|
193 |
+
|
194 |
+
### GHQ量表变化
|
195 |
+
{}
|
196 |
+
|
197 |
+
### SASS量表变化
|
198 |
+
{}
|
199 |
+
""".format(
|
200 |
+
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(bdi_changes)]),
|
201 |
+
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(ghq_changes)]),
|
202 |
+
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(sass_changes)])
|
203 |
+
)
|
204 |
+
|
205 |
+
response = client.chat.completions.create(
|
206 |
+
model=model_name,
|
207 |
+
messages=[{"role": "user", "content": summary_instruction}],
|
208 |
+
temperature=0
|
209 |
+
)
|
210 |
+
|
211 |
+
summary_response = response.choices[0].message.content
|
212 |
+
status = extract_status(summary_response)
|
213 |
+
|
214 |
+
return status
|
215 |
+
|
216 |
+
# 额外增加一个更健壮的解析函数,可以处理不同格式的输出
|
217 |
+
def parse_response_robust(text, expected_format="list"):
|
218 |
+
"""更健壮的响应解析函数
|
219 |
+
|
220 |
+
参数:
|
221 |
+
text: 文本响应
|
222 |
+
expected_format: 预期格式,可以是"list"或"summary"
|
223 |
+
|
224 |
+
返回:
|
225 |
+
解析后的结果(列表或字符串)
|
226 |
+
"""
|
227 |
+
# 首先尝试JSON格式解析
|
228 |
+
try:
|
229 |
+
# 尝试提取JSON部分
|
230 |
+
json_pattern = r'\{[\s\S]*\}'
|
231 |
+
json_match = re.search(json_pattern, text)
|
232 |
+
if json_match:
|
233 |
+
json_data = json.loads(json_match.group(0))
|
234 |
+
if expected_format == "list" and "changes" in json_data:
|
235 |
+
return json_data["changes"]
|
236 |
+
elif expected_format == "summary" and "status" in json_data:
|
237 |
+
return json_data["status"]
|
238 |
+
except:
|
239 |
+
pass # 如果JSON解析失败,继续尝试其他方法
|
240 |
+
|
241 |
+
# 使用适当的提取函数
|
242 |
+
if expected_format == "list":
|
243 |
+
return extract_changes(text)
|
244 |
+
else: # summary
|
245 |
+
return extract_status(text)
|
246 |
+
|
247 |
+
# unit test
|
248 |
+
# if __name__ == "__main__":
|
249 |
+
# # 测试数据
|
250 |
+
# scales = {
|
251 |
+
# "p_bdi": ["A", "B", "C"],
|
252 |
+
# "bdi": ["B", "C", "D"],
|
253 |
+
# "p_ghq": ["A", "A", "B"],
|
254 |
+
# "ghq": ["B", "C", "C"],
|
255 |
+
# "p_sass": ["A", "B", "A"],
|
256 |
+
# "sass": ["C", "D", "B"]
|
257 |
+
# }
|
258 |
+
|
259 |
+
# changes = summarize_scale_changes(scales)
|
260 |
+
# print(changes)
|
src/style_analyzer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
# 设置OpenAI API密钥和基础URL
|
6 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
7 |
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
8 |
+
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
|
9 |
+
|
10 |
+
def analyze_style(profile, conversations):
|
11 |
+
client = OpenAI(
|
12 |
+
api_key=api_key,
|
13 |
+
base_url=base_url
|
14 |
+
)
|
15 |
+
# 提取患者信息
|
16 |
+
patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}"
|
17 |
+
# 提取对话记录
|
18 |
+
dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversations])
|
19 |
+
|
20 |
+
# 构建提示词,明确要求模型按特定格式输出结果
|
21 |
+
prompt = f"""### 任务
|
22 |
+
根据患者情况及咨访对话历史记录分析患者的说话风格。
|
23 |
+
|
24 |
+
{patient_info}
|
25 |
+
|
26 |
+
### 对话记录
|
27 |
+
{dialogue_history}
|
28 |
+
|
29 |
+
请分析患者的说话风格,最多列出5种风格特点。
|
30 |
+
请按以下格式输出结果:
|
31 |
+
说话风格:
|
32 |
+
1. [风格特点1]
|
33 |
+
2. [风格特点2]
|
34 |
+
3. [风格特点3]
|
35 |
+
...
|
36 |
+
|
37 |
+
只需要列出风格特点,不需要解释。
|
38 |
+
"""
|
39 |
+
|
40 |
+
response = client.chat.completions.create(
|
41 |
+
model=model_name,
|
42 |
+
messages=[
|
43 |
+
{"role": "user", "content": prompt}
|
44 |
+
]
|
45 |
+
)
|
46 |
+
|
47 |
+
# 从响应中提取说话风格列表
|
48 |
+
response_text = response.choices[0].message.content
|
49 |
+
|
50 |
+
# 使用正则表达式提取风格特点
|
51 |
+
# 匹配"说话风格:"之后的列表项
|
52 |
+
style_pattern = r"说话风格:\s*(?:\d+\.\s*([^\n]+)(?:\n|$))+"
|
53 |
+
match = re.search(style_pattern, response_text, re.DOTALL)
|
54 |
+
|
55 |
+
if match:
|
56 |
+
# 提取所有的列表项
|
57 |
+
style_items = re.findall(r"\d+\.\s*([^\n]+)", response_text)
|
58 |
+
return style_items
|
59 |
+
else:
|
60 |
+
# 如果没有按预期格式输出,尝试使用备用正则表达式
|
61 |
+
# 寻找任何可能的列表项
|
62 |
+
fallback_items = re.findall(r"(?:^|\n)(?:\d+[\.\)、]|[-•*])\s*([^\n]+)", response_text)
|
63 |
+
|
64 |
+
# 如果仍然没找到,尝试直接分割文本
|
65 |
+
if not fallback_items:
|
66 |
+
# 找到可能包含风格描述的行
|
67 |
+
potential_styles = [line.strip() for line in response_text.split('\n')
|
68 |
+
if line.strip() and not line.startswith('###') and ':' not in line]
|
69 |
+
return potential_styles[:5] # 最多返回5项
|
70 |
+
|
71 |
+
return fallback_items[:5] # 最多返回5项
|
72 |
+
|
73 |
+
# unit test
|
74 |
+
# profile = {
|
75 |
+
# "drisk": 3,
|
76 |
+
# "srisk": 2,
|
77 |
+
# "age": "42",
|
78 |
+
# "gender": "女",
|
79 |
+
# "marital_status": "离婚",
|
80 |
+
# "occupation": "教师",
|
81 |
+
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
|
82 |
+
# }
|
83 |
+
# conversations = [
|
84 |
+
# {"role": "user", "content": "我最近感觉很沮丧,似乎一切都没有意义。"},
|
85 |
+
# {"role": "assistant", "content": "你能具体说说是什么让你有这样的感觉吗?"},
|
86 |
+
# {"role": "user", "content": "我觉得自己在工作上总是做不好,没什么价值。"}
|
87 |
+
# ]
|
88 |
+
# print(analyze_style(profile, conversations))
|