File size: 6,673 Bytes
d720282
aaf12aa
32cfa60
aaf12aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32cfa60
 
aaf12aa
 
 
 
 
32cfa60
aaf12aa
 
 
32cfa60
 
aaf12aa
32cfa60
 
aaf12aa
 
 
32cfa60
aaf12aa
 
 
 
 
 
32cfa60
aaf12aa
 
32cfa60
aaf12aa
 
 
 
32cfa60
aaf12aa
 
32cfa60
aaf12aa
 
32cfa60
 
 
 
 
 
aaf12aa
 
 
 
 
 
 
 
 
 
 
32cfa60
aaf12aa
32cfa60
aaf12aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32cfa60
aaf12aa
 
32cfa60
aaf12aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32cfa60
aaf12aa
 
32cfa60
aaf12aa
 
 
 
 
 
 
32cfa60
aaf12aa
 
32cfa60
aaf12aa
 
32cfa60
 
aaf12aa
 
 
32cfa60
aaf12aa
 
 
 
 
d720282
aaf12aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import streamlit as st
from openai import OpenAI
import time
from datasets import load_dataset # 导入Hugging Face datasets库

# =======================================================================
# 1. 导入您的类和数据
# =======================================================================

# 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名
from ms_patient import MsPatient 

# 使用Streamlit的缓存功能来加载和缓存数据集
@st.cache_data
def load_hf_dataset():
    """
    从Hugging Face Hub加载并缓存数据集。
    这将返回一个字典列表,每个字典代表一个病人数据。
    """
    try:
        # 加载'train'分割部分的数据
        dataset = load_dataset("sci-m-wang/Anna-CPsyCounD", split='train')
        # 转换为pandas DataFrame再转为字典列表,方便处理
        return dataset.to_pandas().to_dict('records')
    except Exception as e:
        st.error(f"从Hugging Face加载数据集失败: {e}")
        return []

# 加载数据
ALL_PATIENTS = load_hf_dataset()


# =======================================================================
# 2. Streamlit 应用界面
# =======================================================================

# --- 页面配置 ---
st.set_page_config(
    page_title="与Anna对话",
    page_icon="�",
    layout="wide"
)

# --- 自定义CSS样式 ---
st.markdown("""
<style>
    /* 主聊天容器 */
    .st-emotion-cache-1y4p8pa { 
        padding-top: 2rem;
    }
    /* 聊天消息 */
    .st-chat-message {
        border-radius: 0.8rem;
        padding: 0.9rem 1.2rem;
        box-shadow: 0 2px 5px rgba(0,0,0,0.05);
        background-color: #ffffff;
    }
    .st-chat-message[data-testid="chat-message-container-user"] {
        background-color: #dcf8c6;
    }
    /* 侧边栏 */
    .st-sidebar {
        background-color: #f8f9fa;
        border-right: 1px solid #e9ecef;
    }
    .st-sidebar h2 {
        color: #343a40;
    }
    .st-expanderHeader {
        font-size: 1.1rem;
        font-weight: 600;
    }
</style>
""", unsafe_allow_html=True)


# --- 初始化 Session State ---
if "patient_agent" not in st.session_state:
    st.session_state.patient_agent = None
if "messages" not in st.session_state:
    st.session_state.messages = []
if "selected_patient_id" not in st.session_state:
    st.session_state.selected_patient_id = None
if "openai_client" not in st.session_state:
    st.session_state.openai_client = None
if "model_name" not in st.session_state:
    st.session_state.model_name = "gpt-4o-mini" # 默认模型

# --- 侧边栏 ---
with st.sidebar:
    st.title("👩 AnnaAgent 设置")
    st.markdown("---")

    # API Key 输入
    with st.expander("🔑 API 设置", expanded=True):
        api_key = st.text_input("输入您的 OpenAI API Key", type="password", help="您的API Key将仅用于本次会话,不会被储存。")
        base_url = st.text_input("API Base URL (可选)", value="https://api.openai.com/v1")
        model_name = st.text_input("模型名称", value=st.session_state.model_name)

        if st.button("连接模型"):
            if api_key:
                try:
                    st.session_state.openai_client = OpenAI(api_key=api_key, base_url=base_url)
                    st.session_state.model_name = model_name
                    st.success("连接成功!")
                    if st.session_state.patient_agent:
                        st.session_state.patient_agent.client = st.session_state.openai_client
                except Exception as e:
                    st.error(f"连接失败: {e}")
            else:
                st.warning("请输入API Key。")

    st.markdown("---")

    # 病人选择
    if not ALL_PATIENTS:
        st.error("无法加载病人数据,请检查网络连接或数据集名称。")
    else:
        patient_options = {p["id"]: f"{p['portrait']['gender']}{p['portrait']['age']}岁 - {p['portrait']['symptom']}" for p in ALL_PATIENTS}
        selected_id = st.selectbox(
            "选择一位病人进行对话",
            options=list(patient_options.keys()),
            format_func=lambda x: patient_options[x]
        )

        # 当选择的病人变化时,重置状态
        if st.session_state.selected_patient_id != selected_id:
            st.session_state.selected_patient_id = selected_id
            selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None)
            
            with st.spinner("正在生成病人角色..."):
                st.session_state.patient_agent = MsPatient(
                    portrait=selected_patient_data["portrait"],
                    report=selected_patient_data["report"],
                    previous_conversations=selected_patient_data["conversation"],
                    language="Chinese",
                    client=st.session_state.openai_client
                )
            st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}]
            st.rerun()

    # 显示病人信息
    if st.session_state.patient_agent:
        st.markdown("---")
        st.subheader("病人信息")
        agent = st.session_state.patient_agent
        st.info(f"""
        **基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']}
        
        **近期状态**: {agent.status}
        """)

        with st.expander("查看完整系统提示 (System Prompt)"):
            st.code(agent.get_system_prompt(), language='markdown')


# --- 主聊天界面 ---
st.title("💬 与 Anna 对话")
st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。")

# 显示聊天记录
for message in st.session_state.messages:
    avatar = "👩" if message["role"] == "assistant" else "🧑‍⚕️"
    with st.chat_message(message["role"], avatar=avatar):
        st.markdown(message["content"])

# 获取用户输入
if prompt := st.chat_input("请输入您想说的话..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user", avatar="🧑‍⚕️"):
        st.markdown(prompt)

    if st.session_state.patient_agent:
        with st.chat_message("assistant", avatar="👩"):
            with st.spinner("Anna正在思考..."):
                response = st.session_state.patient_agent.chat(prompt)
                st.markdown(response)
        
        st.session_state.messages.append({"role": "assistant", "content": response})
    else:
        st.warning("请先在左侧选择一位病人并配置API Key。")