from openai import OpenAI import os import json import re # 设置OpenAI API密钥和基础URL api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") def transform_chain(chain): return {node["stage"]: node["content"] for node in chain} def switch_complaint(chain, index, conversation, max_retries=3): client = OpenAI(api_key=api_key, base_url=base_url) transformed_chain = transform_chain(chain) # 构建对话历史字符串(避免在f-string中使用反斜杠) dialogue_lines = [] for conv in conversation: dialogue_lines.append(f"{conv['role']}: {conv['content']}") dialogue_history = "\n".join(dialogue_lines) # 使用三引号和多行字符串构建prompt prompt = f""" ### 任务说明 根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。 ### 输出要求 必须严格使用以下JSON格式响应,且只包含指定字段: {{"is_recognized": true/false}} ### 对话记录 {dialogue_history} ### 主诉认知链 {json.dumps(transformed_chain, ensure_ascii=False, indent=2)} ### 当前阶段(阶段{index}) {transformed_chain[index]} """ attempts = 0 while attempts < max_retries: response = client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": prompt}], temperature=0 ) raw_output = response.choices[0].message.content.strip() # 首先尝试直接解析JSON try: result = json.loads(raw_output) if "is_recognized" in result: if result["is_recognized"] and index >= len(chain) - 1: print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。") return -1 return index + 1 if result["is_recognized"] else index except json.JSONDecodeError: pass # 继续尝试正则表达式提取 # 使用正则表达式作为备用解析方案 match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)', raw_output, re.IGNORECASE) if match: value = match.group(1) or match.group(2) if value.lower() == 'true': if index >= len(chain) - 1: print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。") return -1 return index + 1 else: return index print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}") attempts += 1 print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。") return index # # unit test # if __name__ == "__main__": # chain = [ # {"stage": 1, "content": "我觉得我最近有点抑郁。"}, # {"stage": 2, "content": "我觉得我最近有点焦虑。"}, # {"stage": 3, "content": "我觉得我最近有点失眠。"}, # {"stage": 4, "content": "我觉得我最近有点烦躁。"}, # ] # conversation = [ # {"role": "Seeker", "content": "我觉得我最近有点抑郁。"}, # {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"}, # {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"}, # ] # # print("Transformed chain:", transform_chain(chain)) # print("Switch complaint index:", switch_complaint(chain, 1, conversation)) # print(switch_complaint(chain, 1, conversation))