File size: 3,842 Bytes
1d4c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from openai import OpenAI
import os
import json
import re

# 设置OpenAI API密钥和基础URL
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")

def transform_chain(chain):
    return {node["stage"]: node["content"] for node in chain}

def switch_complaint(chain, index, conversation, max_retries=3):
    client = OpenAI(api_key=api_key, base_url=base_url)
    transformed_chain = transform_chain(chain)
    
    # 构建对话历史字符串(避免在f-string中使用反斜杠)
    dialogue_lines = []
    for conv in conversation:
        dialogue_lines.append(f"{conv['role']}: {conv['content']}")
    dialogue_history = "\n".join(dialogue_lines)

    # 使用三引号和多行字符串构建prompt
    prompt = f"""
    ### 任务说明
    根据患者情况及咨访对话历史记录,判断患者当前阶段的主诉问题是否已经得到解决。

    ### 输出要求
    必须严格使用以下JSON格式响应,且只包含指定字段:
    {{"is_recognized": true/false}}

    ### 对话记录
    {dialogue_history}

    ### 主诉认知链
    {json.dumps(transformed_chain, ensure_ascii=False, indent=2)}

    ### 当前阶段(阶段{index}
    {transformed_chain[index]}
    """
    
    attempts = 0
    while attempts < max_retries:
        response = client.chat.completions.create(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        raw_output = response.choices[0].message.content.strip()
        
        # 首先尝试直接解析JSON
        try:
            result = json.loads(raw_output)
            if "is_recognized" in result:
                if result["is_recognized"] and index >= len(chain) - 1:
                    print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
                    return -1
                return index + 1 if result["is_recognized"] else index
        except json.JSONDecodeError:
            pass  # 继续尝试正则表达式提取

        # 使用正则表达式作为备用解析方案
        match = re.search(r'"is_recognized"\s*:\s*(true|false)|is_recognized\s*:\s*(true|false)', 
                          raw_output, re.IGNORECASE)
        if match:
            value = match.group(1) or match.group(2)
            if value.lower() == 'true':
                if index >= len(chain) - 1:
                    print("警告:当前阶段已被识别为解决,但没有更多阶段可供切换。")
                    return -1
                return index + 1
            else:
                return index
        
        print(f"第 {attempts+1} 次尝试:无法解析模型输出。原始输出:\n{raw_output}")
        attempts += 1

    print("警告:重试次数达到上限,无法解析模型输出,返回当前阶段。")
    return index
    
# # unit test
# if __name__ == "__main__":
#     chain = [
#         {"stage": 1, "content": "我觉得我最近有点抑郁。"},
#         {"stage": 2, "content": "我觉得我最近有点焦虑。"},
#         {"stage": 3, "content": "我觉得我最近有点失眠。"},
#         {"stage": 4, "content": "我觉得我最近有点烦躁。"},
#     ]
#     conversation = [
#         {"role": "Seeker", "content": "我觉得我最近有点抑郁。"},
#         {"role": "Counselor", "content": "你觉得是什么原因导致你感到抑郁呢?"},
#         {"role": "Seeker", "content": "我也不知道,可能是工作压力吧。"},
#     ]
#     # print("Transformed chain:", transform_chain(chain))
#     print("Switch complaint index:", switch_complaint(chain, 1, conversation))
#     print(switch_complaint(chain, 1, conversation))