Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import os | |
import json | |
import re | |
# 设置OpenAI API密钥和基础URL | |
api_key = os.getenv("OPENAI_API_KEY") | |
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") | |
def transform_chain(chain): | |
return {node["stage"]: node["content"] for node in chain} | |
def switch_complaint(chain, index, conversation, max_retries=3): | |
client = OpenAI(api_key=api_key, base_url=base_url) | |
transformed_chain = transform_chain(chain) | |
# 构建对话历史字符串(避免在f-string中使用反斜杠) | |
dialogue_lines = [] | |
for conv in conversation: | |
dialogue_lines.append(f"{conv['role']}: {conv['content']}") | |
dialogue_history = "\n".join(dialogue_lines) | |
# 使用三引号和多行字符串构建prompt | |
prompt = f""" | |
### 任务说明 | |
根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。 | |
### 输出要求 | |
必须严格使用以下JSON格式响应,且只包含指定字段: | |
{{"is_recognized": true/false}} | |
### 对话记录 | |
{dialogue_history} | |
### 主诉认知链 | |
{json.dumps(transformed_chain, ensure_ascii=False, indent=2)} | |
### 当前阶段(阶段{index}) | |
{transformed_chain[index]} | |
""" | |
attempts = 0 | |
while attempts < max_retries: | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0 | |
) | |
raw_output = response.choices[0].message.content.strip() | |
# 首先尝试直接解析JSON | |
try: | |
result = json.loads(raw_output) | |
if "is_recognized" in result: | |
if result["is_recognized"] and index >= len(chain) - 1: | |
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。") | |
return -1 | |
return index + 1 if result["is_recognized"] else index | |
except json.JSONDecodeError: | |
pass # 继续尝试正则表达式提取 | |
# 使用正则表达式作为备用解析方案 | |
match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)', | |
raw_output, re.IGNORECASE) | |
if match: | |
value = match.group(1) or match.group(2) | |
if value.lower() == 'true': | |
if index >= len(chain) - 1: | |
print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。") | |
return -1 | |
return index + 1 | |
else: | |
return index | |
print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}") | |
attempts += 1 | |
print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。") | |
return index | |
# # unit test | |
# if __name__ == "__main__": | |
# chain = [ | |
# {"stage": 1, "content": "我觉得我最近有点抑郁。"}, | |
# {"stage": 2, "content": "我觉得我最近有点焦虑。"}, | |
# {"stage": 3, "content": "我觉得我最近有点失眠。"}, | |
# {"stage": 4, "content": "我觉得我最近有点烦躁。"}, | |
# ] | |
# conversation = [ | |
# {"role": "Seeker", "content": "我觉得我最近有点抑郁。"}, | |
# {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"}, | |
# {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"}, | |
# ] | |
# # print("Transformed chain:", transform_chain(chain)) | |
# print("Switch complaint index:", switch_complaint(chain, 1, conversation)) | |
# print(switch_complaint(chain, 1, conversation)) |