AnnaAgent-Demo / src /complaint_elicitor.py
sci-m-wang's picture
Upload 14 files
1d4c295 verified
raw
history blame
3.84 kB
from openai import OpenAI
import os
import json
import re
# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
def transform_chain(chain):
return {node["stage"]: node["content"] for node in chain}
def switch_complaint(chain, index, conversation, max_retries=3):
client = OpenAI(api_key=api_key, base_url=base_url)
transformed_chain = transform_chain(chain)
# 构建对话历史字符串(避免在f-string中使用反斜杠)
dialogue_lines = []
for conv in conversation:
dialogue_lines.append(f"{conv['role']}: {conv['content']}")
dialogue_history = "\n".join(dialogue_lines)
# 使用三引号和多行字符串构建prompt
prompt = f"""
### 任务说明
根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。
### 输出要求
必须严格使用以下JSON格式响应,且只包含指定字段:
{{"is_recognized": true/false}}
### 对话记录
{dialogue_history}
### 主诉认知链
{json.dumps(transformed_chain, ensure_ascii=False, indent=2)}
### 当前阶段(阶段{index}
{transformed_chain[index]}
"""
attempts = 0
while attempts < max_retries:
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=0
)
raw_output = response.choices[0].message.content.strip()
# 首先尝试直接解析JSON
try:
result = json.loads(raw_output)
if "is_recognized" in result:
if result["is_recognized"] and index >= len(chain) - 1:
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
return -1
return index + 1 if result["is_recognized"] else index
except json.JSONDecodeError:
pass # 继续尝试正则表达式提取
# 使用正则表达式作为备用解析方案
match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)',
raw_output, re.IGNORECASE)
if match:
value = match.group(1) or match.group(2)
if value.lower() == 'true':
if index >= len(chain) - 1:
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
return -1
return index + 1
else:
return index
print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}")
attempts += 1
print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。")
return index
# # unit test
# if __name__ == "__main__":
# chain = [
# {"stage": 1, "content": "我觉得我最近有点抑郁。"},
# {"stage": 2, "content": "我觉得我最近有点焦虑。"},
# {"stage": 3, "content": "我觉得我最近有点失眠。"},
# {"stage": 4, "content": "我觉得我最近有点烦躁。"},
# ]
# conversation = [
# {"role": "Seeker", "content": "我觉得我最近有点抑郁。"},
# {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"},
# {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"},
# ]
# # print("Transformed chain:", transform_chain(chain))
# print("Switch complaint index:", switch_complaint(chain, 1, conversation))
# print(switch_complaint(chain, 1, conversation))