SynapseAI / app.py
mgbam's picture
Update app.py
4258926 verified
raw
history blame
34.3 kB
import streamlit as st
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver # For state persistence (optional but good)
from typing import Optional, List, Dict, Any, TypedDict, Annotated
import json
import re
import operator
# --- Configuration & Constants --- (Keep previous ones like ClinicalAppSettings)
class ClinicalAppSettings:
APP_TITLE = "SynapseAI: Interactive Clinical Decision Support"
PAGE_LAYOUT = "wide"
MODEL_NAME = "llama3-70b-8192"
TEMPERATURE = 0.1
MAX_SEARCH_RESULTS = 3
class ClinicalPrompts:
# UPDATED SYSTEM PROMPT FOR CONVERSATIONAL FLOW & GUIDELINES
SYSTEM_PROMPT = """
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
**Core Directives for this Conversation:**
1. **Analyze Sequentially:** Process information turn-by-turn. You will receive initial patient data, and potentially follow-up messages or results from tools you requested. Base your responses on the *entire* conversation history.
2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions), provide a comprehensive assessment using the following JSON structure. Only output this structure when you believe you have a complete initial analysis or plan. Do NOT output incomplete JSON.
```json
{
"assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
"differential_diagnosis": [
{"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
{"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
{"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
],
"risk_assessment": {
"identified_red_flags": ["List any triggered red flags"],
"immediate_concerns": ["Specific urgent issues (e.g., sepsis risk, ACS rule-out)"],
"potential_complications": ["Possible future issues"]
},
"recommended_plan": {
"investigations": ["List specific lab tests or imaging needed. Use 'order_lab_test' tool."],
"therapeutics": ["Suggest specific treatments/prescriptions. Use 'prescribe_medication' tool. MUST check interactions first."],
"consultations": ["Recommend specialist consultations."],
"patient_education": ["Key points for patient communication."]
},
"rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.**",
"interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
}
```
4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions`. Report the findings. If interactions exist, modify the plan or state the contraindication.
5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point.
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription).
7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
8. **Conciseness:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation until ready for the full structured JSON output.
"""
# --- Mock Data / Helpers --- (Keep previous ones like MOCK_INTERACTION_DB, ALLERGY_INTERACTIONS, parse_bp, check_red_flags)
# (Include the helper functions from the previous response here)
MOCK_INTERACTION_DB = {
("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding."
}
ALLERGY_INTERACTIONS = {
"penicillin": ["amoxicillin", "ampicillin", "piperacillin"],
"sulfa": ["sulfamethoxazole", "sulfasalazine"],
"aspirin": ["ibuprofen", "naproxen"] # Cross-reactivity example for NSAIDs
}
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
if match: return int(match.group(1)), int(match.group(2))
return None
def check_red_flags(patient_data: dict) -> List[str]:
flags = []
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
vitals = patient_data.get("vitals", {})
history = patient_data.get("pmh", {}).get("conditions", "")
symptoms_lower = [s.lower() for s in symptoms]
if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
# Add other symptom checks...
if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever ({vitals['temp_c']}ยฐC).")
if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia ({vitals['hr_bpm']} bpm).")
if "bp_mmhg" in vitals:
bp = parse_bp(vitals["bp_mmhg"])
if bp and (bp[0] >= 180 or bp[1] >= 110): flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).")
if bp and (bp[0] <= 90 or bp[1] <= 60): flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).")
# Add other vital checks...
if "history of mi" in history.lower() and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
# Add other history checks...
return flags
# --- Enhanced Tool Definitions --- (Keep previous Pydantic models and @tool functions)
# (Include LabOrderInput, PrescriptionInput, InteractionCheckInput, FlagRiskInput
# and the corresponding @tool functions: order_lab_test, prescribe_medication,
# check_drug_interactions, flag_risk from the previous response here)
class LabOrderInput(BaseModel):
test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")
@tool("order_lab_test", args_schema=LabOrderInput)
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
"""Orders a specific lab test with clinical justification and priority."""
return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"})
class PrescriptionInput(BaseModel):
medication_name: str = Field(..., description="Name of the medication.")
dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg').")
route: str = Field(..., description="Route of administration (e.g., 'PO', 'IV', 'IM', 'Topical').")
frequency: str = Field(..., description="How often the medication should be taken (e.g., 'BID', 'QDaily', 'Q4-6H PRN').")
duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Until follow-up').")
reason: str = Field(..., description="Clinical indication for the prescription.")
@tool("prescribe_medication", args_schema=PrescriptionInput)
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
"""Prescribes a medication with detailed instructions and clinical indication."""
# NOTE: Interaction check should have been done *before* calling this via a separate tool call
return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
class InteractionCheckInput(BaseModel):
potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
current_medications: List[str] = Field(..., description="List of the patient's CURRENT medication names.")
allergies: List[str] = Field(..., description="List of the patient's known allergies.")
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
def check_drug_interactions(potential_prescription: str, current_medications: List[str], allergies: List[str]) -> str:
"""Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
warnings = []
potential_med_lower = potential_prescription.lower()
current_meds_lower = [med.lower() for med in current_medications]
allergies_lower = [a.lower() for a in allergies]
for allergy in allergies_lower:
if allergy == potential_med_lower:
warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
continue
if allergy in ALLERGY_INTERACTIONS:
for cross_reactant in ALLERGY_INTERACTIONS[allergy]:
if cross_reactant.lower() == potential_med_lower:
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
for current_med in current_meds_lower:
pair1 = (current_med, potential_med_lower)
pair2 = (potential_med_lower, current_med)
# Normalize keys for lookup if necessary (e.g., if DB keys are canonical names)
key1 = tuple(sorted(pair1))
key2 = tuple(sorted(pair2)) # Although redundant if always sorted
if pair1 in MOCK_INTERACTION_DB:
warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}")
elif pair2 in MOCK_INTERACTION_DB:
warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}")
status = "warning" if warnings else "clear"
message = f"Interaction check for {potential_prescription}: {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for {potential_prescription}."
return json.dumps({"status": status, "message": message, "warnings": warnings})
class FlagRiskInput(BaseModel):
risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
urgency: str = Field("High", description="Urgency level (e.g., 'Critical', 'High', 'Moderate').")
@tool("flag_risk", args_schema=FlagRiskInput)
def flag_risk(risk_description: str, urgency: str) -> str:
"""Flags a critical risk identified during analysis for immediate attention."""
# Display in Streamlit immediately
st.error(f"๐Ÿšจ **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="๐Ÿšจ")
return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
# Initialize Search Tool
search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
# --- LangGraph Setup ---
# Define the state structure
class AgentState(TypedDict):
messages: Annotated[list[Any], operator.add] # Accumulates messages (Human, AI, Tool)
patient_data: Optional[dict] # Holds the structured patient data (can be updated if needed)
# Potentially add other state elements like 'interaction_check_needed_for': Optional[str]
# Define Tools and Tool Executor
tools = [
order_lab_test,
prescribe_medication,
check_drug_interactions,
flag_risk,
search_tool
]
tool_executor = ToolExecutor(tools)
# Define the Agent Model
model = ChatGroq(
temperature=ClinicalAppSettings.TEMPERATURE,
model=ClinicalAppSettings.MODEL_NAME
)
model_with_tools = model.bind_tools(tools) # Bind tools for the LLM to know about them
# --- Graph Nodes ---
# 1. Agent Node: Calls the LLM
def agent_node(state: AgentState):
"""Invokes the LLM to decide the next action or response."""
print("---AGENT NODE---")
# Make sure patient data is included in the first message if not already there
# This is a basic way; more robust would be merging patient_data into context
current_messages = state['messages']
if len(current_messages) == 1 and isinstance(current_messages[0], HumanMessage) and state.get('patient_data'):
# Augment the first human message with formatted patient data
formatted_data = format_patient_data_for_prompt(state['patient_data']) # Need this helper function
current_messages = [
SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT), # Ensure system prompt is first
HumanMessage(content=f"{current_messages[0].content}\n\n**Initial Patient Data:**\n{formatted_data}")
]
elif not any(isinstance(m, SystemMessage) for m in current_messages):
# Add system prompt if missing
current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
response = model_with_tools.invoke(current_messages)
print(f"Agent response: {response}")
return {"messages": [response]}
# 2. Tool Node: Executes tools called by the Agent
def tool_node(state: AgentState):
"""Executes tools called by the LLM and returns results."""
print("---TOOL NODE---")
last_message = state['messages'][-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
print("No tool calls in last message.")
return {} # Should not happen if routing is correct, but safety check
tool_calls = last_message.tool_calls
tool_messages = []
# Safety Check: Ensure interaction check happens *before* prescribing the *same* drug
prescribe_calls = {call['args'].get('medication_name'): call['id'] for call in tool_calls if call['name'] == 'prescribe_medication'}
interaction_check_calls = {call['args'].get('potential_prescription'): call['id'] for call in tool_calls if call['name'] == 'check_drug_interactions'}
for med_name, prescribe_call_id in prescribe_calls.items():
if med_name not in interaction_check_calls:
st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked for this turn.")
# Create an error ToolMessage to send back to the LLM
error_msg = ToolMessage(
content=json.dumps({"status": "error", "message": f"Interaction check for {med_name} must be requested *before or alongside* the prescription call."}),
tool_call_id=prescribe_call_id
)
tool_messages.append(error_msg)
# Remove the invalid prescribe call to prevent execution
tool_calls = [call for call in tool_calls if call['id'] != prescribe_call_id]
# Add patient context to interaction checks if needed
patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", [])
patient_allergies = state.get("patient_data", {}).get("allergies", [])
for call in tool_calls:
if call['name'] == 'check_drug_interactions':
call['args']['current_medications'] = patient_meds
call['args']['allergies'] = patient_allergies
print(f"Augmented interaction check args: {call['args']}")
# Execute remaining valid tool calls
if tool_calls:
responses = tool_executor.batch(tool_calls)
# Responses is a list of tool outputs corresponding to tool_calls
# We need to create ToolMessage objects
tool_messages.extend([
ToolMessage(content=str(resp), tool_call_id=call['id'])
for call, resp in zip(tool_calls, responses)
])
print(f"Tool results: {tool_messages}")
return {"messages": tool_messages}
# --- Graph Edges (Routing Logic) ---
def should_continue(state: AgentState) -> str:
"""Determines whether to continue the loop or end."""
last_message = state['messages'][-1]
# If the LLM made tool calls, we execute them
if isinstance(last_message, AIMessage) and last_message.tool_calls:
print("Routing: continue_tools")
return "continue_tools"
# Otherwise, we end the loop (AI provided a direct answer or finished)
else:
print("Routing: end_conversation_turn")
return "end_conversation_turn"
# --- Graph Definition ---
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
# Define entry point
workflow.set_entry_point("agent")
# Add conditional edges
workflow.add_conditional_edges(
"agent", # Source node
should_continue, # Function to decide the route
{
"continue_tools": "tools", # If tool calls exist, go to tools node
"end_conversation_turn": END # Otherwise, end the graph iteration
}
)
# Add edge from tools back to agent
workflow.add_edge("tools", "agent")
# Compile the graph
# memory = MemorySaverInMemory() # Optional: for persisting state across runs
# app = workflow.compile(checkpointer=memory)
app = workflow.compile()
# --- Helper Function to Format Patient Data ---
def format_patient_data_for_prompt(data: dict) -> str:
"""Formats the patient dictionary into a readable string for the LLM."""
prompt_str = ""
for key, value in data.items():
if isinstance(value, dict):
section_title = key.replace('_', ' ').title()
prompt_str += f"**{section_title}:**\n"
for sub_key, sub_value in value.items():
if sub_value:
prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
elif isinstance(value, list) and value:
prompt_str += f"**{key.replace('_', ' ').title()}:** {', '.join(map(str, value))}\n"
elif value:
prompt_str += f"**{key.replace('_', ' ').title()}:** {value}\n"
return prompt_str.strip()
# --- Streamlit UI (Modified for Conversation) ---
def main():
st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
st.title(f"๐Ÿฉบ {ClinicalAppSettings.APP_TITLE}")
st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})")
# Initialize session state for conversation
if "messages" not in st.session_state:
st.session_state.messages = [] # Store entire conversation history (Human, AI, Tool)
if "patient_data" not in st.session_state:
st.session_state.patient_data = None
if "initial_analysis_done" not in st.session_state:
st.session_state.initial_analysis_done = False
if "graph_app" not in st.session_state:
st.session_state.graph_app = app # Store compiled graph
# --- Patient Data Input Sidebar --- (Similar to before)
with st.sidebar:
st.header("๐Ÿ“„ Patient Intake Form")
# ... (Keep the input fields exactly as in the previous example) ...
# Demographics
age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
# HPI
chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input")
symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"], key="sym_input")
# History
pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)", key="pmh_input")
psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
# Meds & Allergies
current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
allergies_str = st.text_area("Allergies (comma separated)", "Penicillin (rash)", key="allergy_input")
# Social/Family
social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
# Vitals/Exam
col1, col2 = st.columns(2)
with col1:
temp_c = st.number_input("Temp (ยฐC)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input")
rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input")
with col2:
bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3...", key="exam_input")
# Compile Patient Data Dictionary on button press
if st.button("Start/Update Consultation", key="start_button"):
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
current_med_names = []
# Improved parsing for names (still basic, assumes name is first word)
for med in current_meds_list:
match = re.match(r"^\s*([a-zA-Z\-]+)", med)
if match:
current_med_names.append(match.group(1).lower()) # Use lower case for matching
allergies_list = [a.strip().lower() for a in allergies_str.split(',') if a.strip()] # Lowercase allergies
st.session_state.patient_data = {
"demographics": {"age": age, "sex": sex},
"hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
"pmh": {"conditions": pmh}, "psh": {"procedures": psh},
"medications": {"current": current_meds_list, "names_only": current_med_names},
"allergies": allergies_list,
"social_history": {"details": social_history}, "family_history": {"details": family_history},
"vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
"exam_findings": {"notes": exam_notes}
}
# Initial Red Flag Check (Client-side)
red_flags = check_red_flags(st.session_state.patient_data)
if red_flags:
st.warning("**Initial Red Flags Detected:**")
for flag in red_flags: st.warning(f"- {flag}")
# Prepare initial message for the graph
initial_prompt = f"Analyze the following patient case:\nChief Complaint: {chief_complaint}\nSummary: {age} y/o {sex} presenting with..." # Keep it brief, full data is in state
st.session_state.messages = [HumanMessage(content=initial_prompt)]
st.session_state.initial_analysis_done = False # Reset analysis state
st.success("Patient data loaded. Ready for analysis.")
st.rerun() # Refresh main area to show chat
# --- Main Chat Interface Area ---
st.header("๐Ÿ’ฌ Clinical Consultation")
# Display chat messages
for msg in st.session_state.messages:
if isinstance(msg, HumanMessage):
with st.chat_message("user"):
st.markdown(msg.content)
elif isinstance(msg, AIMessage):
with st.chat_message("assistant"):
# Check for structured JSON output
structured_output = None
try:
# Try to find JSON block first
json_match = re.search(r"```json\n(\{.*?\})\n```", msg.content, re.DOTALL)
if json_match:
structured_output = json.loads(json_match.group(1))
# Display non-JSON parts if any
non_json_content = msg.content.replace(json_match.group(0), "").strip()
if non_json_content:
st.markdown(non_json_content)
st.divider() # Separate text from structured output visually
elif msg.content.strip().startswith("{") and msg.content.strip().endswith("}"):
# Maybe the whole message is JSON
structured_output = json.loads(msg.content)
else:
# No JSON found, display raw content
st.markdown(msg.content)
if structured_output:
# Display the structured data nicely (reuse parts of previous UI display logic)
st.subheader("๐Ÿ“Š AI Analysis & Recommendations")
# ... (Add logic here to display assessment, ddx, plan etc. from structured_output)
# Example:
st.write(f"**Assessment:** {structured_output.get('assessment', 'N/A')}")
# Display DDx, Plan etc. using expanders or tabs
# ...
# Display Rationale & Interaction Summary
with st.expander("Rationale & Guideline Check"):
st.write(structured_output.get("rationale_summary", "N/A"))
if structured_output.get("interaction_check_summary"):
with st.expander("Interaction Check"):
st.write(structured_output.get("interaction_check_summary"))
except json.JSONDecodeError:
st.markdown(msg.content) # Display raw if JSON parsing fails
# Display tool calls if any were made in this AI turn
if msg.tool_calls:
with st.expander("๐Ÿ› ๏ธ AI requested actions", expanded=False):
for tc in msg.tool_calls:
st.code(f"{tc['name']}(args={tc['args']})", language="python")
elif isinstance(msg, ToolMessage):
with st.chat_message("tool", avatar="๐Ÿ› ๏ธ"):
try:
tool_data = json.loads(msg.content)
status = tool_data.get("status", "info")
message = tool_data.get("message", msg.content)
details = tool_data.get("details")
warnings = tool_data.get("warnings")
if status == "success" or status == "clear" or status == "flagged":
st.success(f"Tool Result ({msg.name}): {message}", icon="โœ…" if status != "flagged" else "๐Ÿšจ")
elif status == "warning":
st.warning(f"Tool Result ({msg.name}): {message}", icon="โš ๏ธ")
if warnings:
for warn in warnings: st.caption(f"- {warn}")
else: # Error or unknown status
st.error(f"Tool Result ({msg.name}): {message}", icon="โŒ")
if details: st.caption(f"Details: {details}")
except json.JSONDecodeError:
st.info(f"Tool Result ({msg.name}): {msg.content}") # Display raw if not JSON
# Chat input for user
if prompt := st.chat_input("Your message or follow-up query..."):
if not st.session_state.patient_data:
st.warning("Please load patient data using the sidebar first.")
else:
# Add user message to state
st.session_state.messages.append(HumanMessage(content=prompt))
with st.chat_message("user"):
st.markdown(prompt)
# Prepare state for graph invocation
current_state = AgentState(
messages=st.session_state.messages,
patient_data=st.session_state.patient_data
)
# Stream graph execution
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# Use stream to get intermediate steps (optional but good for UX)
# This shows AI thinking and tool calls/results progressively
try:
for event in st.session_state.graph_app.stream(current_state, {"recursion_limit": 15}):
# event is a dictionary, keys are node names
if "agent" in event:
ai_msg = event["agent"]["messages"][-1] # Get the latest AI message
if isinstance(ai_msg, AIMessage):
full_response += ai_msg.content # Append content for final display
message_placeholder.markdown(full_response + "โ–Œ") # Show typing indicator
# Display tool calls as they happen (optional)
# if ai_msg.tool_calls:
# st.info(f"Requesting tools: {[tc['name'] for tc in ai_msg.tool_calls]}")
elif "tools" in event:
# Display tool results as they come back (optional, already handled by message display loop)
pass
# tool_msgs = event["tools"]["messages"]
# for tool_msg in tool_msgs:
# st.info(f"Tool {tool_msg.name} result received.")
# Final display after streaming
message_placeholder.markdown(full_response)
# Update session state with the final messages from the graph run
# The graph state itself isn't directly accessible after streaming finishes easily this way
# We need to get the final state if we used invoke, or reconstruct from stream events
# A simpler way for now: just append the *last* AI message and any Tool messages from the stream
# This assumes the stream provides the final state implicitly. For robust state, use invoke or checkpointer.
# A more robust way: invoke and get final state
# final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
# st.session_state.messages = final_state['messages']
# --- Let's stick to appending for simplicity in this example ---
# Find the last AI message and tool messages from the stream (needs careful event parsing)
# Or, re-run invoke non-streamed just to get final state (less efficient)
final_state_capture = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
st.session_state.messages = final_state_capture['messages']
except Exception as e:
st.error(f"An error occurred during analysis: {e}")
# Attempt to add the error message to the history
st.session_state.messages.append(AIMessage(content=f"Sorry, an error occurred: {e}"))
# Rerun to display the updated chat history correctly
st.rerun()
# Disclaimer
st.markdown("---")
st.warning("**Disclaimer:** SynapseAI is for clinical decision support...") # Keep disclaimer
if __name__ == "__main__":
main()