File size: 5,819 Bytes
2b9aa0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from datetime import datetime
from config.settings import settings
from agent import get_agent_executor
from models import ChatMessage, ChatSession, User # Assuming User is in session_state
from models.db import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start
st.set_page_config(page_title=f"Consult - {settings.APP_TITLE}", layout="wide")
if not st.session_state.get("authenticated_user"):
st.warning("Please log in to access the consultation page.")
st.switch_page("app.py") # Redirect to login
# --- Initialize Agent ---
try:
agent_executor = get_agent_executor()
except ValueError as e: # Handles missing API key
st.error(f"Could not initialize AI Agent: {e}")
st.stop()
# --- Helper Functions ---
def load_chat_history(session_id: int) -> list:
"""Loads chat history from DB for the current session"""
messages = []
with get_session_context() as db:
db_messages = db.query(ChatMessage).filter(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp).all()
for msg in db_messages:
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
# Add tool message handling if you store them as distinct roles in DB
# elif msg.role == "tool":
# messages.append(ToolMessage(content=msg.content, tool_call_id=msg.tool_call_id))
return messages
def save_chat_message(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
"""Saves a chat message to the database."""
with get_session_context() as db:
chat_message = ChatMessage(
session_id=session_id,
role=role,
content=content,
timestamp=datetime.utcnow(),
tool_call_id=tool_call_id,
tool_name=tool_name
)
db.add(chat_message)
db.commit()
# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown("Interact with the Quantum Health Navigator AI.")
current_user: User = st.session_state.authenticated_user
chat_session_id = st.session_state.get("current_chat_session_id")
if not chat_session_id:
st.error("No active chat session. Please re-login or contact support.")
st.stop()
# Load initial chat history for the agent (from Langchain Message objects)
# For the agent, we need history in LangChain message format
if "agent_chat_history" not in st.session_state:
st.session_state.agent_chat_history = load_chat_history(chat_session_id)
if not st.session_state.agent_chat_history: # If no history, maybe add a system greeting
log_consultation_start()
# You could add an initial AIMessage here if desired
# initial_ai_message = AIMessage(content="Hello! How can I assist you today?")
# st.session_state.agent_chat_history.append(initial_ai_message)
# save_chat_message(chat_session_id, "assistant", initial_ai_message.content)
# Display chat messages from DB (for UI)
with get_session_context() as db:
ui_messages = db.query(ChatMessage).filter(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp).all()
for msg in ui_messages:
with st.chat_message(msg.role):
st.markdown(msg.content)
# Chat input
if prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?' or 'Optimize treatment for patient X with diabetes')"):
# Add user message to UI and save to DB
with st.chat_message("user"):
st.markdown(prompt)
save_chat_message(chat_session_id, "user", prompt)
# Add to agent's history (LangChain format)
st.session_state.agent_chat_history.append(HumanMessage(content=prompt))
# Get AI response
with st.spinner("AI is thinking..."):
try:
response = agent_executor.invoke({
"input": prompt,
"chat_history": st.session_state.agent_chat_history
})
ai_response_content = response['output']
# Display AI response in UI and save to DB
with st.chat_message("assistant"):
st.markdown(ai_response_content)
save_chat_message(chat_session_id, "assistant", ai_response_content)
# Add AI response to agent's history
st.session_state.agent_chat_history.append(AIMessage(content=ai_response_content))
# Note: The agent executor might make tool calls. The create_openai_functions_agent
# and AgentExecutor handle the tool invocation and adding ToolMessages to history internally
# before producing the final 'output'. If you need to log individual tool calls/results
# to your DB, you might need a more custom agent loop or callbacks.
except Exception as e:
app_logger.error(f"Error during agent invocation: {e}")
st.error(f"An error occurred: {e}")
# Save error message as AI response?
error_message = f"Sorry, I encountered an error: {str(e)[:200]}" # Truncate for DB
with st.chat_message("assistant"): # Or a custom error role
st.markdown(error_message)
save_chat_message(chat_session_id, "assistant", error_message) # Or "error" role
st.session_state.agent_chat_history.append(AIMessage(content=error_message))
# Rerun to show the latest messages immediately (though Streamlit usually does this)
# st.rerun() # Usually not needed with st.chat_input and context managers |