AnnaAgent-Demo / src /short_term_memory.py
sci-m-wang's picture
Upload 14 files
1d4c295 verified
raw
history blame
8 kB
from openai import OpenAI
import json
import re
import os
# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
def extract_changes(text):
"""从文本中提取变化列表"""
# 首先尝试查找明确的变化列表格式
# 例如: "变化:\n1. xxx\n2. yyy"
list_pattern = r'((?:(?:\d+\.|\-|\*)\s*[^\n]+\n?)+)'
# 尝试匹配带有明确标记的变化列表
change_section = re.search(r'(?:变化(?:列表)?|总结(?:如下)?)[::]\s*([\s\S]+)$', text)
if change_section:
section_text = change_section.group(1).strip()
# 尝试匹配列表项
list_items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', section_text)
if list_items:
return list_items
# 如果没有明确的列表格式,尝试按行分割
lines = [line.strip() for line in section_text.split('\n') if line.strip()]
if lines:
return lines
# 尝试直接从文本中提取列表格式
list_matches = re.findall(list_pattern, text)
if list_matches:
all_items = []
for match in list_matches:
items = re.findall(r'(?:(?:\d+\.|\-|\*)\s*)([^\n]+)', match)
all_items.extend(items)
if all_items:
return all_items
# 如果没有列表格式,尝试按句子分割
sentences = re.findall(r'([^.!?]+[.!?])', text)
if sentences:
return [s.strip() for s in sentences if len(s.strip()) > 10] # 过滤掉过短的句子
# 最后的回退:按段落分割
paragraphs = text.split('\n\n')
if len(paragraphs) > 1:
return [p.strip() for p in paragraphs if len(p.strip()) > 10]
# 如果所有方法都失败,返回完整文本作为单个变化
return [text.strip()] if text.strip() else []
def extract_status(text):
"""从文本中提取患者状态总结"""
# 寻找明确标记的总结部分
status_section = re.search(r'(?:总结|状态|变化|结论)[::]\s*([\s\S]+)$', text)
if status_section:
return status_section.group(1).strip()
# 如果没有明确的总结标记,尝试返回完整文本
# 过滤掉可能的指令解释部分
clean_text = re.sub(r'^.*?(?:根据|基于).*?[,,。]', '', text, flags=re.DOTALL)
# 移除可能的前导分析部分
clean_text = re.sub(r'^.*?(?:分析|查看|判断).*?\n\n', '', clean_text, flags=re.DOTALL)
return clean_text.strip()
def analyzing_changes(scales):
client = OpenAI(
api_key=api_key,
base_url=base_url
)
# 导入量表及问题
bdi_scale = json.load(open("./scales/bdi.json", "r"))
ghq_scale = json.load(open("./scales/ghq-28.json", "r"))
sass_scale = json.load(open("./scales/sass.json", "r"))
# 总结BDI的变化
bdi_instruction = """
### 任务
根据量表的问题和答案,总结出两份量表之间的变化。
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
使用以下格式:
变化:
1. [第一个变化]
2. [第二个变化]
...
### 量表及问题
{}
### 第一份量表的答案
{}
### 第二份量表的答案
{}
""".format(bdi_scale, scales['p_bdi'], scales['bdi'])
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": bdi_instruction}],
temperature=0
)
bdi_response = response.choices[0].message.content
bdi_changes = extract_changes(bdi_response)
# 总结GHQ的变化
ghq_instruction = """
### 任务
根据量表的问题和答案,总结出两份量表之间的变化。
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
使用以下格式:
变化:
1. [第一个变化]
2. [第二个变化]
...
### 量表及问题
{}
### 第一份量表的答案
{}
### 第二份量表的答案
{}
""".format(ghq_scale, scales['p_ghq'], scales['ghq'])
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": ghq_instruction}],
temperature=0
)
ghq_response = response.choices[0].message.content
ghq_changes = extract_changes(ghq_response)
# 总结SASS的变化
sass_instruction = """
### 任务
根据量表的问题和答案,总结出两份量表之间的变化。
请列出明确的变化点,每个变化点单独一行,使用数字编号(1. 2. 3.)。
使用以下格式:
变化:
1. [第一个变化]
2. [第二个变化]
...
### 量表及问题
{}
### 第一份量表的答案
{}
### 第二份量表的答案
{}
""".format(sass_scale, scales['p_sass'], scales['sass'])
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": sass_instruction}],
temperature=0
)
sass_response = response.choices[0].message.content
sass_changes = extract_changes(sass_response)
return bdi_changes, ghq_changes, sass_changes
def summarize_scale_changes(scales):
client = OpenAI(
api_key=api_key,
base_url=base_url
)
# 获取量表变化
bdi_changes, ghq_changes, sass_changes = analyzing_changes(scales)
# 总结量表变化
summary_instruction = """
### 任务
根据量表的变化,总结患者的身体和心理状态变化。
请提供一个全面但简洁的总结,使用以下格式:
总结:
[总结内容]
### BDI量表变化
{}
### GHQ量表变化
{}
### SASS量表变化
{}
""".format(
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(bdi_changes)]),
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(ghq_changes)]),
'\n'.join([f"{i+1}. {change}" for i, change in enumerate(sass_changes)])
)
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": summary_instruction}],
temperature=0
)
summary_response = response.choices[0].message.content
status = extract_status(summary_response)
return status
# 额外增加一个更健壮的解析函数,可以处理不同格式的输出
def parse_response_robust(text, expected_format="list"):
"""更健壮的响应解析函数
参数:
text: 文本响应
expected_format: 预期格式,可以是"list"或"summary"
返回:
解析后的结果(列表或字符串)
"""
# 首先尝试JSON格式解析
try:
# 尝试提取JSON部分
json_pattern = r'\{[\s\S]*\}'
json_match = re.search(json_pattern, text)
if json_match:
json_data = json.loads(json_match.group(0))
if expected_format == "list" and "changes" in json_data:
return json_data["changes"]
elif expected_format == "summary" and "status" in json_data:
return json_data["status"]
except:
pass # 如果JSON解析失败,继续尝试其他方法
# 使用适当的提取函数
if expected_format == "list":
return extract_changes(text)
else: # summary
return extract_status(text)
# unit test
# if __name__ == "__main__":
# # 测试数据
# scales = {
# "p_bdi": ["A", "B", "C"],
# "bdi": ["B", "C", "D"],
# "p_ghq": ["A", "A", "B"],
# "ghq": ["B", "C", "C"],
# "p_sass": ["A", "B", "A"],
# "sass": ["C", "D", "B"]
# }
# changes = summarize_scale_changes(scales)
# print(changes)