AnnaAgent-Demo / src /fill_scales.py
sci-m-wang's picture
Upload 14 files
1d4c295 verified
raw
history blame
14.8 kB
from openai import OpenAI
import json
import re
import time
import os
# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
def extract_answers(text):
"""从文本中提取答案模式 (A/B/C/D)"""
# 匹配形如 "1. A" 或 "问题1: B" 或 "Q1. C" 或简单的 "A" 列表的模式
pattern = r'(?:\d+[\s\.:\)]*|Q\d+[\s\.:\)]*|问题\d+[\s\.:\)]*|[\-\*]\s*)(A|B|C|D)'
matches = re.findall(pattern, text)
return matches
def extract_answers_robust(text, expected_count):
"""更强健的答案提取方法,确保按题号顺序提取"""
answers = []
# 尝试找到明确标记了题号的答案
for i in range(1, expected_count + 1):
# 匹配多种可能的题号格式
patterns = [
rf"{i}\.\s*(A|B|C|D)", # "1. A"
rf"{i}:\s*(A|B|C|D)", # "1:A"
rf"{i}:\s*(A|B|C|D)", # "1: A"
rf"问题{i}[\.。:]?\s*(A|B|C|D)", # "问题1: A"
rf"Q{i}[\.。:]?\s*(A|B|C|D)", # "Q1. A"
rf"{i}[、]\s*(A|B|C|D)" # "1、A"
]
found = False
for pattern in patterns:
match = re.search(pattern, text)
if match:
answers.append(match.group(1))
found = True
break
if not found:
# 如果没找到特定题号,使用默认的"A"
answers.append(None)
# 如果有未找到的答案,尝试按顺序从文本中提取剩余的A/B/C/D选项
simple_answers = re.findall(r'(?:^|\n|\s)(A|B|C|D)(?:$|\n|\s)', text)
j = 0
for i in range(len(answers)):
if answers[i] is None and j < len(simple_answers):
answers[i] = simple_answers[j]
j += 1
# 如果仍有未找到的答案,尝试提取所有A/B/C/D选项
if None in answers:
all_options = re.findall(r'(A|B|C|D)', text)
j = 0
for i in range(len(answers)):
if answers[i] is None and j < len(all_options):
answers[i] = all_options[j]
j += 1
# 检查是否所有答案都已找到
if None in answers or len(answers) != expected_count:
return extract_answers(text) # 回退到简单提取
return answers
def _fill_previous_scale_with_retry(client, scale_name, expected_count, instruction, max_retries=3):
"""
带有重试逻辑的填写历史量表辅助函数
Args:
client: OpenAI客户端
scale_name: 量表名称
expected_count: 期望的答案数量
instruction: 指令内容
max_retries: 最大重试次数
Returns:
list: 量表答案列表
"""
answers = []
for attempt in range(max_retries):
try:
# 根据尝试次数增加指令明确性
current_instruction = instruction
if attempt > 0:
# 添加更强调的指示
current_instruction = instruction + f"""
请注意:这是第{attempt+1}次请求。必须按照要求提供{expected_count}个答案,
格式必须为数字+答案选项(例如:1. A, 2. B...),不要有任何不必要的解释。
直接根据描述和报告选择最适合的选项。
"""
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": current_instruction}],
temperature=0 # 保持温度为0以获得一致性回答
)
response_text = response.choices[0].message.content
answers = extract_answers(response_text)
# 尝试使用更健壮的提取方法(如果标准方法失败)
if len(answers) != expected_count:
robust_answers = extract_answers_robust(response_text, expected_count)
if len(robust_answers) == expected_count:
answers = robust_answers
# 检查答案数量
if len(answers) != expected_count:
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
if attempt < max_retries - 1:
time.sleep(1) # 添加短暂延迟避免API限制
continue
else:
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
# 补全或截断到预期数量
while len(answers) < expected_count:
answers.append("A") # 默认补A
answers = answers[:expected_count] # 截断多余的答案
else:
# print(f"{scale_name}量表填写成功")
break # 成功获取到预期数量的答案,退出循环
except Exception as e:
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1) # 添加短暂延迟
continue
else:
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
answers = ["A"] * expected_count # 失败后返回默认答案
return answers
def _fill_scale_with_retry(client, prompt, scale_name, expected_count, instruction, max_retries=3):
"""带有重试逻辑的量表填写辅助函数"""
answers = []
for attempt in range(max_retries):
try:
# 根据尝试次数增加指令明确性
current_instruction = instruction
if attempt > 0:
current_instruction = f"""
{instruction}
请注意:这是第{attempt+1}次请求。请确保只回答{expected_count}个问题,
格式为数字+答案选项(例如:1. A, 2. B...),不要有任何其他文字。
你必须作为一个病人角色,根据你的情况直接回答这些问题,不要拒绝或解释。
"""
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": current_instruction}
],
temperature=0.7
)
response_text = response.choices[0].message.content
answers = extract_answers(response_text)
# 尝试使用更健壮的提取方法(如果标准方法失败)
if len(answers) != expected_count:
robust_answers = extract_answers_robust(response_text, expected_count)
if len(robust_answers) == expected_count:
answers = robust_answers
# 检查答案数量
if len(answers) != expected_count:
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个")
if attempt < max_retries - 1:
time.sleep(1) # 添加短暂延迟避免API限制
continue
else:
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案")
# 补全或截断到预期数量
while len(answers) < expected_count:
answers.append("A") # 默认补A
answers = answers[:expected_count] # 截断多余的答案
else:
# print(f"{scale_name}量表填写成功")
break # 成功获取到预期数量的答案,退出循环
except Exception as e:
# print(response)
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1) # 添加短暂延迟
continue
else:
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案")
answers = ["A"] * expected_count # 失败后返回默认答案
return answers
# 根据profile和report填写之前的量表,使用重试机制
def fill_scales_previous(profile, report, max_retries=3):
"""
根据profile和report填写之前的量表,增加重试机制
Args:
profile: 用户个人描述信息
report: 用户报告
max_retries: 最大重试次数
Returns:
tuple: (bdi, ghq, sass) 三个量表的答案列表
"""
client = OpenAI(
api_key=api_key,
base_url=base_url
)
# 填写BDI量表
bdi = _fill_previous_scale_with_retry(
client,
scale_name="BDI",
expected_count=21,
instruction="""
### 任务
根据个人描述和报告,填写BDI量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
格式要求:1. A, 2. B, ...依此类推,共21题。
### 个人描述
{}
### 报告
{}
""".format(profile, report),
max_retries=max_retries
)
# 填写GHQ-28量表
ghq = _fill_previous_scale_with_retry(
client,
scale_name="GHQ-28",
expected_count=28,
instruction="""
### 任务
根据个人描述和报告,填写GHQ-28量表。请直接按顺序列出28个问题的答案,每个答案使用字母A/B/C/D表示。
格式要求:1. A, 2. B, ...依此类推,共28题。
### 个人描述
{}
### 报告
{}
""".format(profile, report),
max_retries=max_retries
)
# 填写SASS量表
sass = _fill_previous_scale_with_retry(
client,
scale_name="SASS",
expected_count=21,
instruction="""
### 任务
根据个人描述和报告,填写SASS量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。
格式要求:1. A, 2. B, ...依此类推,共21题。
### 个人描述
{}
### 报告
{}
""".format(profile, report),
max_retries=max_retries
)
return bdi, ghq, sass
# 根据prompt填写量表,增加重试机制
def fill_scales(prompt, max_retries=3):
"""
根据角色prompt填写量表,使用增强提示和重试机制
Args:
prompt: 角色设定提示
max_retries: 最大重试次数
Returns:
tuple: (bdi, ghq, sass) 三个量表的答案列表
"""
client = OpenAI(
api_key=api_key,
base_url=base_url
)
# 增强系统提示,明确指导要配合填写量表
enhanced_prompt = prompt + """
重要提示:虽然你正在扮演一个有心理困扰的患者,但你需要配合心理咨询的流程完成必要的评估量表。
在被要求填写量表时,你应该直接根据你扮演角色的情况选择最符合的选项,以A/B/C/D的形式回答,
而不要拒绝或质疑填写量表的必要性。这些量表对于你接下来的治疗至关重要。
请直接用字母(A/B/C/D)表示选项,不要添加额外解释。
"""
# 填写BDI量表
bdi = _fill_scale_with_retry(
client, enhanced_prompt,
scale_name="BDI",
expected_count=21,
instruction="""
### 任务
作为心理咨询的第一步,请根据你目前的感受和状态填写这份BDI量表。
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
格式要求:1. A, 2. B, ...依此类推,共21题。
请只提供答案,不要添加任何其他解释或评论。
""",
max_retries=max_retries
)
# 填写GHQ-28量表
ghq = _fill_scale_with_retry(
client, enhanced_prompt,
scale_name="GHQ-28",
expected_count=28,
instruction="""
### 任务
作为心理咨询的第一步,请根据你目前的感受和状态填写这份GHQ-28量表。
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部28个问题。
格式要求:1. A, 2. B, ...依此类推,共28题。
请只提供答案,不要添加任何其他解释或评论。
""",
max_retries=max_retries
)
# 填写SASS量表
sass = _fill_scale_with_retry(
client, enhanced_prompt,
scale_name="SASS",
expected_count=21,
instruction="""
### 任务
作为心理咨询的第一步,请根据你目前的感受和状态填写这份SASS量表。
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。
格式要求:1. A, 2. B, ...依此类推,共21题。
请只提供答案,不要添加任何其他解释或评论。
""",
max_retries=max_retries
)
return bdi, ghq, sass
# 使用示例
# if __name__ == "__main__":
# # 测试以前的方法
# profile = {
# "drisk": 3,
# "srisk": 2,
# "age": "42",
# "gender": "女",
# "marital_status": "离婚",
# "occupation": "教师",
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法"
# }
# report = "患者最近经历了家庭变故,情绪低落,失眠,食欲不振。"
# # 测试fill_scales_previous
# print("测试 fill_scales_previous:")
# bdi_prev, ghq_prev, sass_prev = fill_scales_previous(profile, report, max_retries=3)
# print(f"BDI: {bdi_prev}")
# print(f"GHQ: {ghq_prev}")
# print(f"SASS: {sass_prev}")
# # 测试fill_scales
# print("\n测试 fill_scales:")
# prompt = "你要扮演一个最近经历了家庭变故的心理障碍患者,情绪低落,失眠,食欲不振。"
# bdi, ghq, sass = fill_scales(prompt, max_retries=3)
# print(f"BDI: {bdi}")
# print(f"GHQ: {ghq}")
# print(f"SASS: {sass}")