sci-m-wang commited on
Commit
aaf12aa
·
verified ·
1 Parent(s): 32cfa60

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +149 -344
src/streamlit_app.py CHANGED
@@ -1,376 +1,181 @@
1
  import streamlit as st
2
- import json
3
- import pandas as pd
4
- from datetime import datetime
5
  import time
6
- from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # 导入你的MsPatient类和辅助函数
9
- try:
10
- from ms_patient import MsPatient # 请根据实际路径调整
11
- ANNA_AGENT_AVAILABLE = True
12
- except ImportError:
13
- ANNA_AGENT_AVAILABLE = False
14
-
15
- # 导入辅助函数
16
- from integration_example import (
17
- load_dataset, validate_patient_data, initialize_patient_agent,
18
- simulate_response, export_chat_history, get_patient_summary
19
- )
20
 
21
- # 页面配置
 
 
 
 
22
  st.set_page_config(
23
- page_title="AnnaAgent - 心理咨询智能体",
24
- page_icon="🧠",
25
- layout="wide",
26
- initial_sidebar_state="expanded"
27
  )
28
 
29
- # 自定义CSS样式
30
  st.markdown("""
31
  <style>
32
- .main-header {
33
- font-size: 2.5rem;
34
- font-weight: 700;
35
- color: #2E86AB;
36
- text-align: center;
37
- margin-bottom: 2rem;
38
- text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
39
- }
40
-
41
- .chat-container {
42
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
43
- border-radius: 15px;
44
- padding: 20px;
45
- margin: 10px 0;
46
- box-shadow: 0 8px 16px rgba(0,0,0,0.2);
47
- }
48
-
49
- .user-message {
50
- background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
51
- color: white;
52
- padding: 15px;
53
- border-radius: 15px;
54
- margin: 10px 0;
55
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
56
- float: right;
57
- clear: both;
58
- max-width: 70%;
59
- margin-left: 30%;
60
  }
61
-
62
- .assistant-message {
63
- background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
64
- color: #333;
65
- padding: 15px;
66
- border-radius: 15px;
67
- margin: 10px 0;
68
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
69
- float: left;
70
- clear: both;
71
- max-width: 70%;
72
- margin-right: 30%;
73
  }
74
-
75
- .sidebar-info {
76
- background: #f0f2f6;
77
- padding: 15px;
78
- border-radius: 10px;
79
- margin: 10px 0;
80
- border-left: 4px solid #2E86AB;
81
  }
82
-
83
- .metrics-container {
84
- display: flex;
85
- justify-content: space-around;
86
- margin: 20px 0;
87
  }
88
-
89
- .metric-card {
90
- background: white;
91
- padding: 20px;
92
- border-radius: 10px;
93
- text-align: center;
94
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
95
- min-width: 120px;
96
  }
97
-
98
- .stButton > button {
99
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
100
- color: white;
101
- border: none;
102
- border-radius: 25px;
103
- padding: 0.5rem 2rem;
104
  font-weight: 600;
105
- transition: all 0.3s ease;
106
- }
107
-
108
- .stButton > button:hover {
109
- transform: translateY(-2px);
110
- box-shadow: 0 6px 12px rgba(0,0,0,0.2);
111
  }
112
  </style>
113
  """, unsafe_allow_html=True)
114
 
115
- # 初始化session state
116
- if 'messages' not in st.session_state:
117
- st.session_state.messages = []
118
- if 'patient_agent' not in st.session_state:
119
- st.session_state.patient_agent = None
120
- if 'current_patient_data' not in st.session_state:
121
- st.session_state.current_patient_data = None
122
 
123
- # 标题
124
- st.markdown('<h1 class="main-header">🧠 AnnaAgent 心理咨询智能体</h1>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
125
 
126
- # 侧边栏 - 患者信息和控制
127
  with st.sidebar:
128
- st.markdown("## 📋 患者信息配置")
129
-
130
- # 数据集加载
131
- st.markdown("### 📁 数据集加载")
132
- uploaded_file = st.file_uploader(
133
- "上传Anna-CPsyCounD数据集文件",
134
- type=['json', 'jsonl'],
135
- help="请上传包含患者数据的JSON文件"
136
- )
137
-
138
- # 示例数据选择器(如果没有上传文件)
139
- if uploaded_file is None:
140
- st.markdown("### 🎭 示例患者选择")
141
- # 这里提供一些示例数据
142
- example_patients = {
143
- "示例患者1 - 抑郁症状": {
144
- "id": "example_001",
145
- "portrait": {
146
- "age": 28,
147
- "gender": "女",
148
- "occupation": "软件工程师",
149
- "marital_status": "未婚",
150
- "symptom": "抑郁情绪、失眠、食欲不振"
151
- },
152
- "report": {
153
- "chief_complaint": "最近几个月感到情绪低落",
154
- "history": "工作压力大,经常加班",
155
- "mental_status": "情绪低落,思维迟缓"
156
- },
157
- "conversation": []
158
- },
159
- "示例患者2 - 焦虑症状": {
160
- "id": "example_002",
161
- "portrait": {
162
- "age": 35,
163
- "gender": "男",
164
- "occupation": "销售经理",
165
- "marital_status": "已婚",
166
- "symptom": "焦虑、心慌、担心"
167
- },
168
- "report": {
169
- "chief_complaint": "经常感到紧张焦虑",
170
- "history": "业绩压力,家庭责任重",
171
- "mental_status": "焦虑不安,注意力不集中"
172
- },
173
- "conversation": []
174
- }
175
- }
176
-
177
- selected_example = st.selectbox(
178
- "选择一个示例患者",
179
- options=list(example_patients.keys())
180
- )
181
-
182
- if st.button("🚀 启动会话", type="primary"):
183
- patient_data = example_patients[selected_example]
184
- st.session_state.current_patient_data = patient_data
185
-
186
- # 初始化患者智能体
187
- if ANNA_AGENT_AVAILABLE:
188
- agent, message = initialize_patient_agent(patient_data)
189
- if agent:
190
- st.session_state.patient_agent = agent
191
- st.success(f"✅ 患者智能体已启动!{message}")
192
- else:
193
- st.error(f"❌ 启动失败: {message}")
194
- st.session_state.patient_agent = None
195
  else:
196
- st.warning("⚠️ 使用模拟模式")
197
- st.session_state.patient_agent = None
198
-
199
- st.rerun()
200
-
201
- else:
202
- # 处理上传的文件
203
- try:
204
- data = load_dataset(uploaded_file)
205
- if data and isinstance(data, list):
206
- patients_df = pd.DataFrame(data)
207
- st.markdown("### 👥 患者列表")
208
-
209
- # 显示患者信息表格
210
- display_df = patients_df[['id']].copy()
211
- if 'portrait' in patients_df.columns:
212
- display_df['年龄'] = patients_df['portrait'].apply(lambda x: x.get('age', 'N/A') if isinstance(x, dict) else 'N/A')
213
- display_df['性别'] = patients_df['portrait'].apply(lambda x: x.get('gender', 'N/A') if isinstance(x, dict) else 'N/A')
214
- display_df['症状'] = patients_df['portrait'].apply(lambda x: x.get('symptom', 'N/A') if isinstance(x, dict) else 'N/A')
215
-
216
- selected_id = st.selectbox("选择患者ID", options=patients_df['id'].tolist())
217
-
218
- if st.button("🚀 启动会话", type="primary"):
219
- selected_patient = patients_df[patients_df['id'] == selected_id].iloc[0].to_dict()
220
- st.session_state.current_patient_data = selected_patient
221
-
222
- # 初始化患者智能体
223
- if ANNA_AGENT_AVAILABLE:
224
- agent, message = initialize_patient_agent(selected_patient)
225
- if agent:
226
- st.session_state.patient_agent = agent
227
- st.success(f"✅ 患者智能体已启动!{message}")
228
- else:
229
- st.error(f"❌ 启动失败: {message}")
230
- st.session_state.patient_agent = None
231
- else:
232
- st.warning("⚠️ 使用模拟模式")
233
- st.session_state.patient_agent = None
234
-
235
- st.rerun()
236
-
237
- except Exception as e:
238
- st.error(f"❌ 文件加载失败: {str(e)})")
239
-
240
- # 分隔线
241
  st.markdown("---")
242
-
243
- # 当前患者信息显示
244
- if st.session_state.current_patient_data:
245
- st.markdown("### 👤 当前患者信息")
246
- patient_info = st.session_state.current_patient_data
247
-
248
- if 'portrait' in patient_info:
249
- portrait = patient_info['portrait']
250
- st.markdown(f"""
251
- <div class="sidebar-info">
252
- <strong>基本信息:</strong><br>
253
- • 年龄: {portrait.get('age', 'N/A')}<br>
254
- 性别: {portrait.get('gender', 'N/A')}<br>
255
- 职业: {portrait.get('occupation', 'N/A')}<br>
256
- 婚姻状态: {portrait.get('marital_status', 'N/A')}<br>
257
- 主要症状: {portrait.get('symptom', 'N/A')}
258
- </div>
259
- """, unsafe_allow_html=True)
260
-
261
- # 控制按钮
262
- st.markdown("### 🎮 会话控制")
263
-
264
- col1, col2 = st.columns(2)
265
- with col1:
266
- if st.button("🔄 重置会话"):
267
- st.session_state.messages = []
268
- st.rerun()
269
-
270
- with col2:
271
- if st.button("💾 导出记录"):
272
- if st.session_state.messages:
273
- chat_history = {
274
- "patient_id": st.session_state.current_patient_data.get('id', 'unknown') if st.session_state.current_patient_data else 'unknown',
275
- "timestamp": datetime.now().isoformat(),
276
- "messages": st.session_state.messages
277
- }
278
- st.download_button(
279
- label="📥 下载聊天记录",
280
- data=json.dumps(chat_history, ensure_ascii=False, indent=2),
281
- file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
282
- mime="application/json"
283
  )
 
 
284
 
285
- # 主聊天界面
286
- if st.session_state.current_patient_data is None:
287
- st.markdown("""
288
- <div style='text-align: center; padding: 50px;'>
289
- <h3>🎯 欢迎使用AnnaAgent心理咨询智能体</h3>
290
- <p>请在左侧面板选择或上传患者数据以开始对话</p>
291
- </div>
292
- """, unsafe_allow_html=True)
293
- else:
294
- # 显示聊天历史
295
- chat_container = st.container()
296
-
297
- with chat_container:
298
- for message in st.session_state.messages:
299
- if message["role"] == "user":
300
- st.markdown(f"""
301
- <div class="user-message">
302
- <strong>咨询师:</strong> {message["content"]}
303
- </div>
304
- """, unsafe_allow_html=True)
305
- else:
306
- st.markdown(f"""
307
- <div class="assistant-message">
308
- <strong>来访者:</strong> {message["content"]}
309
- </div>
310
- """, unsafe_allow_html=True)
311
-
312
- # 输入框
313
- with st.form(key="chat_form", clear_on_submit=True):
314
- col1, col2 = st.columns([4, 1])
315
-
316
- with col1:
317
- user_input = st.text_area(
318
- "输入您的话语:",
319
- placeholder="请输入您想对来访者说的话...",
320
- height=100,
321
- key="user_input"
322
- )
323
-
324
- with col2:
325
- st.markdown("<br>", unsafe_allow_html=True) # 增加一些间距
326
- submit_button = st.form_submit_button("💬 发送", type="primary")
327
-
328
- if submit_button and user_input:
329
- # 添加用户消息到聊天历史
330
- st.session_state.messages.append({"role": "user", "content": user_input})
331
 
332
- # 这里调用AnnaAgent生成回复
333
- with st.spinner("🤔 来访者正在思考回复..."):
334
- try:
335
- # 模拟调用(你需要替换为实际的调用)
336
- # if st.session_state.patient_agent:
337
- # response = st.session_state.patient_agent.chat(user_input)
338
- # else:
339
- # response = "抱歉,智能体尚未初始化。"
340
-
341
- # 临时模拟回复
342
- time.sleep(1) # 模拟思考时间
343
- response = f"谢谢您的关心。您说'{user_input}',这让我想到..."
344
-
345
- # 添加助手回复到聊天历史
346
- st.session_state.messages.append({"role": "assistant", "content": response})
347
-
348
- except Exception as e:
349
- st.error(f"❌ 生成回复时出错: {str(e)}")
350
- st.session_state.messages.append({
351
- "role": "assistant",
352
- "content": "抱歉,我现在无法回复。请稍后再试。"
353
- })
354
-
355
- st.rerun()
356
 
357
- # 底部信息
358
- st.markdown("---")
359
- col1, col2, col3 = st.columns(3)
360
 
361
- with col1:
362
- if st.session_state.messages:
363
- st.metric("💬 对话轮数", len(st.session_state.messages) // 2)
364
 
365
- with col2:
366
- if st.session_state.current_patient_data:
367
- st.metric("👤 当前患者", st.session_state.current_patient_data.get('id', 'Unknown'))
368
 
369
- with col3:
370
- st.metric("🕐 会话时间", f"{datetime.now().strftime('%H:%M:%S')}")
 
 
 
371
 
372
- st.markdown("""
373
- <div style='text-align: center; color: #666; margin-top: 20px;'>
374
- <small>AnnaAgent - 基于大语言模型的心理咨询智能体 | 仅供学术研究使用</small>
375
- </div>
376
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from openai import OpenAI
 
 
3
  import time
4
+ from datasets import load_dataset # 导入Hugging Face datasets库
5
+
6
+ # =======================================================================
7
+ # 1. 导入您的类和数据
8
+ # =======================================================================
9
+
10
+ # 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名
11
+ from ms_patient import MsPatient
12
+
13
+ # 使用Streamlit的缓存功能来加载和缓存数据集
14
+ @st.cache_data
15
+ def load_hf_dataset():
16
+ """
17
+ 从Hugging Face Hub加载并缓存数据集。
18
+ 这将返回一个字典列表,每个字典代表一个病人数据。
19
+ """
20
+ try:
21
+ # 加载'train'分割部分的数据
22
+ dataset = load_dataset("sci-m-wang/Anna-CPsyCounD", split='train')
23
+ # 转换为pandas DataFrame再转为字典列表,方便处理
24
+ return dataset.to_pandas().to_dict('records')
25
+ except Exception as e:
26
+ st.error(f"从Hugging Face加载数据集失败: {e}")
27
+ return []
28
+
29
+ # 加载数据
30
+ ALL_PATIENTS = load_hf_dataset()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # =======================================================================
34
+ # 2. Streamlit 应用界面
35
+ # =======================================================================
36
+
37
+ # --- 页面配置 ---
38
  st.set_page_config(
39
+ page_title="与Anna对话",
40
+ page_icon="",
41
+ layout="wide"
 
42
  )
43
 
44
+ # --- 自定义CSS样式 ---
45
  st.markdown("""
46
  <style>
47
+ /* 主聊天容器 */
48
+ .st-emotion-cache-1y4p8pa {
49
+ padding-top: 2rem;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  }
51
+ /* 聊天消息 */
52
+ .st-chat-message {
53
+ border-radius: 0.8rem;
54
+ padding: 0.9rem 1.2rem;
55
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
56
+ background-color: #ffffff;
 
 
 
 
 
 
57
  }
58
+ .st-chat-message[data-testid="chat-message-container-user"] {
59
+ background-color: #dcf8c6;
 
 
 
 
 
60
  }
61
+ /* 侧边栏 */
62
+ .st-sidebar {
63
+ background-color: #f8f9fa;
64
+ border-right: 1px solid #e9ecef;
 
65
  }
66
+ .st-sidebar h2 {
67
+ color: #343a40;
 
 
 
 
 
 
68
  }
69
+ .st-expanderHeader {
70
+ font-size: 1.1rem;
 
 
 
 
 
71
  font-weight: 600;
 
 
 
 
 
 
72
  }
73
  </style>
74
  """, unsafe_allow_html=True)
75
 
 
 
 
 
 
 
 
76
 
77
+ # --- 初始化 Session State ---
78
+ if "patient_agent" not in st.session_state:
79
+ st.session_state.patient_agent = None
80
+ if "messages" not in st.session_state:
81
+ st.session_state.messages = []
82
+ if "selected_patient_id" not in st.session_state:
83
+ st.session_state.selected_patient_id = None
84
+ if "openai_client" not in st.session_state:
85
+ st.session_state.openai_client = None
86
+ if "model_name" not in st.session_state:
87
+ st.session_state.model_name = "gpt-4o-mini" # 默认模型
88
 
89
+ # --- 侧边栏 ---
90
  with st.sidebar:
91
+ st.title("👩 AnnaAgent 设置")
92
+ st.markdown("---")
93
+
94
+ # API Key 输入
95
+ with st.expander("🔑 API 设置", expanded=True):
96
+ api_key = st.text_input("输入您的 OpenAI API Key", type="password", help="您的API Key将仅用于本次会话,不会被储存。")
97
+ base_url = st.text_input("API Base URL (可选)", value="https://api.openai.com/v1")
98
+ model_name = st.text_input("模型名称", value=st.session_state.model_name)
99
+
100
+ if st.button("连接模型"):
101
+ if api_key:
102
+ try:
103
+ st.session_state.openai_client = OpenAI(api_key=api_key, base_url=base_url)
104
+ st.session_state.model_name = model_name
105
+ st.success("连接成功!")
106
+ if st.session_state.patient_agent:
107
+ st.session_state.patient_agent.client = st.session_state.openai_client
108
+ except Exception as e:
109
+ st.error(f"连接失败: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  else:
111
+ st.warning("请输入API Key。")
112
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  st.markdown("---")
114
+
115
+ # 病人选择
116
+ if not ALL_PATIENTS:
117
+ st.error("无法加载病人数据,请检查网络连接或数据集名称。")
118
+ else:
119
+ patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptom']}" for p in ALL_PATIENTS}
120
+ selected_id = st.selectbox(
121
+ "选择一位病人进行对话",
122
+ options=list(patient_options.keys()),
123
+ format_func=lambda x: patient_options[x]
124
+ )
125
+
126
+ # 当选择的病人变化时,重置状态
127
+ if st.session_state.selected_patient_id != selected_id:
128
+ st.session_state.selected_patient_id = selected_id
129
+ selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None)
130
+
131
+ with st.spinner("正在生成病人角色..."):
132
+ st.session_state.patient_agent = MsPatient(
133
+ portrait=selected_patient_data["portrait"],
134
+ report=selected_patient_data["report"],
135
+ previous_conversations=selected_patient_data["conversation"],
136
+ language="Chinese",
137
+ client=st.session_state.openai_client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
+ st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}]
140
+ st.rerun()
141
 
142
+ # 显示病人信息
143
+ if st.session_state.patient_agent:
144
+ st.markdown("---")
145
+ st.subheader("病人信息")
146
+ agent = st.session_state.patient_agent
147
+ st.info(f"""
148
+ **基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ **近期状态**: {agent.status}
151
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ with st.expander("查看完整系统提示 (System Prompt)"):
154
+ st.code(agent.get_system_prompt(), language='markdown')
 
155
 
 
 
 
156
 
157
+ # --- 主聊天界面 ---
158
+ st.title("💬 与 Anna 对话")
159
+ st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。")
160
 
161
+ # 显示聊天记录
162
+ for message in st.session_state.messages:
163
+ avatar = "👩" if message["role"] == "assistant" else "🧑‍⚕️"
164
+ with st.chat_message(message["role"], avatar=avatar):
165
+ st.markdown(message["content"])
166
 
167
+ # 获取用户输入
168
+ if prompt := st.chat_input("请输入您想说的话..."):
169
+ st.session_state.messages.append({"role": "user", "content": prompt})
170
+ with st.chat_message("user", avatar="🧑‍⚕️"):
171
+ st.markdown(prompt)
172
+
173
+ if st.session_state.patient_agent:
174
+ with st.chat_message("assistant", avatar="👩"):
175
+ with st.spinner("Anna正在思考..."):
176
+ response = st.session_state.patient_agent.chat(prompt)
177
+ st.markdown(response)
178
+
179
+ st.session_state.messages.append({"role": "assistant", "content": response})
180
+ else:
181
+ st.warning("请先在左侧选择一位病人并配置API Key。")