File size: 5,819 Bytes
2b9aa0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from datetime import datetime

from config.settings import settings
from agent import get_agent_executor
from models import ChatMessage, ChatSession, User # Assuming User is in session_state
from models.db import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start

st.set_page_config(page_title=f"Consult - {settings.APP_TITLE}", layout="wide")

if not st.session_state.get("authenticated_user"):
    st.warning("Please log in to access the consultation page.")
    st.switch_page("app.py") # Redirect to login

# --- Initialize Agent ---
try:
    agent_executor = get_agent_executor()
except ValueError as e: # Handles missing API key
    st.error(f"Could not initialize AI Agent: {e}")
    st.stop()


# --- Helper Functions ---
def load_chat_history(session_id: int) -> list:
    """Loads chat history from DB for the current session"""
    messages = []
    with get_session_context() as db:
        db_messages = db.query(ChatMessage).filter(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp).all()
        for msg in db_messages:
            if msg.role == "user":
                messages.append(HumanMessage(content=msg.content))
            elif msg.role == "assistant":
                messages.append(AIMessage(content=msg.content))
            # Add tool message handling if you store them as distinct roles in DB
            # elif msg.role == "tool":
            #     messages.append(ToolMessage(content=msg.content, tool_call_id=msg.tool_call_id))
    return messages

def save_chat_message(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
    """Saves a chat message to the database."""
    with get_session_context() as db:
        chat_message = ChatMessage(
            session_id=session_id,
            role=role,
            content=content,
            timestamp=datetime.utcnow(),
            tool_call_id=tool_call_id,
            tool_name=tool_name
        )
        db.add(chat_message)
        db.commit()

# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown("Interact with the Quantum Health Navigator AI.")

current_user: User = st.session_state.authenticated_user
chat_session_id = st.session_state.get("current_chat_session_id")

if not chat_session_id:
    st.error("No active chat session. Please re-login or contact support.")
    st.stop()

# Load initial chat history for the agent (from Langchain Message objects)
# For the agent, we need history in LangChain message format
if "agent_chat_history" not in st.session_state:
    st.session_state.agent_chat_history = load_chat_history(chat_session_id)
    if not st.session_state.agent_chat_history: # If no history, maybe add a system greeting
        log_consultation_start()
        # You could add an initial AIMessage here if desired
        # initial_ai_message = AIMessage(content="Hello! How can I assist you today?")
        # st.session_state.agent_chat_history.append(initial_ai_message)
        # save_chat_message(chat_session_id, "assistant", initial_ai_message.content)


# Display chat messages from DB (for UI)
with get_session_context() as db:
    ui_messages = db.query(ChatMessage).filter(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp).all()
    for msg in ui_messages:
        with st.chat_message(msg.role):
            st.markdown(msg.content)

# Chat input
if prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?' or 'Optimize treatment for patient X with diabetes')"):
    # Add user message to UI and save to DB
    with st.chat_message("user"):
        st.markdown(prompt)
    save_chat_message(chat_session_id, "user", prompt)
    
    # Add to agent's history (LangChain format)
    st.session_state.agent_chat_history.append(HumanMessage(content=prompt))

    # Get AI response
    with st.spinner("AI is thinking..."):
        try:
            response = agent_executor.invoke({
                "input": prompt,
                "chat_history": st.session_state.agent_chat_history
            })
            ai_response_content = response['output']
            
            # Display AI response in UI and save to DB
            with st.chat_message("assistant"):
                st.markdown(ai_response_content)
            save_chat_message(chat_session_id, "assistant", ai_response_content)
            
            # Add AI response to agent's history
            st.session_state.agent_chat_history.append(AIMessage(content=ai_response_content))

            # Note: The agent executor might make tool calls. The create_openai_functions_agent
            # and AgentExecutor handle the tool invocation and adding ToolMessages to history internally
            # before producing the final 'output'. If you need to log individual tool calls/results
            # to your DB, you might need a more custom agent loop or callbacks.

        except Exception as e:
            app_logger.error(f"Error during agent invocation: {e}")
            st.error(f"An error occurred: {e}")
            # Save error message as AI response?
            error_message = f"Sorry, I encountered an error: {str(e)[:200]}" # Truncate for DB
            with st.chat_message("assistant"): # Or a custom error role
                st.markdown(error_message)
            save_chat_message(chat_session_id, "assistant", error_message) # Or "error" role
            st.session_state.agent_chat_history.append(AIMessage(content=error_message))
    
    # Rerun to show the latest messages immediately (though Streamlit usually does this)
    # st.rerun() # Usually not needed with st.chat_input and context managers