Spaces:
Sleeping
Sleeping
File size: 6,968 Bytes
d720282 aaf12aa 32cfa60 2a517fb aaf12aa fb22eb4 aaf12aa 2a517fb aaf12aa 2a517fb aaf12aa 2a517fb aaf12aa 2a517fb aaf12aa 2a517fb 32cfa60 aaf12aa 32cfa60 aaf12aa 2a517fb aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 2a517fb aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa 32cfa60 aaf12aa d720282 aaf12aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import streamlit as st
from openai import OpenAI
import time
import json # 导入json库
# =======================================================================
# 1. 导入您的类和数据
# =======================================================================
# 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名
from ms_patient import MsPatient
# 从本地JSON文件加载数据集的函数
def load_data_from_json(filepath="Anna-CPsyCounD.json"):
"""
从本地的JSON文件加载数据集。
请确保您已将数据集文件上传到与此应用相同的目录中。
"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
# 假设JSON文件的根是一个包含病人记录的列表
return json.load(f)
except FileNotFoundError:
st.error(f"错误:找不到数据文件 '{filepath}'。请确保您已将该文件上传到Hugging Face Space。")
return []
except json.JSONDecodeError:
st.error(f"错误:无法解析 '{filepath}'。请检查文件是否为有效的JSON格式。")
return []
except Exception as e:
st.error(f"加载数据时发生未知错误: {e}")
return []
# 加载数据
# 注意:现在每次脚本重新运行时都会从本地JSON文件加载
ALL_PATIENTS = load_data_from_json()
# =======================================================================
# 2. Streamlit 应用界面
# =======================================================================
# --- 页面配置 ---
st.set_page_config(
page_title="与Anna对话",
page_icon="�",
layout="wide"
)
# --- 自定义CSS样式 ---
st.markdown("""
<style>
/* 主聊天容器 */
.st-emotion-cache-1y4p8pa {
padding-top: 2rem;
}
/* 聊天消息 */
.st-chat-message {
border-radius: 0.8rem;
padding: 0.9rem 1.2rem;
box-shadow: 0 2px 5px rgba(0,0,0,0.05);
background-color: #ffffff;
}
.st-chat-message[data-testid="chat-message-container-user"] {
background-color: #dcf8c6;
}
/* 侧边栏 */
.st-sidebar {
background-color: #f8f9fa;
border-right: 1px solid #e9ecef;
}
.st-sidebar h2 {
color: #343a40;
}
.st-expanderHeader {
font-size: 1.1rem;
font-weight: 600;
}
</style>
""", unsafe_allow_html=True)
# --- 初始化 Session State ---
if "patient_agent" not in st.session_state:
st.session_state.patient_agent = None
if "messages" not in st.session_state:
st.session_state.messages = []
if "selected_patient_id" not in st.session_state:
st.session_state.selected_patient_id = None
if "openai_client" not in st.session_state:
st.session_state.openai_client = None
if "model_name" not in st.session_state:
st.session_state.model_name = "gpt-4o-mini" # 默认模型
# --- 侧边栏 ---
with st.sidebar:
st.title("👩 AnnaAgent 设置")
st.markdown("---")
# API Key 输入
with st.expander("🔑 API 设置", expanded=True):
api_key = st.text_input("输入您的 OpenAI API Key", type="password", help="您的API Key将仅用于本次会话,不会被储存。")
base_url = st.text_input("API Base URL (可选)", value="https://api.openai.com/v1")
model_name = st.text_input("模型名称", value=st.session_state.model_name)
if st.button("连接模型"):
if api_key:
try:
st.session_state.openai_client = OpenAI(api_key=api_key, base_url=base_url)
st.session_state.model_name = model_name
st.success("连接成功!")
if st.session_state.patient_agent:
st.session_state.patient_agent.client = st.session_state.openai_client
except Exception as e:
st.error(f"连接失败: {e}")
else:
st.warning("请输入API Key。")
st.markdown("---")
# 病人选择
if not ALL_PATIENTS:
st.error("无法加载病人数据。请检查JSON文件是否已上传且格式正确。")
else:
patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptom']}" for p in ALL_PATIENTS}
selected_id = st.selectbox(
"选择一位病人进行对话",
options=list(patient_options.keys()),
format_func=lambda x: patient_options[x]
)
# 当选择的病人变化时,重置状态
if st.session_state.selected_patient_id != selected_id:
st.session_state.selected_patient_id = selected_id
selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None)
with st.spinner("正在生成病人角色..."):
st.session_state.patient_agent = MsPatient(
portrait=selected_patient_data["portrait"],
report=selected_patient_data["report"],
previous_conversations=selected_patient_data["conversation"],
language="Chinese",
client=st.session_state.openai_client
)
st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}]
st.rerun()
# 显示病人信息
if st.session_state.patient_agent:
st.markdown("---")
st.subheader("病人信息")
agent = st.session_state.patient_agent
st.info(f"""
**基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']}
**近期状态**: {agent.status}
""")
with st.expander("查看完整系统提示 (System Prompt)"):
st.code(agent.get_system_prompt(), language='markdown')
# --- 主聊天界面 ---
st.title("💬 与 Anna 对话")
st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。")
# 显示聊天记录
for message in st.session_state.messages:
avatar = "👩" if message["role"] == "assistant" else "🧑⚕️"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# 获取用户输入
if prompt := st.chat_input("请输入您想说的话..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="🧑⚕️"):
st.markdown(prompt)
if st.session_state.patient_agent:
with st.chat_message("assistant", avatar="👩"):
with st.spinner("Anna正在思考..."):
response = st.session_state.patient_agent.chat(prompt)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
else:
st.warning("请先在左侧选择一位病人并配置API Key。") |