AnnaAgent-Demo / src /streamlit_app.py
sci-m-wang's picture
Update src/streamlit_app.py
6c26d15 verified
raw
history blame
6.65 kB
import streamlit as st
from openai import OpenAI
import time
import json # 导入json库
import os # 导入os库用于读取环境变量
# =======================================================================
# 1. 导入您的类和数据
# =======================================================================
# 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名
from ms_patient import MsPatient
# 从本地JSON文件加载数据集的函数
def load_data_from_json(filepath="/app/src/CPsyCounS-3134.json"):
"""
从本地的JSON文件加载数据集。
请确保您已将数据集文件上传到与此应用相同的目录中。
"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
# 假设JSON文件的根是一个包含病人记录的列表
return json.load(f)
except FileNotFoundError:
st.error(f"错误:找不到数据文件 '{filepath}'。请确保您已将该文件上传到Hugging Face Space。")
return []
except json.JSONDecodeError:
st.error(f"错误:无法解析 '{filepath}'。请检查文件是否为有效的JSON格式。")
return []
except Exception as e:
st.error(f"加载数据时发生未知错误: {e}")
return []
# 加载数据
ALL_PATIENTS = load_data_from_json()
# =======================================================================
# 2. Streamlit 应用界面
# =======================================================================
# --- 页面配置 ---
st.set_page_config(
page_title="与Anna对话",
page_icon="�",
layout="wide"
)
# --- 自定义CSS样式 ---
st.markdown("""
<style>
/* 主聊天容器 */
.st-emotion-cache-1y4p8pa {
padding-top: 2rem;
}
/* 聊天消息 */
.st-chat-message {
border-radius: 0.8rem;
padding: 0.9rem 1.2rem;
box-shadow: 0 2px 5px rgba(0,0,0,0.05);
background-color: #ffffff;
}
.st-chat-message[data-testid="chat-message-container-user"] {
background-color: #dcf8c6;
}
/* 侧边栏 */
.st-sidebar {
background-color: #f8f9fa;
border-right: 1px solid #e9ecef;
}
.st-sidebar h2 {
color: #343a40;
}
.st-expanderHeader {
font-size: 1.1rem;
font-weight: 600;
}
</style>
""", unsafe_allow_html=True)
# --- 初始化 Session State 和 OpenAI Client ---
# 仅在会话状态中不存在时,才从环境变量初始化客户端
if "openai_client" not in st.session_state:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
st.session_state.openai_client = None
st.session_state.model_name = None
else:
try:
st.session_state.openai_client = OpenAI(api_key=api_key, base_url=os.getenv("OPENAI_BASE_URL"))
st.session_state.model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini")
except Exception as e:
st.error(f"初始化OpenAI客户端失败: {e}")
st.session_state.openai_client = None
st.session_state.model_name = None
if "patient_agent" not in st.session_state:
st.session_state.patient_agent = None
if "messages" not in st.session_state:
st.session_state.messages = []
if "selected_patient_id" not in st.session_state:
st.session_state.selected_patient_id = None
# --- 侧边栏 ---
with st.sidebar:
st.title("👩 AnnaAgent 设置")
st.markdown("---")
# 病人选择
if not ALL_PATIENTS:
st.error("无法加载病人数据。请检查JSON文件是否已上传且格式正确。")
else:
patient_options = {p["id"]: f"{p['portrait']['gender']}{p['portrait']['age']}岁 - {p['portrait']['symptoms']}" for p in ALL_PATIENTS}
selected_id = st.selectbox(
"选择一位病人进行对话",
options=list(patient_options.keys()),
format_func=lambda x: patient_options[x]
)
# 当选择的病人变化时,重置状态
if st.session_state.selected_patient_id != selected_id:
st.session_state.selected_patient_id = selected_id
selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None)
with st.spinner("正在生成病人角色..."):
st.session_state.patient_agent = MsPatient(
portrait=selected_patient_data["portrait"],
report=selected_patient_data["report"],
previous_conversations=selected_patient_data["conversation"],
language="Chinese"
)
st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}]
st.rerun()
# 显示病人信息
if st.session_state.patient_agent:
st.markdown("---")
st.subheader("病人信息")
agent = st.session_state.patient_agent
st.info(f"""
**基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']}
**近期状态**: {agent.status}
""")
with st.expander("查看完整系统提示 (System Prompt)"):
st.code(agent.get_system_prompt(), language='markdown')
# --- 主聊天界面 ---
st.title("💬 与 Anna 对话")
st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。")
# 检查API Key是否已在后台设置
if not st.session_state.openai_client:
st.error("后台未设置 OPENAI_API_KEY。请在Hugging Face Space的'Settings' -> 'Secrets'中进行设置后刷新页面。")
else:
# 显示聊天记录
for message in st.session_state.messages:
avatar = "👩" if message["role"] == "assistant" else "🧑‍⚕️"
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# 获取用户输入
if prompt := st.chat_input("请输入您想说的话..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="🧑‍⚕️"):
st.markdown(prompt)
if st.session_state.patient_agent:
with st.chat_message("assistant", avatar="👩"):
with st.spinner("Anna正在思考..."):
response = st.session_state.patient_agent.chat(prompt)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
else:
st.warning("请先在左侧选择一位病人。")