Update app.py
Browse files
app.py
CHANGED
|
@@ -1,116 +1,117 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from langchain_groq import ChatGroq
|
| 3 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 4 |
-
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 5 |
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 7 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
| 8 |
from langchain_core.tools import tool
|
| 9 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import json
|
| 11 |
-
import re
|
|
|
|
| 12 |
|
| 13 |
-
# --- Configuration & Constants ---
|
| 14 |
class ClinicalAppSettings:
|
| 15 |
-
APP_TITLE = "SynapseAI:
|
| 16 |
PAGE_LAYOUT = "wide"
|
| 17 |
-
MODEL_NAME = "llama3-70b-8192"
|
| 18 |
TEMPERATURE = 0.1
|
| 19 |
MAX_SEARCH_RESULTS = 3
|
| 20 |
|
| 21 |
class ClinicalPrompts:
|
|
|
|
| 22 |
SYSTEM_PROMPT = """
|
| 23 |
-
You are SynapseAI, an expert AI clinical assistant
|
| 24 |
-
Your
|
| 25 |
|
| 26 |
-
**Core Directives:**
|
| 27 |
-
1. **
|
| 28 |
-
2. **
|
|
|
|
| 29 |
```json
|
| 30 |
{
|
| 31 |
-
"assessment": "Concise summary of the patient's presentation and key findings.",
|
| 32 |
"differential_diagnosis": [
|
| 33 |
-
{"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence..."},
|
| 34 |
{"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
|
| 35 |
{"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
|
| 36 |
],
|
| 37 |
"risk_assessment": {
|
| 38 |
-
"identified_red_flags": ["List any triggered red flags
|
| 39 |
-
"immediate_concerns": ["Specific urgent issues
|
| 40 |
-
"potential_complications": ["Possible future issues
|
| 41 |
},
|
| 42 |
"recommended_plan": {
|
| 43 |
-
"investigations": ["List specific lab tests or imaging
|
| 44 |
-
"therapeutics": ["Suggest specific treatments
|
| 45 |
-
"consultations": ["Recommend specialist consultations
|
| 46 |
"patient_education": ["Key points for patient communication."]
|
| 47 |
},
|
| 48 |
-
"rationale_summary": "
|
| 49 |
-
"interaction_check_summary": "Summary of findings from
|
| 50 |
}
|
| 51 |
```
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
"""
|
| 58 |
|
| 59 |
-
# --- Mock Data / Helpers ---
|
| 60 |
-
# (
|
| 61 |
MOCK_INTERACTION_DB = {
|
| 62 |
-
("
|
| 63 |
-
("
|
| 64 |
-
("
|
| 65 |
-
("
|
| 66 |
}
|
| 67 |
|
| 68 |
ALLERGY_INTERACTIONS = {
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
"
|
| 72 |
}
|
| 73 |
|
| 74 |
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
| 75 |
-
"""Parses BP string like '120/80' into (systolic, diastolic) integers."""
|
| 76 |
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
|
| 77 |
-
if match:
|
| 78 |
-
return int(match.group(1)), int(match.group(2))
|
| 79 |
return None
|
| 80 |
|
| 81 |
def check_red_flags(patient_data: dict) -> List[str]:
|
| 82 |
-
"""Checks patient data against predefined red flags."""
|
| 83 |
flags = []
|
| 84 |
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
|
| 85 |
vitals = patient_data.get("vitals", {})
|
| 86 |
history = patient_data.get("pmh", {}).get("conditions", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
if "
|
| 90 |
-
if "shortness of breath" in [s.lower() for s in symptoms]: flags.append("Red Flag: Shortness of Breath reported.")
|
| 91 |
-
if "severe headache" in [s.lower() for s in symptoms]: flags.append("Red Flag: Severe Headache reported.")
|
| 92 |
-
if "sudden vision loss" in [s.lower() for s in symptoms]: flags.append("Red Flag: Sudden Vision Loss reported.")
|
| 93 |
-
if "weakness on one side" in [s.lower() for s in symptoms]: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
|
| 94 |
-
|
| 95 |
-
# Vital Sign Flags (add more checks as needed)
|
| 96 |
-
if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever (Temperature: {vitals['temp_c']}°C).")
|
| 97 |
-
if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia (Heart Rate: {vitals['hr_bpm']} bpm).")
|
| 98 |
-
if "rr_rpm" in vitals and vitals["rr_rpm"] >= 24: flags.append(f"Red Flag: Tachypnea (Respiratory Rate: {vitals['rr_rpm']} rpm).")
|
| 99 |
-
if "spo2_percent" in vitals and vitals["spo2_percent"] <= 92: flags.append(f"Red Flag: Hypoxia (SpO2: {vitals['spo2_percent']}%).")
|
| 100 |
if "bp_mmhg" in vitals:
|
| 101 |
bp = parse_bp(vitals["bp_mmhg"])
|
| 102 |
-
if bp:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
# History Flags (Simple examples)
|
| 107 |
-
if "history of mi" in history.lower() and "chest pain" in [s.lower() for s in symptoms]: flags.append("Red Flag: History of MI with current Chest Pain.")
|
| 108 |
|
|
|
|
|
|
|
| 109 |
return flags
|
| 110 |
|
| 111 |
-
# --- Enhanced Tool Definitions ---
|
| 112 |
|
| 113 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
class LabOrderInput(BaseModel):
|
| 115 |
test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
|
| 116 |
reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
|
|
@@ -119,11 +120,7 @@ class LabOrderInput(BaseModel):
|
|
| 119 |
@tool("order_lab_test", args_schema=LabOrderInput)
|
| 120 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
| 121 |
"""Orders a specific lab test with clinical justification and priority."""
|
| 122 |
-
return json.dumps({
|
| 123 |
-
"status": "success",
|
| 124 |
-
"message": f"Lab Ordered: {test_name} ({priority})",
|
| 125 |
-
"details": f"Reason: {reason}"
|
| 126 |
-
})
|
| 127 |
|
| 128 |
class PrescriptionInput(BaseModel):
|
| 129 |
medication_name: str = Field(..., description="Name of the medication.")
|
|
@@ -136,12 +133,8 @@ class PrescriptionInput(BaseModel):
|
|
| 136 |
@tool("prescribe_medication", args_schema=PrescriptionInput)
|
| 137 |
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
|
| 138 |
"""Prescribes a medication with detailed instructions and clinical indication."""
|
| 139 |
-
#
|
| 140 |
-
return json.dumps({
|
| 141 |
-
"status": "success",
|
| 142 |
-
"message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
|
| 143 |
-
"details": f"Duration: {duration}. Reason: {reason}"
|
| 144 |
-
})
|
| 145 |
|
| 146 |
class InteractionCheckInput(BaseModel):
|
| 147 |
potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
|
|
@@ -153,35 +146,34 @@ def check_drug_interactions(potential_prescription: str, current_medications: Li
|
|
| 153 |
"""Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
|
| 154 |
warnings = []
|
| 155 |
potential_med_lower = potential_prescription.lower()
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
allergy_lower = allergy.lower()
|
| 160 |
-
# Simple direct check
|
| 161 |
-
if allergy_lower == potential_med_lower:
|
| 162 |
warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
|
| 163 |
continue
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
for cross_reactant in ALLERGY_INTERACTIONS[allergy_lower]:
|
| 167 |
if cross_reactant.lower() == potential_med_lower:
|
| 168 |
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
|
| 169 |
|
| 170 |
-
# Check Drug-Drug Interactions (using simplified mock data)
|
| 171 |
-
current_meds_lower = [med.lower() for med in current_medications]
|
| 172 |
for current_med in current_meds_lower:
|
| 173 |
-
# Check pairs in both orders
|
| 174 |
pair1 = (current_med, potential_med_lower)
|
| 175 |
pair2 = (potential_med_lower, current_med)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
if pair1 in MOCK_INTERACTION_DB:
|
| 177 |
-
warnings.append(f"Interaction
|
| 178 |
elif pair2 in MOCK_INTERACTION_DB:
|
| 179 |
-
warnings.append(f"Interaction
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
if not warnings:
|
| 182 |
-
return json.dumps({"status": "clear", "message": f"No major interactions identified for {potential_prescription} with current meds/allergies.", "warnings": []})
|
| 183 |
-
else:
|
| 184 |
-
return json.dumps({"status": "warning", "message": f"Potential interactions identified for {potential_prescription}.", "warnings": warnings})
|
| 185 |
|
| 186 |
class FlagRiskInput(BaseModel):
|
| 187 |
risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
|
|
@@ -190,372 +182,415 @@ class FlagRiskInput(BaseModel):
|
|
| 190 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
| 191 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
| 192 |
"""Flags a critical risk identified during analysis for immediate attention."""
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
"message": f"Risk '{risk_description}' flagged with {urgency} urgency."
|
| 197 |
-
})
|
| 198 |
-
|
| 199 |
|
| 200 |
# Initialize Search Tool
|
| 201 |
-
search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
try:
|
| 263 |
-
#
|
| 264 |
-
|
| 265 |
-
json_match = re.search(r"```json\n(\{.*?\})\n```", ai_response.content, re.DOTALL)
|
| 266 |
if json_match:
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
else:
|
| 269 |
-
#
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
except json.JSONDecodeError:
|
| 272 |
-
st.
|
| 273 |
-
st.code(ai_response.content, language=None) # Display raw if not JSON
|
| 274 |
-
response_content = {"assessment": ai_response.content, "error": "Output format incorrect"}
|
| 275 |
|
| 276 |
-
# Extract tool calls separately
|
| 277 |
-
if ai_response.tool_calls:
|
| 278 |
-
tool_calls = ai_response.tool_calls
|
| 279 |
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
| 291 |
|
| 292 |
-
if not selected_tool:
|
| 293 |
-
return json.dumps({"status": "error", "message": f"Unknown tool: {tool_name}"})
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
return selected_tool.invoke(tool_args)
|
| 298 |
-
except Exception as e:
|
| 299 |
-
st.error(f"Error executing tool '{tool_name}': {str(e)}")
|
| 300 |
-
return json.dumps({"status": "error", "message": f"Failed to execute {tool_name}: {str(e)}"})
|
| 301 |
|
| 302 |
-
# --- Streamlit UI ---
|
| 303 |
-
def main():
|
| 304 |
-
st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
|
| 305 |
-
st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
|
| 306 |
-
st.caption(f"Powered by Langchain & Groq ({ClinicalAppSettings.MODEL_NAME})")
|
| 307 |
-
|
| 308 |
-
# Initialize Agent in session state
|
| 309 |
-
if 'agent' not in st.session_state:
|
| 310 |
-
st.session_state.agent = ClinicalAgent()
|
| 311 |
-
if 'analysis_complete' not in st.session_state:
|
| 312 |
-
st.session_state.analysis_complete = False
|
| 313 |
-
if 'analysis_result' not in st.session_state:
|
| 314 |
-
st.session_state.analysis_result = None
|
| 315 |
-
if 'tool_call_results' not in st.session_state:
|
| 316 |
-
st.session_state.tool_call_results = []
|
| 317 |
-
if 'red_flags' not in st.session_state:
|
| 318 |
-
st.session_state.red_flags = []
|
| 319 |
-
|
| 320 |
-
# --- Patient Data Input Sidebar ---
|
| 321 |
-
with st.sidebar:
|
| 322 |
-
st.header("📄 Patient Intake Form")
|
| 323 |
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.")
|
| 349 |
-
|
| 350 |
-
# Review of Systems (ROS) - Simplified
|
| 351 |
-
# st.subheader("Review of Systems (ROS)") # Keep UI cleaner for now
|
| 352 |
-
# ros_constitutional = st.checkbox("ROS: Constitutional (Fever, Chills, Weight loss)")
|
| 353 |
-
# ros_cardiac = st.checkbox("ROS: Cardiac (Chest pain, Palpitations)", value=True) # Pre-check based on HPI
|
| 354 |
-
|
| 355 |
-
# Vitals & Basic Exam
|
| 356 |
-
st.subheader("Vitals & Exam Findings")
|
| 357 |
-
col1, col2 = st.columns(2)
|
| 358 |
-
with col1:
|
| 359 |
-
temp_c = st.number_input("Temperature (°C)", 35.0, 42.0, 36.8, format="%.1f")
|
| 360 |
-
hr_bpm = st.number_input("Heart Rate (bpm)", 30, 250, 95)
|
| 361 |
-
rr_rpm = st.number_input("Respiratory Rate (rpm)", 5, 50, 18)
|
| 362 |
-
with col2:
|
| 363 |
-
bp_mmhg = st.text_input("Blood Pressure (SYS/DIA)", "155/90")
|
| 364 |
-
spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96)
|
| 365 |
-
pain_scale = st.slider("Pain (0-10)", 0, 10, 8)
|
| 366 |
-
exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear. Cardiac exam: Regular rhythm, no murmurs/gallops. Abdomen soft. No edema.")
|
| 367 |
-
|
| 368 |
-
# Clean medication list and allergies for processing
|
| 369 |
-
current_meds_list = [med.strip() for med in current_meds.split('\n') if med.strip()]
|
| 370 |
-
current_med_names = [med.split(' ')[0].strip() for med in current_meds_list] # Simplified name extraction
|
| 371 |
-
allergies_list = [a.strip() for a in allergies.split(',') if a.strip()]
|
| 372 |
-
|
| 373 |
-
# Compile Patient Data Dictionary
|
| 374 |
-
patient_data = {
|
| 375 |
-
"demographics": {"age": age, "sex": sex},
|
| 376 |
-
"hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
|
| 377 |
-
"pmh": {"conditions": pmh},
|
| 378 |
-
"psh": {"procedures": psh},
|
| 379 |
-
"medications": {"current": current_meds_list, "names_only": current_med_names},
|
| 380 |
-
"allergies": allergies_list,
|
| 381 |
-
"social_history": {"details": social_history},
|
| 382 |
-
"family_history": {"details": family_history},
|
| 383 |
-
# "ros": {"constitutional": ros_constitutional, "cardiac": ros_cardiac}, # Add if using ROS inputs
|
| 384 |
-
"vitals": {
|
| 385 |
-
"temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg,
|
| 386 |
-
"rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale
|
| 387 |
-
},
|
| 388 |
-
"exam_findings": {"notes": exam_notes}
|
| 389 |
-
}
|
| 390 |
|
| 391 |
-
# --- Main Analysis Area ---
|
| 392 |
-
st.header("🤖 AI Clinical Analysis")
|
| 393 |
-
|
| 394 |
-
# Action Button
|
| 395 |
-
if st.button("Analyze Patient Data", type="primary", use_container_width=True):
|
| 396 |
-
st.session_state.analysis_complete = False
|
| 397 |
-
st.session_state.analysis_result = None
|
| 398 |
-
st.session_state.tool_call_results = []
|
| 399 |
-
st.session_state.red_flags = []
|
| 400 |
-
|
| 401 |
-
# 1. Initial Red Flag Check (Client-side before LLM)
|
| 402 |
-
st.session_state.red_flags = check_red_flags(patient_data)
|
| 403 |
-
if st.session_state.red_flags:
|
| 404 |
-
st.warning("**Initial Red Flags Detected:**")
|
| 405 |
-
for flag in st.session_state.red_flags:
|
| 406 |
-
st.warning(f"- {flag}")
|
| 407 |
-
st.warning("Proceeding with AI analysis, but these require immediate attention.")
|
| 408 |
-
|
| 409 |
-
# 2. Call AI Agent
|
| 410 |
-
with st.spinner("SynapseAI is processing the case... Please wait."):
|
| 411 |
-
analysis_output, tool_calls = st.session_state.agent.analyze(patient_data)
|
| 412 |
-
|
| 413 |
-
if analysis_output:
|
| 414 |
-
st.session_state.analysis_result = analysis_output
|
| 415 |
-
st.session_state.analysis_complete = True
|
| 416 |
-
|
| 417 |
-
# 3. Process any Tool Calls requested by the AI
|
| 418 |
-
if tool_calls:
|
| 419 |
-
st.info(f"AI recommended {len(tool_calls)} action(s). Executing...")
|
| 420 |
-
tool_results = []
|
| 421 |
-
with st.spinner("Executing recommended actions..."):
|
| 422 |
-
for call in tool_calls:
|
| 423 |
-
st.write(f"⚙️ Requesting: `{call['name']}` with args `{call['args']}`")
|
| 424 |
-
# Pass patient context if needed (e.g., for interaction check)
|
| 425 |
-
if call['name'] == 'check_drug_interactions':
|
| 426 |
-
call['args']['current_medications'] = patient_data['medications']['names_only']
|
| 427 |
-
call['args']['allergies'] = patient_data['allergies']
|
| 428 |
-
elif call['name'] == 'prescribe_medication':
|
| 429 |
-
# Pre-flight check: Ensure interaction check was requested *before* this prescribe call
|
| 430 |
-
interaction_check_requested = any(tc['name'] == 'check_drug_interactions' and tc['args'].get('potential_prescription') == call['args'].get('medication_name') for tc in tool_calls)
|
| 431 |
-
if not interaction_check_requested:
|
| 432 |
-
st.error(f"**Safety Violation:** AI attempted to prescribe '{call['args'].get('medication_name')}' without requesting `check_drug_interactions` first. Prescription blocked.")
|
| 433 |
-
tool_results.append({"tool_call_id": call['id'], "name": call['name'], "output": json.dumps({"status":"error", "message": "Interaction check not performed prior to prescription attempt."})})
|
| 434 |
-
continue # Skip this tool call
|
| 435 |
-
|
| 436 |
-
result = st.session_state.agent.process_tool_call(call)
|
| 437 |
-
tool_results.append({"tool_call_id": call['id'], "name": call['name'], "output": result}) # Store result with ID
|
| 438 |
-
|
| 439 |
-
# Display tool result immediately
|
| 440 |
-
try:
|
| 441 |
-
result_data = json.loads(result)
|
| 442 |
-
if result_data.get("status") == "success" or result_data.get("status") == "clear" or result_data.get("status") == "flagged":
|
| 443 |
-
st.success(f"✅ Action `{call['name']}`: {result_data.get('message')}", icon="✅")
|
| 444 |
-
if result_data.get("details"): st.caption(f"Details: {result_data.get('details')}")
|
| 445 |
-
elif result_data.get("status") == "warning":
|
| 446 |
-
st.warning(f"⚠️ Action `{call['name']}`: {result_data.get('message')}", icon="⚠️")
|
| 447 |
-
if result_data.get("warnings"):
|
| 448 |
-
for warn in result_data["warnings"]: st.caption(f"- {warn}")
|
| 449 |
-
else:
|
| 450 |
-
st.error(f"❌ Action `{call['name']}`: {result_data.get('message')}", icon="❌")
|
| 451 |
-
except json.JSONDecodeError:
|
| 452 |
-
st.error(f"Tool `{call['name']}` returned non-JSON: {result}") # Fallback for non-JSON results
|
| 453 |
-
|
| 454 |
-
st.session_state.tool_call_results = tool_results
|
| 455 |
-
# Optionally: Send results back to LLM for final summary (requires multi-turn agent)
|
| 456 |
-
else:
|
| 457 |
-
st.error("Analysis failed. Please check the input data or try again.")
|
| 458 |
-
|
| 459 |
-
# --- Display Analysis Results ---
|
| 460 |
-
if st.session_state.analysis_complete and st.session_state.analysis_result:
|
| 461 |
-
st.divider()
|
| 462 |
-
st.header("📊 Analysis & Recommendations")
|
| 463 |
-
|
| 464 |
-
res = st.session_state.analysis_result
|
| 465 |
-
|
| 466 |
-
# Layout columns for better readability
|
| 467 |
-
col_assessment, col_plan = st.columns(2)
|
| 468 |
-
|
| 469 |
-
with col_assessment:
|
| 470 |
-
st.subheader("📋 Assessment")
|
| 471 |
-
st.write(res.get("assessment", "N/A"))
|
| 472 |
-
|
| 473 |
-
st.subheader("🤔 Differential Diagnosis")
|
| 474 |
-
ddx = res.get("differential_diagnosis", [])
|
| 475 |
-
if ddx:
|
| 476 |
-
for item in ddx:
|
| 477 |
-
likelihood = item.get('likelihood', 'Unknown').capitalize()
|
| 478 |
-
icon = "🥇" if likelihood=="High" else ("🥈" if likelihood=="Medium" else "🥉")
|
| 479 |
-
with st.expander(f"{icon} {item.get('diagnosis', 'Unknown Diagnosis')} ({likelihood} Likelihood)", expanded=(likelihood=="High")):
|
| 480 |
-
st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
|
| 481 |
-
else:
|
| 482 |
-
st.info("No differential diagnosis provided.")
|
| 483 |
-
|
| 484 |
-
st.subheader("🚨 Risk Assessment")
|
| 485 |
-
risk = res.get("risk_assessment", {})
|
| 486 |
-
flags = risk.get("identified_red_flags", []) + [f.replace("Red Flag: ", "") for f in st.session_state.red_flags] # Combine AI and initial flags
|
| 487 |
-
if flags:
|
| 488 |
-
st.warning(f"**Identified Red Flags:** {', '.join(flags)}")
|
| 489 |
-
else:
|
| 490 |
-
st.success("No immediate red flags identified by AI in this analysis.")
|
| 491 |
-
|
| 492 |
-
if risk.get("immediate_concerns"):
|
| 493 |
-
st.warning(f"**Immediate Concerns:** {', '.join(risk.get('immediate_concerns'))}")
|
| 494 |
-
if risk.get("potential_complications"):
|
| 495 |
-
st.info(f"**Potential Complications:** {', '.join(risk.get('potential_complications'))}")
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
with col_plan:
|
| 499 |
-
st.subheader("📝 Recommended Plan")
|
| 500 |
-
plan = res.get("recommended_plan", {})
|
| 501 |
-
|
| 502 |
-
st.markdown("**Investigations:**")
|
| 503 |
-
if plan.get("investigations"):
|
| 504 |
-
st.markdown("\n".join([f"- {inv}" for inv in plan.get("investigations")]))
|
| 505 |
-
else: st.markdown("_None suggested._")
|
| 506 |
-
|
| 507 |
-
st.markdown("**Therapeutics:**")
|
| 508 |
-
if plan.get("therapeutics"):
|
| 509 |
-
st.markdown("\n".join([f"- {thx}" for thx in plan.get("therapeutics")]))
|
| 510 |
-
else: st.markdown("_None suggested._")
|
| 511 |
-
|
| 512 |
-
st.markdown("**Consultations:**")
|
| 513 |
-
if plan.get("consultations"):
|
| 514 |
-
st.markdown("\n".join([f"- {con}" for con in plan.get("consultations")]))
|
| 515 |
-
else: st.markdown("_None suggested._")
|
| 516 |
-
|
| 517 |
-
st.markdown("**Patient Education:**")
|
| 518 |
-
if plan.get("patient_education"):
|
| 519 |
-
st.markdown("\n".join([f"- {edu}" for edu in plan.get("patient_education")]))
|
| 520 |
-
else: st.markdown("_None specified._")
|
| 521 |
-
|
| 522 |
-
# Display Rationale and Interaction Summary below the columns
|
| 523 |
-
st.subheader("🧠 AI Rationale & Checks")
|
| 524 |
-
with st.expander("Show AI Reasoning Summary", expanded=False):
|
| 525 |
-
st.write(res.get("rationale_summary", "No rationale summary provided."))
|
| 526 |
-
|
| 527 |
-
interaction_summary = res.get("interaction_check_summary", "")
|
| 528 |
-
if interaction_summary: # Only show if interaction check was relevant/performed
|
| 529 |
-
with st.expander("Drug Interaction Check Summary", expanded=True):
|
| 530 |
-
st.write(interaction_summary)
|
| 531 |
-
# Also show detailed results from the tool call itself if available
|
| 532 |
-
for tool_res in st.session_state.tool_call_results:
|
| 533 |
-
if tool_res['name'] == 'check_drug_interactions':
|
| 534 |
-
try:
|
| 535 |
-
data = json.loads(tool_res['output'])
|
| 536 |
-
if data.get('warnings'):
|
| 537 |
-
st.warning("Interaction Details:")
|
| 538 |
-
for warn in data['warnings']:
|
| 539 |
-
st.caption(f"- {warn}")
|
| 540 |
-
else:
|
| 541 |
-
st.success("Interaction Details: " + data.get('message', 'Check complete.'))
|
| 542 |
-
except: pass # Ignore parsing errors here
|
| 543 |
-
|
| 544 |
-
# Display raw JSON if needed for debugging
|
| 545 |
-
with st.expander("Show Raw AI Output (JSON)"):
|
| 546 |
-
st.json(res)
|
| 547 |
-
|
| 548 |
-
st.divider()
|
| 549 |
-
st.success("Analysis Complete.")
|
| 550 |
|
| 551 |
# Disclaimer
|
| 552 |
st.markdown("---")
|
| 553 |
-
st.warning(
|
| 554 |
-
"""**Disclaimer:** SynapseAI is an AI assistant for clinical decision support and does not replace professional medical judgment.
|
| 555 |
-
All outputs should be critically reviewed by a qualified healthcare provider before making any clinical decisions.
|
| 556 |
-
Verify all information, especially dosages and interactions, independently."""
|
| 557 |
-
)
|
| 558 |
-
|
| 559 |
|
| 560 |
if __name__ == "__main__":
|
| 561 |
main()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from langchain_groq import ChatGroq
|
| 3 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 4 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
|
| 5 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
| 6 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
| 7 |
from langchain_core.tools import tool
|
| 8 |
+
from langgraph.prebuilt import ToolExecutor
|
| 9 |
+
from langgraph.graph import StateGraph, END
|
| 10 |
+
from langgraph.checkpoint.memory import MemorySaver # For state persistence (optional but good)
|
| 11 |
+
|
| 12 |
+
from typing import Optional, List, Dict, Any, TypedDict, Annotated
|
| 13 |
import json
|
| 14 |
+
import re
|
| 15 |
+
import operator
|
| 16 |
|
| 17 |
+
# --- Configuration & Constants --- (Keep previous ones like ClinicalAppSettings)
|
| 18 |
class ClinicalAppSettings:
|
| 19 |
+
APP_TITLE = "SynapseAI: Interactive Clinical Decision Support"
|
| 20 |
PAGE_LAYOUT = "wide"
|
| 21 |
+
MODEL_NAME = "llama3-70b-8192"
|
| 22 |
TEMPERATURE = 0.1
|
| 23 |
MAX_SEARCH_RESULTS = 3
|
| 24 |
|
| 25 |
class ClinicalPrompts:
|
| 26 |
+
# UPDATED SYSTEM PROMPT FOR CONVERSATIONAL FLOW & GUIDELINES
|
| 27 |
SYSTEM_PROMPT = """
|
| 28 |
+
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
|
| 29 |
+
Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
|
| 30 |
|
| 31 |
+
**Core Directives for this Conversation:**
|
| 32 |
+
1. **Analyze Sequentially:** Process information turn-by-turn. You will receive initial patient data, and potentially follow-up messages or results from tools you requested. Base your responses on the *entire* conversation history.
|
| 33 |
+
2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
|
| 34 |
+
3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions), provide a comprehensive assessment using the following JSON structure. Only output this structure when you believe you have a complete initial analysis or plan. Do NOT output incomplete JSON.
|
| 35 |
```json
|
| 36 |
{
|
| 37 |
+
"assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
|
| 38 |
"differential_diagnosis": [
|
| 39 |
+
{"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
|
| 40 |
{"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
|
| 41 |
{"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
|
| 42 |
],
|
| 43 |
"risk_assessment": {
|
| 44 |
+
"identified_red_flags": ["List any triggered red flags"],
|
| 45 |
+
"immediate_concerns": ["Specific urgent issues (e.g., sepsis risk, ACS rule-out)"],
|
| 46 |
+
"potential_complications": ["Possible future issues"]
|
| 47 |
},
|
| 48 |
"recommended_plan": {
|
| 49 |
+
"investigations": ["List specific lab tests or imaging needed. Use 'order_lab_test' tool."],
|
| 50 |
+
"therapeutics": ["Suggest specific treatments/prescriptions. Use 'prescribe_medication' tool. MUST check interactions first."],
|
| 51 |
+
"consultations": ["Recommend specialist consultations."],
|
| 52 |
"patient_education": ["Key points for patient communication."]
|
| 53 |
},
|
| 54 |
+
"rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.**",
|
| 55 |
+
"interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
|
| 56 |
}
|
| 57 |
```
|
| 58 |
+
4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions`. Report the findings. If interactions exist, modify the plan or state the contraindication.
|
| 59 |
+
5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point.
|
| 60 |
+
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription).
|
| 61 |
+
7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
|
| 62 |
+
8. **Conciseness:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation until ready for the full structured JSON output.
|
| 63 |
"""
|
| 64 |
|
| 65 |
+
# --- Mock Data / Helpers --- (Keep previous ones like MOCK_INTERACTION_DB, ALLERGY_INTERACTIONS, parse_bp, check_red_flags)
|
| 66 |
+
# (Include the helper functions from the previous response here)
|
| 67 |
MOCK_INTERACTION_DB = {
|
| 68 |
+
("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
|
| 69 |
+
("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
|
| 70 |
+
("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
|
| 71 |
+
("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding."
|
| 72 |
}
|
| 73 |
|
| 74 |
ALLERGY_INTERACTIONS = {
|
| 75 |
+
"penicillin": ["amoxicillin", "ampicillin", "piperacillin"],
|
| 76 |
+
"sulfa": ["sulfamethoxazole", "sulfasalazine"],
|
| 77 |
+
"aspirin": ["ibuprofen", "naproxen"] # Cross-reactivity example for NSAIDs
|
| 78 |
}
|
| 79 |
|
| 80 |
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
|
|
|
| 81 |
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
|
| 82 |
+
if match: return int(match.group(1)), int(match.group(2))
|
|
|
|
| 83 |
return None
|
| 84 |
|
| 85 |
def check_red_flags(patient_data: dict) -> List[str]:
|
|
|
|
| 86 |
flags = []
|
| 87 |
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
|
| 88 |
vitals = patient_data.get("vitals", {})
|
| 89 |
history = patient_data.get("pmh", {}).get("conditions", "")
|
| 90 |
+
symptoms_lower = [s.lower() for s in symptoms]
|
| 91 |
+
|
| 92 |
+
if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
|
| 93 |
+
if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
|
| 94 |
+
if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
|
| 95 |
+
# Add other symptom checks...
|
| 96 |
|
| 97 |
+
if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever ({vitals['temp_c']}°C).")
|
| 98 |
+
if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia ({vitals['hr_bpm']} bpm).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
if "bp_mmhg" in vitals:
|
| 100 |
bp = parse_bp(vitals["bp_mmhg"])
|
| 101 |
+
if bp and (bp[0] >= 180 or bp[1] >= 110): flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).")
|
| 102 |
+
if bp and (bp[0] <= 90 or bp[1] <= 60): flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).")
|
| 103 |
+
# Add other vital checks...
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
if "history of mi" in history.lower() and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
|
| 106 |
+
# Add other history checks...
|
| 107 |
return flags
|
| 108 |
|
|
|
|
| 109 |
|
| 110 |
+
# --- Enhanced Tool Definitions --- (Keep previous Pydantic models and @tool functions)
|
| 111 |
+
# (Include LabOrderInput, PrescriptionInput, InteractionCheckInput, FlagRiskInput
|
| 112 |
+
# and the corresponding @tool functions: order_lab_test, prescribe_medication,
|
| 113 |
+
# check_drug_interactions, flag_risk from the previous response here)
|
| 114 |
+
|
| 115 |
class LabOrderInput(BaseModel):
|
| 116 |
test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
|
| 117 |
reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
|
|
|
|
| 120 |
@tool("order_lab_test", args_schema=LabOrderInput)
|
| 121 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
| 122 |
"""Orders a specific lab test with clinical justification and priority."""
|
| 123 |
+
return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
class PrescriptionInput(BaseModel):
|
| 126 |
medication_name: str = Field(..., description="Name of the medication.")
|
|
|
|
| 133 |
@tool("prescribe_medication", args_schema=PrescriptionInput)
|
| 134 |
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
|
| 135 |
"""Prescribes a medication with detailed instructions and clinical indication."""
|
| 136 |
+
# NOTE: Interaction check should have been done *before* calling this via a separate tool call
|
| 137 |
+
return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
class InteractionCheckInput(BaseModel):
|
| 140 |
potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
|
|
|
|
| 146 |
"""Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
|
| 147 |
warnings = []
|
| 148 |
potential_med_lower = potential_prescription.lower()
|
| 149 |
+
current_meds_lower = [med.lower() for med in current_medications]
|
| 150 |
+
allergies_lower = [a.lower() for a in allergies]
|
| 151 |
|
| 152 |
+
for allergy in allergies_lower:
|
| 153 |
+
if allergy == potential_med_lower:
|
|
|
|
|
|
|
|
|
|
| 154 |
warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
|
| 155 |
continue
|
| 156 |
+
if allergy in ALLERGY_INTERACTIONS:
|
| 157 |
+
for cross_reactant in ALLERGY_INTERACTIONS[allergy]:
|
|
|
|
| 158 |
if cross_reactant.lower() == potential_med_lower:
|
| 159 |
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
|
| 160 |
|
|
|
|
|
|
|
| 161 |
for current_med in current_meds_lower:
|
|
|
|
| 162 |
pair1 = (current_med, potential_med_lower)
|
| 163 |
pair2 = (potential_med_lower, current_med)
|
| 164 |
+
# Normalize keys for lookup if necessary (e.g., if DB keys are canonical names)
|
| 165 |
+
key1 = tuple(sorted(pair1))
|
| 166 |
+
key2 = tuple(sorted(pair2)) # Although redundant if always sorted
|
| 167 |
+
|
| 168 |
if pair1 in MOCK_INTERACTION_DB:
|
| 169 |
+
warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}")
|
| 170 |
elif pair2 in MOCK_INTERACTION_DB:
|
| 171 |
+
warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}")
|
| 172 |
+
|
| 173 |
+
status = "warning" if warnings else "clear"
|
| 174 |
+
message = f"Interaction check for {potential_prescription}: {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for {potential_prescription}."
|
| 175 |
+
return json.dumps({"status": status, "message": message, "warnings": warnings})
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
class FlagRiskInput(BaseModel):
|
| 179 |
risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
|
|
|
|
| 182 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
| 183 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
| 184 |
"""Flags a critical risk identified during analysis for immediate attention."""
|
| 185 |
+
# Display in Streamlit immediately
|
| 186 |
+
st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
|
| 187 |
+
return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
# Initialize Search Tool
|
| 190 |
+
search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# --- LangGraph Setup ---
|
| 194 |
+
|
| 195 |
+
# Define the state structure
|
| 196 |
+
class AgentState(TypedDict):
|
| 197 |
+
messages: Annotated[list[Any], operator.add] # Accumulates messages (Human, AI, Tool)
|
| 198 |
+
patient_data: Optional[dict] # Holds the structured patient data (can be updated if needed)
|
| 199 |
+
# Potentially add other state elements like 'interaction_check_needed_for': Optional[str]
|
| 200 |
+
|
| 201 |
+
# Define Tools and Tool Executor
|
| 202 |
+
tools = [
|
| 203 |
+
order_lab_test,
|
| 204 |
+
prescribe_medication,
|
| 205 |
+
check_drug_interactions,
|
| 206 |
+
flag_risk,
|
| 207 |
+
search_tool
|
| 208 |
+
]
|
| 209 |
+
tool_executor = ToolExecutor(tools)
|
| 210 |
+
|
| 211 |
+
# Define the Agent Model
|
| 212 |
+
model = ChatGroq(
|
| 213 |
+
temperature=ClinicalAppSettings.TEMPERATURE,
|
| 214 |
+
model=ClinicalAppSettings.MODEL_NAME
|
| 215 |
+
)
|
| 216 |
+
model_with_tools = model.bind_tools(tools) # Bind tools for the LLM to know about them
|
| 217 |
+
|
| 218 |
+
# --- Graph Nodes ---
|
| 219 |
+
|
| 220 |
+
# 1. Agent Node: Calls the LLM
|
| 221 |
+
def agent_node(state: AgentState):
|
| 222 |
+
"""Invokes the LLM to decide the next action or response."""
|
| 223 |
+
print("---AGENT NODE---")
|
| 224 |
+
# Make sure patient data is included in the first message if not already there
|
| 225 |
+
# This is a basic way; more robust would be merging patient_data into context
|
| 226 |
+
current_messages = state['messages']
|
| 227 |
+
if len(current_messages) == 1 and isinstance(current_messages[0], HumanMessage) and state.get('patient_data'):
|
| 228 |
+
# Augment the first human message with formatted patient data
|
| 229 |
+
formatted_data = format_patient_data_for_prompt(state['patient_data']) # Need this helper function
|
| 230 |
+
current_messages = [
|
| 231 |
+
SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT), # Ensure system prompt is first
|
| 232 |
+
HumanMessage(content=f"{current_messages[0].content}\n\n**Initial Patient Data:**\n{formatted_data}")
|
| 233 |
+
]
|
| 234 |
+
elif not any(isinstance(m, SystemMessage) for m in current_messages):
|
| 235 |
+
# Add system prompt if missing
|
| 236 |
+
current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
response = model_with_tools.invoke(current_messages)
|
| 240 |
+
print(f"Agent response: {response}")
|
| 241 |
+
return {"messages": [response]}
|
| 242 |
+
|
| 243 |
+
# 2. Tool Node: Executes tools called by the Agent
|
| 244 |
+
def tool_node(state: AgentState):
|
| 245 |
+
"""Executes tools called by the LLM and returns results."""
|
| 246 |
+
print("---TOOL NODE---")
|
| 247 |
+
last_message = state['messages'][-1]
|
| 248 |
+
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
| 249 |
+
print("No tool calls in last message.")
|
| 250 |
+
return {} # Should not happen if routing is correct, but safety check
|
| 251 |
+
|
| 252 |
+
tool_calls = last_message.tool_calls
|
| 253 |
+
tool_messages = []
|
| 254 |
+
|
| 255 |
+
# Safety Check: Ensure interaction check happens *before* prescribing the *same* drug
|
| 256 |
+
prescribe_calls = {call['args'].get('medication_name'): call['id'] for call in tool_calls if call['name'] == 'prescribe_medication'}
|
| 257 |
+
interaction_check_calls = {call['args'].get('potential_prescription'): call['id'] for call in tool_calls if call['name'] == 'check_drug_interactions'}
|
| 258 |
+
|
| 259 |
+
for med_name, prescribe_call_id in prescribe_calls.items():
|
| 260 |
+
if med_name not in interaction_check_calls:
|
| 261 |
+
st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked for this turn.")
|
| 262 |
+
# Create an error ToolMessage to send back to the LLM
|
| 263 |
+
error_msg = ToolMessage(
|
| 264 |
+
content=json.dumps({"status": "error", "message": f"Interaction check for {med_name} must be requested *before or alongside* the prescription call."}),
|
| 265 |
+
tool_call_id=prescribe_call_id
|
| 266 |
+
)
|
| 267 |
+
tool_messages.append(error_msg)
|
| 268 |
+
# Remove the invalid prescribe call to prevent execution
|
| 269 |
+
tool_calls = [call for call in tool_calls if call['id'] != prescribe_call_id]
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# Add patient context to interaction checks if needed
|
| 273 |
+
patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", [])
|
| 274 |
+
patient_allergies = state.get("patient_data", {}).get("allergies", [])
|
| 275 |
+
for call in tool_calls:
|
| 276 |
+
if call['name'] == 'check_drug_interactions':
|
| 277 |
+
call['args']['current_medications'] = patient_meds
|
| 278 |
+
call['args']['allergies'] = patient_allergies
|
| 279 |
+
print(f"Augmented interaction check args: {call['args']}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Execute remaining valid tool calls
|
| 283 |
+
if tool_calls:
|
| 284 |
+
responses = tool_executor.batch(tool_calls)
|
| 285 |
+
# Responses is a list of tool outputs corresponding to tool_calls
|
| 286 |
+
# We need to create ToolMessage objects
|
| 287 |
+
tool_messages.extend([
|
| 288 |
+
ToolMessage(content=str(resp), tool_call_id=call['id'])
|
| 289 |
+
for call, resp in zip(tool_calls, responses)
|
| 290 |
+
])
|
| 291 |
+
print(f"Tool results: {tool_messages}")
|
| 292 |
+
|
| 293 |
+
return {"messages": tool_messages}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# --- Graph Edges (Routing Logic) ---
|
| 297 |
+
def should_continue(state: AgentState) -> str:
|
| 298 |
+
"""Determines whether to continue the loop or end."""
|
| 299 |
+
last_message = state['messages'][-1]
|
| 300 |
+
# If the LLM made tool calls, we execute them
|
| 301 |
+
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
| 302 |
+
print("Routing: continue_tools")
|
| 303 |
+
return "continue_tools"
|
| 304 |
+
# Otherwise, we end the loop (AI provided a direct answer or finished)
|
| 305 |
+
else:
|
| 306 |
+
print("Routing: end_conversation_turn")
|
| 307 |
+
return "end_conversation_turn"
|
| 308 |
+
|
| 309 |
+
# --- Graph Definition ---
|
| 310 |
+
workflow = StateGraph(AgentState)
|
| 311 |
+
|
| 312 |
+
# Add nodes
|
| 313 |
+
workflow.add_node("agent", agent_node)
|
| 314 |
+
workflow.add_node("tools", tool_node)
|
| 315 |
+
|
| 316 |
+
# Define entry point
|
| 317 |
+
workflow.set_entry_point("agent")
|
| 318 |
+
|
| 319 |
+
# Add conditional edges
|
| 320 |
+
workflow.add_conditional_edges(
|
| 321 |
+
"agent", # Source node
|
| 322 |
+
should_continue, # Function to decide the route
|
| 323 |
+
{
|
| 324 |
+
"continue_tools": "tools", # If tool calls exist, go to tools node
|
| 325 |
+
"end_conversation_turn": END # Otherwise, end the graph iteration
|
| 326 |
+
}
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Add edge from tools back to agent
|
| 330 |
+
workflow.add_edge("tools", "agent")
|
| 331 |
+
|
| 332 |
+
# Compile the graph
|
| 333 |
+
# memory = MemorySaverInMemory() # Optional: for persisting state across runs
|
| 334 |
+
# app = workflow.compile(checkpointer=memory)
|
| 335 |
+
app = workflow.compile()
|
| 336 |
+
|
| 337 |
+
# --- Helper Function to Format Patient Data ---
|
| 338 |
+
def format_patient_data_for_prompt(data: dict) -> str:
|
| 339 |
+
"""Formats the patient dictionary into a readable string for the LLM."""
|
| 340 |
+
prompt_str = ""
|
| 341 |
+
for key, value in data.items():
|
| 342 |
+
if isinstance(value, dict):
|
| 343 |
+
section_title = key.replace('_', ' ').title()
|
| 344 |
+
prompt_str += f"**{section_title}:**\n"
|
| 345 |
+
for sub_key, sub_value in value.items():
|
| 346 |
+
if sub_value:
|
| 347 |
+
prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
|
| 348 |
+
elif isinstance(value, list) and value:
|
| 349 |
+
prompt_str += f"**{key.replace('_', ' ').title()}:** {', '.join(map(str, value))}\n"
|
| 350 |
+
elif value:
|
| 351 |
+
prompt_str += f"**{key.replace('_', ' ').title()}:** {value}\n"
|
| 352 |
+
return prompt_str.strip()
|
| 353 |
+
|
| 354 |
+
# --- Streamlit UI (Modified for Conversation) ---
|
| 355 |
+
def main():
|
| 356 |
+
st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
|
| 357 |
+
st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
|
| 358 |
+
st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})")
|
| 359 |
+
|
| 360 |
+
# Initialize session state for conversation
|
| 361 |
+
if "messages" not in st.session_state:
|
| 362 |
+
st.session_state.messages = [] # Store entire conversation history (Human, AI, Tool)
|
| 363 |
+
if "patient_data" not in st.session_state:
|
| 364 |
+
st.session_state.patient_data = None
|
| 365 |
+
if "initial_analysis_done" not in st.session_state:
|
| 366 |
+
st.session_state.initial_analysis_done = False
|
| 367 |
+
if "graph_app" not in st.session_state:
|
| 368 |
+
st.session_state.graph_app = app # Store compiled graph
|
| 369 |
+
|
| 370 |
+
# --- Patient Data Input Sidebar --- (Similar to before)
|
| 371 |
+
with st.sidebar:
|
| 372 |
+
st.header("📄 Patient Intake Form")
|
| 373 |
+
# ... (Keep the input fields exactly as in the previous example) ...
|
| 374 |
+
# Demographics
|
| 375 |
+
age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
|
| 376 |
+
sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
|
| 377 |
+
# HPI
|
| 378 |
+
chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
|
| 379 |
+
hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input")
|
| 380 |
+
symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"], key="sym_input")
|
| 381 |
+
# History
|
| 382 |
+
pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)", key="pmh_input")
|
| 383 |
+
psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
|
| 384 |
+
# Meds & Allergies
|
| 385 |
+
current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
|
| 386 |
+
allergies_str = st.text_area("Allergies (comma separated)", "Penicillin (rash)", key="allergy_input")
|
| 387 |
+
# Social/Family
|
| 388 |
+
social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
|
| 389 |
+
family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
|
| 390 |
+
# Vitals/Exam
|
| 391 |
+
col1, col2 = st.columns(2)
|
| 392 |
+
with col1:
|
| 393 |
+
temp_c = st.number_input("Temp (°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
|
| 394 |
+
hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input")
|
| 395 |
+
rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input")
|
| 396 |
+
with col2:
|
| 397 |
+
bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
|
| 398 |
+
spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
|
| 399 |
+
pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
|
| 400 |
+
exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3...", key="exam_input")
|
| 401 |
+
|
| 402 |
+
# Compile Patient Data Dictionary on button press
|
| 403 |
+
if st.button("Start/Update Consultation", key="start_button"):
|
| 404 |
+
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
|
| 405 |
+
current_med_names = []
|
| 406 |
+
# Improved parsing for names (still basic, assumes name is first word)
|
| 407 |
+
for med in current_meds_list:
|
| 408 |
+
match = re.match(r"^\s*([a-zA-Z\-]+)", med)
|
| 409 |
+
if match:
|
| 410 |
+
current_med_names.append(match.group(1).lower()) # Use lower case for matching
|
| 411 |
+
|
| 412 |
+
allergies_list = [a.strip().lower() for a in allergies_str.split(',') if a.strip()] # Lowercase allergies
|
| 413 |
+
|
| 414 |
+
st.session_state.patient_data = {
|
| 415 |
+
"demographics": {"age": age, "sex": sex},
|
| 416 |
+
"hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
|
| 417 |
+
"pmh": {"conditions": pmh}, "psh": {"procedures": psh},
|
| 418 |
+
"medications": {"current": current_meds_list, "names_only": current_med_names},
|
| 419 |
+
"allergies": allergies_list,
|
| 420 |
+
"social_history": {"details": social_history}, "family_history": {"details": family_history},
|
| 421 |
+
"vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
|
| 422 |
+
"exam_findings": {"notes": exam_notes}
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
# Initial Red Flag Check (Client-side)
|
| 426 |
+
red_flags = check_red_flags(st.session_state.patient_data)
|
| 427 |
+
if red_flags:
|
| 428 |
+
st.warning("**Initial Red Flags Detected:**")
|
| 429 |
+
for flag in red_flags: st.warning(f"- {flag}")
|
| 430 |
+
|
| 431 |
+
# Prepare initial message for the graph
|
| 432 |
+
initial_prompt = f"Analyze the following patient case:\nChief Complaint: {chief_complaint}\nSummary: {age} y/o {sex} presenting with..." # Keep it brief, full data is in state
|
| 433 |
+
st.session_state.messages = [HumanMessage(content=initial_prompt)]
|
| 434 |
+
st.session_state.initial_analysis_done = False # Reset analysis state
|
| 435 |
+
st.success("Patient data loaded. Ready for analysis.")
|
| 436 |
+
st.rerun() # Refresh main area to show chat
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# --- Main Chat Interface Area ---
|
| 440 |
+
st.header("💬 Clinical Consultation")
|
| 441 |
+
|
| 442 |
+
# Display chat messages
|
| 443 |
+
for msg in st.session_state.messages:
|
| 444 |
+
if isinstance(msg, HumanMessage):
|
| 445 |
+
with st.chat_message("user"):
|
| 446 |
+
st.markdown(msg.content)
|
| 447 |
+
elif isinstance(msg, AIMessage):
|
| 448 |
+
with st.chat_message("assistant"):
|
| 449 |
+
# Check for structured JSON output
|
| 450 |
+
structured_output = None
|
| 451 |
try:
|
| 452 |
+
# Try to find JSON block first
|
| 453 |
+
json_match = re.search(r"```json\n(\{.*?\})\n```", msg.content, re.DOTALL)
|
|
|
|
| 454 |
if json_match:
|
| 455 |
+
structured_output = json.loads(json_match.group(1))
|
| 456 |
+
# Display non-JSON parts if any
|
| 457 |
+
non_json_content = msg.content.replace(json_match.group(0), "").strip()
|
| 458 |
+
if non_json_content:
|
| 459 |
+
st.markdown(non_json_content)
|
| 460 |
+
st.divider() # Separate text from structured output visually
|
| 461 |
+
elif msg.content.strip().startswith("{") and msg.content.strip().endswith("}"):
|
| 462 |
+
# Maybe the whole message is JSON
|
| 463 |
+
structured_output = json.loads(msg.content)
|
| 464 |
else:
|
| 465 |
+
# No JSON found, display raw content
|
| 466 |
+
st.markdown(msg.content)
|
| 467 |
+
|
| 468 |
+
if structured_output:
|
| 469 |
+
# Display the structured data nicely (reuse parts of previous UI display logic)
|
| 470 |
+
st.subheader("📊 AI Analysis & Recommendations")
|
| 471 |
+
# ... (Add logic here to display assessment, ddx, plan etc. from structured_output)
|
| 472 |
+
# Example:
|
| 473 |
+
st.write(f"**Assessment:** {structured_output.get('assessment', 'N/A')}")
|
| 474 |
+
# Display DDx, Plan etc. using expanders or tabs
|
| 475 |
+
# ...
|
| 476 |
+
# Display Rationale & Interaction Summary
|
| 477 |
+
with st.expander("Rationale & Guideline Check"):
|
| 478 |
+
st.write(structured_output.get("rationale_summary", "N/A"))
|
| 479 |
+
if structured_output.get("interaction_check_summary"):
|
| 480 |
+
with st.expander("Interaction Check"):
|
| 481 |
+
st.write(structured_output.get("interaction_check_summary"))
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
except json.JSONDecodeError:
|
| 485 |
+
st.markdown(msg.content) # Display raw if JSON parsing fails
|
| 486 |
+
|
| 487 |
+
# Display tool calls if any were made in this AI turn
|
| 488 |
+
if msg.tool_calls:
|
| 489 |
+
with st.expander("🛠️ AI requested actions", expanded=False):
|
| 490 |
+
for tc in msg.tool_calls:
|
| 491 |
+
st.code(f"{tc['name']}(args={tc['args']})", language="python")
|
| 492 |
+
|
| 493 |
+
elif isinstance(msg, ToolMessage):
|
| 494 |
+
with st.chat_message("tool", avatar="🛠️"):
|
| 495 |
+
try:
|
| 496 |
+
tool_data = json.loads(msg.content)
|
| 497 |
+
status = tool_data.get("status", "info")
|
| 498 |
+
message = tool_data.get("message", msg.content)
|
| 499 |
+
details = tool_data.get("details")
|
| 500 |
+
warnings = tool_data.get("warnings")
|
| 501 |
+
|
| 502 |
+
if status == "success" or status == "clear" or status == "flagged":
|
| 503 |
+
st.success(f"Tool Result ({msg.name}): {message}", icon="✅" if status != "flagged" else "🚨")
|
| 504 |
+
elif status == "warning":
|
| 505 |
+
st.warning(f"Tool Result ({msg.name}): {message}", icon="⚠️")
|
| 506 |
+
if warnings:
|
| 507 |
+
for warn in warnings: st.caption(f"- {warn}")
|
| 508 |
+
else: # Error or unknown status
|
| 509 |
+
st.error(f"Tool Result ({msg.name}): {message}", icon="❌")
|
| 510 |
+
|
| 511 |
+
if details: st.caption(f"Details: {details}")
|
| 512 |
+
|
| 513 |
except json.JSONDecodeError:
|
| 514 |
+
st.info(f"Tool Result ({msg.name}): {msg.content}") # Display raw if not JSON
|
|
|
|
|
|
|
| 515 |
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
+
# Chat input for user
|
| 518 |
+
if prompt := st.chat_input("Your message or follow-up query..."):
|
| 519 |
+
if not st.session_state.patient_data:
|
| 520 |
+
st.warning("Please load patient data using the sidebar first.")
|
| 521 |
+
else:
|
| 522 |
+
# Add user message to state
|
| 523 |
+
st.session_state.messages.append(HumanMessage(content=prompt))
|
| 524 |
+
with st.chat_message("user"):
|
| 525 |
+
st.markdown(prompt)
|
| 526 |
+
|
| 527 |
+
# Prepare state for graph invocation
|
| 528 |
+
current_state = AgentState(
|
| 529 |
+
messages=st.session_state.messages,
|
| 530 |
+
patient_data=st.session_state.patient_data
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# Stream graph execution
|
| 534 |
+
with st.chat_message("assistant"):
|
| 535 |
+
message_placeholder = st.empty()
|
| 536 |
+
full_response = ""
|
| 537 |
+
|
| 538 |
+
# Use stream to get intermediate steps (optional but good for UX)
|
| 539 |
+
# This shows AI thinking and tool calls/results progressively
|
| 540 |
+
try:
|
| 541 |
+
for event in st.session_state.graph_app.stream(current_state, {"recursion_limit": 15}):
|
| 542 |
+
# event is a dictionary, keys are node names
|
| 543 |
+
if "agent" in event:
|
| 544 |
+
ai_msg = event["agent"]["messages"][-1] # Get the latest AI message
|
| 545 |
+
if isinstance(ai_msg, AIMessage):
|
| 546 |
+
full_response += ai_msg.content # Append content for final display
|
| 547 |
+
message_placeholder.markdown(full_response + "▌") # Show typing indicator
|
| 548 |
|
| 549 |
+
# Display tool calls as they happen (optional)
|
| 550 |
+
# if ai_msg.tool_calls:
|
| 551 |
+
# st.info(f"Requesting tools: {[tc['name'] for tc in ai_msg.tool_calls]}")
|
| 552 |
|
| 553 |
+
elif "tools" in event:
|
| 554 |
+
# Display tool results as they come back (optional, already handled by message display loop)
|
| 555 |
+
pass
|
| 556 |
+
# tool_msgs = event["tools"]["messages"]
|
| 557 |
+
# for tool_msg in tool_msgs:
|
| 558 |
+
# st.info(f"Tool {tool_msg.name} result received.")
|
| 559 |
|
|
|
|
|
|
|
| 560 |
|
| 561 |
+
# Final display after streaming
|
| 562 |
+
message_placeholder.markdown(full_response)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
+
# Update session state with the final messages from the graph run
|
| 566 |
+
# The graph state itself isn't directly accessible after streaming finishes easily this way
|
| 567 |
+
# We need to get the final state if we used invoke, or reconstruct from stream events
|
| 568 |
+
# A simpler way for now: just append the *last* AI message and any Tool messages from the stream
|
| 569 |
+
# This assumes the stream provides the final state implicitly. For robust state, use invoke or checkpointer.
|
| 570 |
+
|
| 571 |
+
# A more robust way: invoke and get final state
|
| 572 |
+
# final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
|
| 573 |
+
# st.session_state.messages = final_state['messages']
|
| 574 |
+
# --- Let's stick to appending for simplicity in this example ---
|
| 575 |
+
# Find the last AI message and tool messages from the stream (needs careful event parsing)
|
| 576 |
+
# Or, re-run invoke non-streamed just to get final state (less efficient)
|
| 577 |
+
final_state_capture = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
|
| 578 |
+
st.session_state.messages = final_state_capture['messages']
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
except Exception as e:
|
| 582 |
+
st.error(f"An error occurred during analysis: {e}")
|
| 583 |
+
# Attempt to add the error message to the history
|
| 584 |
+
st.session_state.messages.append(AIMessage(content=f"Sorry, an error occurred: {e}"))
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# Rerun to display the updated chat history correctly
|
| 588 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
# Disclaimer
|
| 592 |
st.markdown("---")
|
| 593 |
+
st.warning("**Disclaimer:** SynapseAI is for clinical decision support...") # Keep disclaimer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
if __name__ == "__main__":
|
| 596 |
main()
|