import streamlit as st from openai import OpenAI import time from datasets import load_dataset # 导入Hugging Face datasets库 # ======================================================================= # 1. 导入您的类和数据 # ======================================================================= # 请将 'your_agent_file' 替换为包含 MsPatient 类的实际文件名 from ms_patient import MsPatient # 使用Streamlit的缓存功能来加载和缓存数据集 @st.cache_data def load_hf_dataset(): """ 从Hugging Face Hub加载并缓存数据集。 这将返回一个字典列表,每个字典代表一个病人数据。 """ try: # 加载'train'分割部分的数据 dataset = load_dataset("sci-m-wang/Anna-CPsyCounD", split='train') # 转换为pandas DataFrame再转为字典列表,方便处理 return dataset.to_pandas().to_dict('records') except Exception as e: st.error(f"从Hugging Face加载数据集失败: {e}") return [] # 加载数据 ALL_PATIENTS = load_hf_dataset() # ======================================================================= # 2. Streamlit 应用界面 # ======================================================================= # --- 页面配置 --- st.set_page_config( page_title="与Anna对话", page_icon="�", layout="wide" ) # --- 自定义CSS样式 --- st.markdown(""" """, unsafe_allow_html=True) # --- 初始化 Session State --- if "patient_agent" not in st.session_state: st.session_state.patient_agent = None if "messages" not in st.session_state: st.session_state.messages = [] if "selected_patient_id" not in st.session_state: st.session_state.selected_patient_id = None if "openai_client" not in st.session_state: st.session_state.openai_client = None if "model_name" not in st.session_state: st.session_state.model_name = "gpt-4o-mini" # 默认模型 # --- 侧边栏 --- with st.sidebar: st.title("👩 AnnaAgent 设置") st.markdown("---") # API Key 输入 with st.expander("🔑 API 设置", expanded=True): api_key = st.text_input("输入您的 OpenAI API Key", type="password", help="您的API Key将仅用于本次会话,不会被储存。") base_url = st.text_input("API Base URL (可选)", value="https://api.openai.com/v1") model_name = st.text_input("模型名称", value=st.session_state.model_name) if st.button("连接模型"): if api_key: try: st.session_state.openai_client = OpenAI(api_key=api_key, base_url=base_url) st.session_state.model_name = model_name st.success("连接成功!") if st.session_state.patient_agent: st.session_state.patient_agent.client = st.session_state.openai_client except Exception as e: st.error(f"连接失败: {e}") else: st.warning("请输入API Key。") st.markdown("---") # 病人选择 if not ALL_PATIENTS: st.error("无法加载病人数据,请检查网络连接或数据集名称。") else: patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptom']}" for p in ALL_PATIENTS} selected_id = st.selectbox( "选择一位病人进行对话", options=list(patient_options.keys()), format_func=lambda x: patient_options[x] ) # 当选择的病人变化时,重置状态 if st.session_state.selected_patient_id != selected_id: st.session_state.selected_patient_id = selected_id selected_patient_data = next((p for p in ALL_PATIENTS if p["id"] == selected_id), None) with st.spinner("正在生成病人角色..."): st.session_state.patient_agent = MsPatient( portrait=selected_patient_data["portrait"], report=selected_patient_data["report"], previous_conversations=selected_patient_data["conversation"], language="Chinese", client=st.session_state.openai_client ) st.session_state.messages = [{"role": "assistant", "content": "你好,医生..."}] st.rerun() # 显示病人信息 if st.session_state.patient_agent: st.markdown("---") st.subheader("病人信息") agent = st.session_state.patient_agent st.info(f""" **基本情况**: {agent.portrait['gender']}, {agent.portrait['age']}岁, {agent.portrait['occupation']}, {agent.portrait['marital_status']} **近期状态**: {agent.status} """) with st.expander("查看完整系统提示 (System Prompt)"): st.code(agent.get_system_prompt(), language='markdown') # --- 主聊天界面 --- st.title("💬 与 Anna 对话") st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。") # 显示聊天记录 for message in st.session_state.messages: avatar = "👩" if message["role"] == "assistant" else "🧑‍⚕️" with st.chat_message(message["role"], avatar=avatar): st.markdown(message["content"]) # 获取用户输入 if prompt := st.chat_input("请输入您想说的话..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user", avatar="🧑‍⚕️"): st.markdown(prompt) if st.session_state.patient_agent: with st.chat_message("assistant", avatar="👩"): with st.spinner("Anna正在思考..."): response = st.session_state.patient_agent.chat(prompt) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) else: st.warning("请先在左侧选择一位病人并配置API Key。")