Spaces:
Sleeping
Sleeping
import streamlit as st | |
from openai import OpenAI | |
import time | |
import json # 导入json库 | |
import os # 导入os库用于读取环境变量 | |
# ======================================================================= | |
# 1. 导入您的类和数据 | |
# ======================================================================= | |
# 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名 | |
from ms_patient import MsPatient | |
# 从本地JSON文件加载数据集的函数 | |
def load_data_from_json(filepath="/app/src/CPsyCounS-3134.json"): | |
""" | |
从本地的JSON文件加载数据集。 | |
请确保您已将数据集文件上传到与此应用相同的目录中。 | |
""" | |
try: | |
with open(filepath, 'r', encoding='utf-8') as f: | |
# 假设JSON文件的根是一个包含病人记录的列表 | |
return json.load(f) | |
except FileNotFoundError: | |
st.error(f"错误:找不到数据文件 '{filepath}'。请确保您已将该文件上传到Hugging Face Space。") | |
return [] | |
except json.JSONDecodeError: | |
st.error(f"错误:无法解析 '{filepath}'。请检查文件是否为有效的JSON格式。") | |
return [] | |
except Exception as e: | |
st.error(f"加载数据时发生未知错误: {e}") | |
return [] | |
# 加载数据 | |
ALL_PATIENTS = load_data_from_json() | |
# ======================================================================= | |
# 2. Streamlit 应用界面 | |
# ======================================================================= | |
# --- 页面配置 --- | |
st.set_page_config( | |
page_title="与Anna对话", | |
page_icon="�", | |
layout="wide" | |
) | |
# --- 自定义CSS样式 --- | |
st.markdown(""" | |
<style> | |
/* 主聊天容器 */ | |
.st-emotion-cache-1y4p8pa { | |
padding-top: 2rem; | |
} | |
/* 聊天消息 */ | |
.st-chat-message { | |
border-radius: 0.8rem; | |
padding: 0.9rem 1.2rem; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.05); | |
background-color: #ffffff; | |
} | |
.st-chat-message[data-testid="chat-message-container-user"] { | |
background-color: #dcf8c6; | |
} | |
/* 侧边栏 */ | |
.st-sidebar { | |
background-color: #f8f9fa; | |
border-right: 1px solid #e9ecef; | |
} | |
.st-sidebar h2 { | |
color: #343a40; | |
} | |
.st-expanderHeader { | |
font-size: 1.1rem; | |
font-weight: 600; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# --- 初始化 Session State 和 OpenAI Client --- | |
# 仅在会话状态中不存在时,才从环境变量初始化客户端 | |
if "openai_client" not in st.session_state: | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
st.session_state.openai_client = None | |
st.session_state.model_name = None | |
else: | |
try: | |
st.session_state.openai_client = OpenAI(api_key=api_key, base_url=os.getenv("OPENAI_BASE_URL")) | |
st.session_state.model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini") | |
except Exception as e: | |
st.error(f"初始化OpenAI客户端失败: {e}") | |
st.session_state.openai_client = None | |
st.session_state.model_name = None | |
if "patient_agent" not in st.session_state: | |
st.session_state.patient_agent = None | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "selected_patient_id" not in st.session_state: | |
st.session_state.selected_patient_id = None | |
# --- 侧边栏 --- | |
with st.sidebar: | |
st.title("👩 AnnaAgent 设置") | |
st.markdown("---") | |
# 病人选择 | |
if not ALL_PATIENTS: | |
st.error("无法加载病人数据。请检查JSON文件是否已上传且格式正确。") | |
else: | |
patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptoms']}" for p in ALL_PATIENTS} | |
selected_id = st.selectbox( | |
"选择一位病人进行对话", | |
options=list(patient_options.keys()), | |
format_func=lambda x: patient_options[x] | |
) | |
# 当选择的病人变化时,重置状态 | |
if st.session_state.selected_patient_id != selected_id: | |
st.session_state.selected_patient_id = selected_id | |
selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None) | |
with st.spinner("正在生成病人角色..."): | |
st.session_state.patient_agent = MsPatient( | |
portrait=selected_patient_data["portrait"], | |
report=selected_patient_data["report"], | |
previous_conversations=selected_patient_data["conversation"], | |
language="Chinese" | |
) | |
st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}] | |
st.rerun() | |
# 显示病人信息 | |
if st.session_state.patient_agent: | |
st.markdown("---") | |
st.subheader("病人信息") | |
agent = st.session_state.patient_agent | |
st.info(f""" | |
**基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']} | |
**近期状态**: {agent.status} | |
""") | |
with st.expander("查看完整系统提示 (System Prompt)"): | |
st.code(agent.get_system_prompt(), language='markdown') | |
# --- 主聊天界面 --- | |
st.title("💬 与 Anna 对话") | |
st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。") | |
# 检查API Key是否已在后台设置 | |
if not st.session_state.openai_client: | |
st.error("后台未设置 OPENAI_API_KEY。请在Hugging Face Space的'Settings' -> 'Secrets'中进行设置后刷新页面。") | |
else: | |
# 显示聊天记录 | |
for message in st.session_state.messages: | |
avatar = "👩" if message["role"] == "assistant" else "🧑⚕️" | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
# 获取用户输入 | |
if prompt := st.chat_input("请输入您想说的话..."): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user", avatar="🧑⚕️"): | |
st.markdown(prompt) | |
if st.session_state.patient_agent: | |
with st.chat_message("assistant", avatar="👩"): | |
with st.spinner("Anna正在思考..."): | |
response = st.session_state.patient_agent.chat(prompt) | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
else: | |
st.warning("请先在左侧选择一位病人。") |