Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import json | |
import re | |
import os | |
# 设置OpenAI API密钥和基础URL | |
api_key = os.getenv("OPENAI_API_KEY") | |
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") | |
def extract_changes(text): | |
"""从文本中提取变化列表""" | |
# 首先尝试查找明确的变化列表格式 | |
# 例如: "变化:\n1. xxx\n2. yyy" | |
list_pattern = r'((?:(?:\d+\.|\-|\*)\s*[^\n]+\n?)+)' | |
# 尝试匹配带有明确标记的变化列表 | |
change_section = re.search(r'(?:变化(?:列表)?|总结(?:如下)?)[::]\s*([\s\S]+)$', text) | |
if change_section: | |
section_text = change_section.group(1).strip() | |
# 尝试匹配列表项 | |
list_items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', section_text) | |
if list_items: | |
return list_items | |
# 如果没有明确的列表格式,尝试按行分割 | |
lines = [line.strip() for line in section_text.split('\n') if line.strip()] | |
if lines: | |
return lines | |
# 尝试直接从文本中提取列表格式 | |
list_matches = re.findall(list_pattern, text) | |
if list_matches: | |
all_items = [] | |
for match in list_matches: | |
items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', match) | |
all_items.extend(items) | |
if all_items: | |
return all_items | |
# 如果没有列表格式,尝试按句子分割 | |
sentences = re.findall(r'([^.!?]+[.!?])', text) | |
if sentences: | |
return [s.strip() for s in sentences if len(s.strip()) > 10] # 过滤掉过短的句子 | |
# 最后的回退:按段落分割 | |
paragraphs = text.split('\n\n') | |
if len(paragraphs) > 1: | |
return [p.strip() for p in paragraphs if len(p.strip()) > 10] | |
# 如果所有方法都失败,返回完整文本作为单个变化 | |
return [text.strip()] if text.strip() else [] | |
def extract_status(text): | |
"""从文本中提取患者状态总结""" | |
# 寻找明确标记的总结部分 | |
status_section = re.search(r'(?:总结|状态|变化|结论)[::]\s*([\s\S]+)$', text) | |
if status_section: | |
return status_section.group(1).strip() | |
# 如果没有明确的总结标记,尝试返回完整文本 | |
# 过滤掉可能的指令解释部分 | |
clean_text = re.sub(r'^.*?(?:根据|基于).*?[,,。]', '', text, flags=re.DOTALL) | |
# 移除可能的前导分析部分 | |
clean_text = re.sub(r'^.*?(?:分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL) | |
return clean_text.strip() | |
def analyzing_changes(scales): | |
client = OpenAI( | |
api_key=api_key, | |
base_url=base_url | |
) | |
# 导入量表及问题 | |
bdi_scale = json.load(open("./scales/bdi.json", "r")) | |
ghq_scale = json.load(open("./scales/ghq-28.json", "r")) | |
sass_scale = json.load(open("./scales/sass.json", "r")) | |
# 总结BDI的变化 | |
bdi_instruction = """ | |
### 任务 | |
根据量表的问题和答案,总结出两份量表之间的变化。 | |
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。 | |
使用以下格式: | |
变化: | |
1. [第一个变化] | |
2. [第二个变化] | |
... | |
### 量表及问题 | |
{} | |
### 第一份量表的答案 | |
{} | |
### 第二份量表的答案 | |
{} | |
""".format(bdi_scale, scales['p_bdi'], scales['bdi']) | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": bdi_instruction}], | |
temperature=0 | |
) | |
bdi_response = response.choices[0].message.content | |
bdi_changes = extract_changes(bdi_response) | |
# 总结GHQ的变化 | |
ghq_instruction = """ | |
### 任务 | |
根据量表的问题和答案,总结出两份量表之间的变化。 | |
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。 | |
使用以下格式: | |
变化: | |
1. [第一个变化] | |
2. [第二个变化] | |
... | |
### 量表及问题 | |
{} | |
### 第一份量表的答案 | |
{} | |
### 第二份量表的答案 | |
{} | |
""".format(ghq_scale, scales['p_ghq'], scales['ghq']) | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": ghq_instruction}], | |
temperature=0 | |
) | |
ghq_response = response.choices[0].message.content | |
ghq_changes = extract_changes(ghq_response) | |
# 总结SASS的变化 | |
sass_instruction = """ | |
### 任务 | |
根据量表的问题和答案,总结出两份量表之间的变化。 | |
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。 | |
使用以下格式: | |
变化: | |
1. [第一个变化] | |
2. [第二个变化] | |
... | |
### 量表及问题 | |
{} | |
### 第一份量表的答案 | |
{} | |
### 第二份量表的答案 | |
{} | |
""".format(sass_scale, scales['p_sass'], scales['sass']) | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": sass_instruction}], | |
temperature=0 | |
) | |
sass_response = response.choices[0].message.content | |
sass_changes = extract_changes(sass_response) | |
return bdi_changes, ghq_changes, sass_changes | |
def summarize_scale_changes(scales): | |
client = OpenAI( | |
api_key=api_key, | |
base_url=base_url | |
) | |
# 获取量表变化 | |
bdi_changes, ghq_changes, sass_changes = analyzing_changes(scales) | |
# 总结量表变化 | |
summary_instruction = """ | |
### 任务 | |
根据量表的变化,总结患者的身体和心理状态变化。 | |
请提供一个全面但简洁的总结,使用以下格式: | |
总结: | |
[总结内容] | |
### BDI量表变化 | |
{} | |
### GHQ量表变化 | |
{} | |
### SASS量表变化 | |
{} | |
""".format( | |
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(bdi_changes)]), | |
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(ghq_changes)]), | |
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(sass_changes)]) | |
) | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": summary_instruction}], | |
temperature=0 | |
) | |
summary_response = response.choices[0].message.content | |
status = extract_status(summary_response) | |
return status | |
# 额外增加一个更健壮的解析函数,可以处理不同格式的输出 | |
def parse_response_robust(text, expected_format="list"): | |
"""更健壮的响应解析函数 | |
参数: | |
text: 文本响应 | |
expected_format: 预期格式,可以是"list"或"summary" | |
返回: | |
解析后的结果(列表或字符串) | |
""" | |
# 首先尝试JSON格式解析 | |
try: | |
# 尝试提取JSON部分 | |
json_pattern = r'\{[\s\S]*\}' | |
json_match = re.search(json_pattern, text) | |
if json_match: | |
json_data = json.loads(json_match.group(0)) | |
if expected_format == "list" and "changes" in json_data: | |
return json_data["changes"] | |
elif expected_format == "summary" and "status" in json_data: | |
return json_data["status"] | |
except: | |
pass # 如果JSON解析失败,继续尝试其他方法 | |
# 使用适当的提取函数 | |
if expected_format == "list": | |
return extract_changes(text) | |
else: # summary | |
return extract_status(text) | |
# unit test | |
# if __name__ == "__main__": | |
# # 测试数据 | |
# scales = { | |
# "p_bdi": ["A", "B", "C"], | |
# "bdi": ["B", "C", "D"], | |
# "p_ghq": ["A", "A", "B"], | |
# "ghq": ["B", "C", "C"], | |
# "p_sass": ["A", "B", "A"], | |
# "sass": ["C", "D", "B"] | |
# } | |
# changes = summarize_scale_changes(scales) | |
# print(changes) |