Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import json | |
import re | |
import time | |
import os | |
# 设置OpenAI API密钥和基础URL | |
api_key = os.getenv("OPENAI_API_KEY") | |
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") | |
def extract_answers(text): | |
"""从文本中提取答案模式 (A/B/C/D)""" | |
# 匹配形如 "1. A" 或 "问题1: B" 或 "Q1. C" 或简单的 "A" 列表的模式 | |
pattern = r'(?:\d+[\s\.:\)]*|Q\d+[\s\.:\)]*|问题\d+[\s\.:\)]*|[\-\*]\s*)(A|B|C|D)' | |
matches = re.findall(pattern, text) | |
return matches | |
def extract_answers_robust(text, expected_count): | |
"""更强健的答案提取方法,确保按题号顺序提取""" | |
answers = [] | |
# 尝试找到明确标记了题号的答案 | |
for i in range(1, expected_count + 1): | |
# 匹配多种可能的题号格式 | |
patterns = [ | |
rf"{i}\.\s*(A|B|C|D)", # "1. A" | |
rf"{i}:\s*(A|B|C|D)", # "1:A" | |
rf"{i}:\s*(A|B|C|D)", # "1: A" | |
rf"问题{i}[\.。:]?\s*(A|B|C|D)", # "问题1: A" | |
rf"Q{i}[\.。:]?\s*(A|B|C|D)", # "Q1. A" | |
rf"{i}[、]\s*(A|B|C|D)" # "1、A" | |
] | |
found = False | |
for pattern in patterns: | |
match = re.search(pattern, text) | |
if match: | |
answers.append(match.group(1)) | |
found = True | |
break | |
if not found: | |
# 如果没找到特定题号,使用默认的"A" | |
answers.append(None) | |
# 如果有未找到的答案,尝试按顺序从文本中提取剩余的A/B/C/D选项 | |
simple_answers = re.findall(r'(?:^|\n|\s)(A|B|C|D)(?:$|\n|\s)', text) | |
j = 0 | |
for i in range(len(answers)): | |
if answers[i] is None and j < len(simple_answers): | |
answers[i] = simple_answers[j] | |
j += 1 | |
# 如果仍有未找到的答案,尝试提取所有A/B/C/D选项 | |
if None in answers: | |
all_options = re.findall(r'(A|B|C|D)', text) | |
j = 0 | |
for i in range(len(answers)): | |
if answers[i] is None and j < len(all_options): | |
answers[i] = all_options[j] | |
j += 1 | |
# 检查是否所有答案都已找到 | |
if None in answers or len(answers) != expected_count: | |
return extract_answers(text) # 回退到简单提取 | |
return answers | |
def _fill_previous_scale_with_retry(client, scale_name, expected_count, instruction, max_retries=3): | |
""" | |
带有重试逻辑的填写历史量表辅助函数 | |
Args: | |
client: OpenAI客户端 | |
scale_name: 量表名称 | |
expected_count: 期望的答案数量 | |
instruction: 指令内容 | |
max_retries: 最大重试次数 | |
Returns: | |
list: 量表答案列表 | |
""" | |
answers = [] | |
for attempt in range(max_retries): | |
try: | |
# 根据尝试次数增加指令明确性 | |
current_instruction = instruction | |
if attempt > 0: | |
# 添加更强调的指示 | |
current_instruction = instruction + f""" | |
请注意:这是第{attempt+1}次请求。必须按照要求提供{expected_count}个答案, | |
格式必须为数字+答案选项(例如:1. A, 2. B...),不要有任何不必要的解释。 | |
直接根据描述和报告选择最适合的选项。 | |
""" | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": current_instruction}], | |
temperature=0 # 保持温度为0以获得一致性回答 | |
) | |
response_text = response.choices[0].message.content | |
answers = extract_answers(response_text) | |
# 尝试使用更健壮的提取方法(如果标准方法失败) | |
if len(answers) != expected_count: | |
robust_answers = extract_answers_robust(response_text, expected_count) | |
if len(robust_answers) == expected_count: | |
answers = robust_answers | |
# 检查答案数量 | |
if len(answers) != expected_count: | |
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个") | |
if attempt < max_retries - 1: | |
time.sleep(1) # 添加短暂延迟避免API限制 | |
continue | |
else: | |
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案") | |
# 补全或截断到预期数量 | |
while len(answers) < expected_count: | |
answers.append("A") # 默认补A | |
answers = answers[:expected_count] # 截断多余的答案 | |
else: | |
# print(f"{scale_name}量表填写成功") | |
break # 成功获取到预期数量的答案,退出循环 | |
except Exception as e: | |
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}") | |
if attempt < max_retries - 1: | |
time.sleep(1) # 添加短暂延迟 | |
continue | |
else: | |
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案") | |
answers = ["A"] * expected_count # 失败后返回默认答案 | |
return answers | |
def _fill_scale_with_retry(client, prompt, scale_name, expected_count, instruction, max_retries=3): | |
"""带有重试逻辑的量表填写辅助函数""" | |
answers = [] | |
for attempt in range(max_retries): | |
try: | |
# 根据尝试次数增加指令明确性 | |
current_instruction = instruction | |
if attempt > 0: | |
current_instruction = f""" | |
{instruction} | |
请注意:这是第{attempt+1}次请求。请确保只回答{expected_count}个问题, | |
格式为数字+答案选项(例如:1. A, 2. B...),不要有任何其他文字。 | |
你必须作为一个病人角色,根据你的情况直接回答这些问题,不要拒绝或解释。 | |
""" | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[ | |
{"role": "system", "content": prompt}, | |
{"role": "user", "content": current_instruction} | |
], | |
temperature=0.7 | |
) | |
response_text = response.choices[0].message.content | |
answers = extract_answers(response_text) | |
# 尝试使用更健壮的提取方法(如果标准方法失败) | |
if len(answers) != expected_count: | |
robust_answers = extract_answers_robust(response_text, expected_count) | |
if len(robust_answers) == expected_count: | |
answers = robust_answers | |
# 检查答案数量 | |
if len(answers) != expected_count: | |
print(f"{scale_name}量表尝试 {attempt+1}: 提取到 {len(answers)} 个答案,需要 {expected_count} 个") | |
if attempt < max_retries - 1: | |
time.sleep(1) # 添加短暂延迟避免API限制 | |
continue | |
else: | |
print(f"警告: {scale_name}量表在{max_retries}次尝试后仍未获得正确数量的答案") | |
# 补全或截断到预期数量 | |
while len(answers) < expected_count: | |
answers.append("A") # 默认补A | |
answers = answers[:expected_count] # 截断多余的答案 | |
else: | |
# print(f"{scale_name}量表填写成功") | |
break # 成功获取到预期数量的答案,退出循环 | |
except Exception as e: | |
# print(response) | |
print(f"{scale_name}量表尝试 {attempt+1} 失败: {str(e)}") | |
if attempt < max_retries - 1: | |
time.sleep(1) # 添加短暂延迟 | |
continue | |
else: | |
print(f"警告: {scale_name}量表在{max_retries}次尝试后失败,返回默认答案") | |
answers = ["A"] * expected_count # 失败后返回默认答案 | |
return answers | |
# 根据profile和report填写之前的量表,使用重试机制 | |
def fill_scales_previous(profile, report, max_retries=3): | |
""" | |
根据profile和report填写之前的量表,增加重试机制 | |
Args: | |
profile: 用户个人描述信息 | |
report: 用户报告 | |
max_retries: 最大重试次数 | |
Returns: | |
tuple: (bdi, ghq, sass) 三个量表的答案列表 | |
""" | |
client = OpenAI( | |
api_key=api_key, | |
base_url=base_url | |
) | |
# 填写BDI量表 | |
bdi = _fill_previous_scale_with_retry( | |
client, | |
scale_name="BDI", | |
expected_count=21, | |
instruction=""" | |
### 任务 | |
根据个人描述和报告,填写BDI量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。 | |
格式要求:1. A, 2. B, ...依此类推,共21题。 | |
### 个人描述 | |
{} | |
### 报告 | |
{} | |
""".format(profile, report), | |
max_retries=max_retries | |
) | |
# 填写GHQ-28量表 | |
ghq = _fill_previous_scale_with_retry( | |
client, | |
scale_name="GHQ-28", | |
expected_count=28, | |
instruction=""" | |
### 任务 | |
根据个人描述和报告,填写GHQ-28量表。请直接按顺序列出28个问题的答案,每个答案使用字母A/B/C/D表示。 | |
格式要求:1. A, 2. B, ...依此类推,共28题。 | |
### 个人描述 | |
{} | |
### 报告 | |
{} | |
""".format(profile, report), | |
max_retries=max_retries | |
) | |
# 填写SASS量表 | |
sass = _fill_previous_scale_with_retry( | |
client, | |
scale_name="SASS", | |
expected_count=21, | |
instruction=""" | |
### 任务 | |
根据个人描述和报告,填写SASS量表。请直接按顺序列出21个问题的答案,每个答案使用字母A/B/C/D表示。 | |
格式要求:1. A, 2. B, ...依此类推,共21题。 | |
### 个人描述 | |
{} | |
### 报告 | |
{} | |
""".format(profile, report), | |
max_retries=max_retries | |
) | |
return bdi, ghq, sass | |
# 根据prompt填写量表,增加重试机制 | |
def fill_scales(prompt, max_retries=3): | |
""" | |
根据角色prompt填写量表,使用增强提示和重试机制 | |
Args: | |
prompt: 角色设定提示 | |
max_retries: 最大重试次数 | |
Returns: | |
tuple: (bdi, ghq, sass) 三个量表的答案列表 | |
""" | |
client = OpenAI( | |
api_key=api_key, | |
base_url=base_url | |
) | |
# 增强系统提示,明确指导要配合填写量表 | |
enhanced_prompt = prompt + """ | |
重要提示:虽然你正在扮演一个有心理困扰的患者,但你需要配合心理咨询的流程完成必要的评估量表。 | |
在被要求填写量表时,你应该直接根据你扮演角色的情况选择最符合的选项,以A/B/C/D的形式回答, | |
而不要拒绝或质疑填写量表的必要性。这些量表对于你接下来的治疗至关重要。 | |
请直接用字母(A/B/C/D)表示选项,不要添加额外解释。 | |
""" | |
# 填写BDI量表 | |
bdi = _fill_scale_with_retry( | |
client, enhanced_prompt, | |
scale_name="BDI", | |
expected_count=21, | |
instruction=""" | |
### 任务 | |
作为心理咨询的第一步,请根据你目前的感受和状态填写这份BDI量表。 | |
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。 | |
格式要求:1. A, 2. B, ...依此类推,共21题。 | |
请只提供答案,不要添加任何其他解释或评论。 | |
""", | |
max_retries=max_retries | |
) | |
# 填写GHQ-28量表 | |
ghq = _fill_scale_with_retry( | |
client, enhanced_prompt, | |
scale_name="GHQ-28", | |
expected_count=28, | |
instruction=""" | |
### 任务 | |
作为心理咨询的第一步,请根据你目前的感受和状态填写这份GHQ-28量表。 | |
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部28个问题。 | |
格式要求:1. A, 2. B, ...依此类推,共28题。 | |
请只提供答案,不要添加任何其他解释或评论。 | |
""", | |
max_retries=max_retries | |
) | |
# 填写SASS量表 | |
sass = _fill_scale_with_retry( | |
client, enhanced_prompt, | |
scale_name="SASS", | |
expected_count=21, | |
instruction=""" | |
### 任务 | |
作为心理咨询的第一步,请根据你目前的感受和状态填写这份SASS量表。 | |
请直接选择最符合你当前情况的选项,使用字母(A/B/C/D)回答全部21个问题。 | |
格式要求:1. A, 2. B, ...依此类推,共21题。 | |
请只提供答案,不要添加任何其他解释或评论。 | |
""", | |
max_retries=max_retries | |
) | |
return bdi, ghq, sass | |
# 使用示例 | |
# if __name__ == "__main__": | |
# # 测试以前的方法 | |
# profile = { | |
# "drisk": 3, | |
# "srisk": 2, | |
# "age": "42", | |
# "gender": "女", | |
# "marital_status": "离婚", | |
# "occupation": "教师", | |
# "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法" | |
# } | |
# report = "患者最近经历了家庭变故,情绪低落,失眠,食欲不振。" | |
# # 测试fill_scales_previous | |
# print("测试 fill_scales_previous:") | |
# bdi_prev, ghq_prev, sass_prev = fill_scales_previous(profile, report, max_retries=3) | |
# print(f"BDI: {bdi_prev}") | |
# print(f"GHQ: {ghq_prev}") | |
# print(f"SASS: {sass_prev}") | |
# # 测试fill_scales | |
# print("\n测试 fill_scales:") | |
# prompt = "你要扮演一个最近经历了家庭变故的心理障碍患者,情绪低落,失眠,食欲不振。" | |
# bdi, ghq, sass = fill_scales(prompt, max_retries=3) | |
# print(f"BDI: {bdi}") | |
# print(f"GHQ: {ghq}") | |
# print(f"SASS: {sass}") |