|
import streamlit as st |
|
from langchain_groq import ChatGroq |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolExecutor |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
from typing import Optional, List, Dict, Any, TypedDict, Annotated |
|
import json |
|
import re |
|
import operator |
|
|
|
|
|
class ClinicalAppSettings: |
|
APP_TITLE = "SynapseAI: Interactive Clinical Decision Support" |
|
PAGE_LAYOUT = "wide" |
|
MODEL_NAME = "llama3-70b-8192" |
|
TEMPERATURE = 0.1 |
|
MAX_SEARCH_RESULTS = 3 |
|
|
|
class ClinicalPrompts: |
|
|
|
SYSTEM_PROMPT = """ |
|
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation. |
|
Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care. |
|
|
|
**Core Directives for this Conversation:** |
|
1. **Analyze Sequentially:** Process information turn-by-turn. You will receive initial patient data, and potentially follow-up messages or results from tools you requested. Base your responses on the *entire* conversation history. |
|
2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions. |
|
3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions), provide a comprehensive assessment using the following JSON structure. Only output this structure when you believe you have a complete initial analysis or plan. Do NOT output incomplete JSON. |
|
```json |
|
{ |
|
"assessment": "Concise summary of the patient's presentation and key findings based on the conversation.", |
|
"differential_diagnosis": [ |
|
{"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."}, |
|
{"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."}, |
|
{"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."} |
|
], |
|
"risk_assessment": { |
|
"identified_red_flags": ["List any triggered red flags"], |
|
"immediate_concerns": ["Specific urgent issues (e.g., sepsis risk, ACS rule-out)"], |
|
"potential_complications": ["Possible future issues"] |
|
}, |
|
"recommended_plan": { |
|
"investigations": ["List specific lab tests or imaging needed. Use 'order_lab_test' tool."], |
|
"therapeutics": ["Suggest specific treatments/prescriptions. Use 'prescribe_medication' tool. MUST check interactions first."], |
|
"consultations": ["Recommend specialist consultations."], |
|
"patient_education": ["Key points for patient communication."] |
|
}, |
|
"rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.**", |
|
"interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed." |
|
} |
|
``` |
|
4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions`. Report the findings. If interactions exist, modify the plan or state the contraindication. |
|
5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point. |
|
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription). |
|
7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output. |
|
8. **Conciseness:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation until ready for the full structured JSON output. |
|
""" |
|
|
|
|
|
|
|
MOCK_INTERACTION_DB = { |
|
("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.", |
|
("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.", |
|
("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.", |
|
("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding." |
|
} |
|
|
|
ALLERGY_INTERACTIONS = { |
|
"penicillin": ["amoxicillin", "ampicillin", "piperacillin"], |
|
"sulfa": ["sulfamethoxazole", "sulfasalazine"], |
|
"aspirin": ["ibuprofen", "naproxen"] |
|
} |
|
|
|
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]: |
|
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string) |
|
if match: return int(match.group(1)), int(match.group(2)) |
|
return None |
|
|
|
def check_red_flags(patient_data: dict) -> List[str]: |
|
flags = [] |
|
symptoms = patient_data.get("hpi", {}).get("symptoms", []) |
|
vitals = patient_data.get("vitals", {}) |
|
history = patient_data.get("pmh", {}).get("conditions", "") |
|
symptoms_lower = [s.lower() for s in symptoms] |
|
|
|
if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.") |
|
if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.") |
|
if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.") |
|
|
|
|
|
if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever ({vitals['temp_c']}ยฐC).") |
|
if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia ({vitals['hr_bpm']} bpm).") |
|
if "bp_mmhg" in vitals: |
|
bp = parse_bp(vitals["bp_mmhg"]) |
|
if bp and (bp[0] >= 180 or bp[1] >= 110): flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).") |
|
if bp and (bp[0] <= 90 or bp[1] <= 60): flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).") |
|
|
|
|
|
if "history of mi" in history.lower() and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.") |
|
|
|
return flags |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LabOrderInput(BaseModel): |
|
test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').") |
|
reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').") |
|
priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').") |
|
|
|
@tool("order_lab_test", args_schema=LabOrderInput) |
|
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: |
|
"""Orders a specific lab test with clinical justification and priority.""" |
|
return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"}) |
|
|
|
class PrescriptionInput(BaseModel): |
|
medication_name: str = Field(..., description="Name of the medication.") |
|
dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg').") |
|
route: str = Field(..., description="Route of administration (e.g., 'PO', 'IV', 'IM', 'Topical').") |
|
frequency: str = Field(..., description="How often the medication should be taken (e.g., 'BID', 'QDaily', 'Q4-6H PRN').") |
|
duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Until follow-up').") |
|
reason: str = Field(..., description="Clinical indication for the prescription.") |
|
|
|
@tool("prescribe_medication", args_schema=PrescriptionInput) |
|
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str: |
|
"""Prescribes a medication with detailed instructions and clinical indication.""" |
|
|
|
return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"}) |
|
|
|
class InteractionCheckInput(BaseModel): |
|
potential_prescription: str = Field(..., description="The name of the NEW medication being considered.") |
|
current_medications: List[str] = Field(..., description="List of the patient's CURRENT medication names.") |
|
allergies: List[str] = Field(..., description="List of the patient's known allergies.") |
|
|
|
@tool("check_drug_interactions", args_schema=InteractionCheckInput) |
|
def check_drug_interactions(potential_prescription: str, current_medications: List[str], allergies: List[str]) -> str: |
|
"""Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing.""" |
|
warnings = [] |
|
potential_med_lower = potential_prescription.lower() |
|
current_meds_lower = [med.lower() for med in current_medications] |
|
allergies_lower = [a.lower() for a in allergies] |
|
|
|
for allergy in allergies_lower: |
|
if allergy == potential_med_lower: |
|
warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.") |
|
continue |
|
if allergy in ALLERGY_INTERACTIONS: |
|
for cross_reactant in ALLERGY_INTERACTIONS[allergy]: |
|
if cross_reactant.lower() == potential_med_lower: |
|
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.") |
|
|
|
for current_med in current_meds_lower: |
|
pair1 = (current_med, potential_med_lower) |
|
pair2 = (potential_med_lower, current_med) |
|
|
|
key1 = tuple(sorted(pair1)) |
|
key2 = tuple(sorted(pair2)) |
|
|
|
if pair1 in MOCK_INTERACTION_DB: |
|
warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}") |
|
elif pair2 in MOCK_INTERACTION_DB: |
|
warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}") |
|
|
|
status = "warning" if warnings else "clear" |
|
message = f"Interaction check for {potential_prescription}: {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for {potential_prescription}." |
|
return json.dumps({"status": status, "message": message, "warnings": warnings}) |
|
|
|
|
|
class FlagRiskInput(BaseModel): |
|
risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').") |
|
urgency: str = Field("High", description="Urgency level (e.g., 'Critical', 'High', 'Moderate').") |
|
|
|
@tool("flag_risk", args_schema=FlagRiskInput) |
|
def flag_risk(risk_description: str, urgency: str) -> str: |
|
"""Flags a critical risk identified during analysis for immediate attention.""" |
|
|
|
st.error(f"๐จ **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="๐จ") |
|
return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."}) |
|
|
|
|
|
search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results") |
|
|
|
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list[Any], operator.add] |
|
patient_data: Optional[dict] |
|
|
|
|
|
|
|
tools = [ |
|
order_lab_test, |
|
prescribe_medication, |
|
check_drug_interactions, |
|
flag_risk, |
|
search_tool |
|
] |
|
tool_executor = ToolExecutor(tools) |
|
|
|
|
|
model = ChatGroq( |
|
temperature=ClinicalAppSettings.TEMPERATURE, |
|
model=ClinicalAppSettings.MODEL_NAME |
|
) |
|
model_with_tools = model.bind_tools(tools) |
|
|
|
|
|
|
|
|
|
def agent_node(state: AgentState): |
|
"""Invokes the LLM to decide the next action or response.""" |
|
print("---AGENT NODE---") |
|
|
|
|
|
current_messages = state['messages'] |
|
if len(current_messages) == 1 and isinstance(current_messages[0], HumanMessage) and state.get('patient_data'): |
|
|
|
formatted_data = format_patient_data_for_prompt(state['patient_data']) |
|
current_messages = [ |
|
SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT), |
|
HumanMessage(content=f"{current_messages[0].content}\n\n**Initial Patient Data:**\n{formatted_data}") |
|
] |
|
elif not any(isinstance(m, SystemMessage) for m in current_messages): |
|
|
|
current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages |
|
|
|
|
|
response = model_with_tools.invoke(current_messages) |
|
print(f"Agent response: {response}") |
|
return {"messages": [response]} |
|
|
|
|
|
def tool_node(state: AgentState): |
|
"""Executes tools called by the LLM and returns results.""" |
|
print("---TOOL NODE---") |
|
last_message = state['messages'][-1] |
|
if not isinstance(last_message, AIMessage) or not last_message.tool_calls: |
|
print("No tool calls in last message.") |
|
return {} |
|
|
|
tool_calls = last_message.tool_calls |
|
tool_messages = [] |
|
|
|
|
|
prescribe_calls = {call['args'].get('medication_name'): call['id'] for call in tool_calls if call['name'] == 'prescribe_medication'} |
|
interaction_check_calls = {call['args'].get('potential_prescription'): call['id'] for call in tool_calls if call['name'] == 'check_drug_interactions'} |
|
|
|
for med_name, prescribe_call_id in prescribe_calls.items(): |
|
if med_name not in interaction_check_calls: |
|
st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked for this turn.") |
|
|
|
error_msg = ToolMessage( |
|
content=json.dumps({"status": "error", "message": f"Interaction check for {med_name} must be requested *before or alongside* the prescription call."}), |
|
tool_call_id=prescribe_call_id |
|
) |
|
tool_messages.append(error_msg) |
|
|
|
tool_calls = [call for call in tool_calls if call['id'] != prescribe_call_id] |
|
|
|
|
|
|
|
patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", []) |
|
patient_allergies = state.get("patient_data", {}).get("allergies", []) |
|
for call in tool_calls: |
|
if call['name'] == 'check_drug_interactions': |
|
call['args']['current_medications'] = patient_meds |
|
call['args']['allergies'] = patient_allergies |
|
print(f"Augmented interaction check args: {call['args']}") |
|
|
|
|
|
|
|
if tool_calls: |
|
responses = tool_executor.batch(tool_calls) |
|
|
|
|
|
tool_messages.extend([ |
|
ToolMessage(content=str(resp), tool_call_id=call['id']) |
|
for call, resp in zip(tool_calls, responses) |
|
]) |
|
print(f"Tool results: {tool_messages}") |
|
|
|
return {"messages": tool_messages} |
|
|
|
|
|
|
|
def should_continue(state: AgentState) -> str: |
|
"""Determines whether to continue the loop or end.""" |
|
last_message = state['messages'][-1] |
|
|
|
if isinstance(last_message, AIMessage) and last_message.tool_calls: |
|
print("Routing: continue_tools") |
|
return "continue_tools" |
|
|
|
else: |
|
print("Routing: end_conversation_turn") |
|
return "end_conversation_turn" |
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("agent", agent_node) |
|
workflow.add_node("tools", tool_node) |
|
|
|
|
|
workflow.set_entry_point("agent") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"agent", |
|
should_continue, |
|
{ |
|
"continue_tools": "tools", |
|
"end_conversation_turn": END |
|
} |
|
) |
|
|
|
|
|
workflow.add_edge("tools", "agent") |
|
|
|
|
|
|
|
|
|
app = workflow.compile() |
|
|
|
|
|
def format_patient_data_for_prompt(data: dict) -> str: |
|
"""Formats the patient dictionary into a readable string for the LLM.""" |
|
prompt_str = "" |
|
for key, value in data.items(): |
|
if isinstance(value, dict): |
|
section_title = key.replace('_', ' ').title() |
|
prompt_str += f"**{section_title}:**\n" |
|
for sub_key, sub_value in value.items(): |
|
if sub_value: |
|
prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n" |
|
elif isinstance(value, list) and value: |
|
prompt_str += f"**{key.replace('_', ' ').title()}:** {', '.join(map(str, value))}\n" |
|
elif value: |
|
prompt_str += f"**{key.replace('_', ' ').title()}:** {value}\n" |
|
return prompt_str.strip() |
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT) |
|
st.title(f"๐ฉบ {ClinicalAppSettings.APP_TITLE}") |
|
st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "patient_data" not in st.session_state: |
|
st.session_state.patient_data = None |
|
if "initial_analysis_done" not in st.session_state: |
|
st.session_state.initial_analysis_done = False |
|
if "graph_app" not in st.session_state: |
|
st.session_state.graph_app = app |
|
|
|
|
|
with st.sidebar: |
|
st.header("๐ Patient Intake Form") |
|
|
|
|
|
age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input") |
|
sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input") |
|
|
|
chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input") |
|
hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input") |
|
symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"], key="sym_input") |
|
|
|
pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)", key="pmh_input") |
|
psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input") |
|
|
|
current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input") |
|
allergies_str = st.text_area("Allergies (comma separated)", "Penicillin (rash)", key="allergy_input") |
|
|
|
social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input") |
|
family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
temp_c = st.number_input("Temp (ยฐC)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input") |
|
hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input") |
|
rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input") |
|
with col2: |
|
bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input") |
|
spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input") |
|
pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input") |
|
exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3...", key="exam_input") |
|
|
|
|
|
if st.button("Start/Update Consultation", key="start_button"): |
|
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()] |
|
current_med_names = [] |
|
|
|
for med in current_meds_list: |
|
match = re.match(r"^\s*([a-zA-Z\-]+)", med) |
|
if match: |
|
current_med_names.append(match.group(1).lower()) |
|
|
|
allergies_list = [a.strip().lower() for a in allergies_str.split(',') if a.strip()] |
|
|
|
st.session_state.patient_data = { |
|
"demographics": {"age": age, "sex": sex}, |
|
"hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms}, |
|
"pmh": {"conditions": pmh}, "psh": {"procedures": psh}, |
|
"medications": {"current": current_meds_list, "names_only": current_med_names}, |
|
"allergies": allergies_list, |
|
"social_history": {"details": social_history}, "family_history": {"details": family_history}, |
|
"vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale}, |
|
"exam_findings": {"notes": exam_notes} |
|
} |
|
|
|
|
|
red_flags = check_red_flags(st.session_state.patient_data) |
|
if red_flags: |
|
st.warning("**Initial Red Flags Detected:**") |
|
for flag in red_flags: st.warning(f"- {flag}") |
|
|
|
|
|
initial_prompt = f"Analyze the following patient case:\nChief Complaint: {chief_complaint}\nSummary: {age} y/o {sex} presenting with..." |
|
st.session_state.messages = [HumanMessage(content=initial_prompt)] |
|
st.session_state.initial_analysis_done = False |
|
st.success("Patient data loaded. Ready for analysis.") |
|
st.rerun() |
|
|
|
|
|
|
|
st.header("๐ฌ Clinical Consultation") |
|
|
|
|
|
for msg in st.session_state.messages: |
|
if isinstance(msg, HumanMessage): |
|
with st.chat_message("user"): |
|
st.markdown(msg.content) |
|
elif isinstance(msg, AIMessage): |
|
with st.chat_message("assistant"): |
|
|
|
structured_output = None |
|
try: |
|
|
|
json_match = re.search(r"```json\n(\{.*?\})\n```", msg.content, re.DOTALL) |
|
if json_match: |
|
structured_output = json.loads(json_match.group(1)) |
|
|
|
non_json_content = msg.content.replace(json_match.group(0), "").strip() |
|
if non_json_content: |
|
st.markdown(non_json_content) |
|
st.divider() |
|
elif msg.content.strip().startswith("{") and msg.content.strip().endswith("}"): |
|
|
|
structured_output = json.loads(msg.content) |
|
else: |
|
|
|
st.markdown(msg.content) |
|
|
|
if structured_output: |
|
|
|
st.subheader("๐ AI Analysis & Recommendations") |
|
|
|
|
|
st.write(f"**Assessment:** {structured_output.get('assessment', 'N/A')}") |
|
|
|
|
|
|
|
with st.expander("Rationale & Guideline Check"): |
|
st.write(structured_output.get("rationale_summary", "N/A")) |
|
if structured_output.get("interaction_check_summary"): |
|
with st.expander("Interaction Check"): |
|
st.write(structured_output.get("interaction_check_summary")) |
|
|
|
|
|
except json.JSONDecodeError: |
|
st.markdown(msg.content) |
|
|
|
|
|
if msg.tool_calls: |
|
with st.expander("๐ ๏ธ AI requested actions", expanded=False): |
|
for tc in msg.tool_calls: |
|
st.code(f"{tc['name']}(args={tc['args']})", language="python") |
|
|
|
elif isinstance(msg, ToolMessage): |
|
with st.chat_message("tool", avatar="๐ ๏ธ"): |
|
try: |
|
tool_data = json.loads(msg.content) |
|
status = tool_data.get("status", "info") |
|
message = tool_data.get("message", msg.content) |
|
details = tool_data.get("details") |
|
warnings = tool_data.get("warnings") |
|
|
|
if status == "success" or status == "clear" or status == "flagged": |
|
st.success(f"Tool Result ({msg.name}): {message}", icon="โ
" if status != "flagged" else "๐จ") |
|
elif status == "warning": |
|
st.warning(f"Tool Result ({msg.name}): {message}", icon="โ ๏ธ") |
|
if warnings: |
|
for warn in warnings: st.caption(f"- {warn}") |
|
else: |
|
st.error(f"Tool Result ({msg.name}): {message}", icon="โ") |
|
|
|
if details: st.caption(f"Details: {details}") |
|
|
|
except json.JSONDecodeError: |
|
st.info(f"Tool Result ({msg.name}): {msg.content}") |
|
|
|
|
|
|
|
if prompt := st.chat_input("Your message or follow-up query..."): |
|
if not st.session_state.patient_data: |
|
st.warning("Please load patient data using the sidebar first.") |
|
else: |
|
|
|
st.session_state.messages.append(HumanMessage(content=prompt)) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
current_state = AgentState( |
|
messages=st.session_state.messages, |
|
patient_data=st.session_state.patient_data |
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
|
|
|
|
|
|
try: |
|
for event in st.session_state.graph_app.stream(current_state, {"recursion_limit": 15}): |
|
|
|
if "agent" in event: |
|
ai_msg = event["agent"]["messages"][-1] |
|
if isinstance(ai_msg, AIMessage): |
|
full_response += ai_msg.content |
|
message_placeholder.markdown(full_response + "โ") |
|
|
|
|
|
|
|
|
|
|
|
elif "tools" in event: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
message_placeholder.markdown(full_response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_state_capture = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15}) |
|
st.session_state.messages = final_state_capture['messages'] |
|
|
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during analysis: {e}") |
|
|
|
st.session_state.messages.append(AIMessage(content=f"Sorry, an error occurred: {e}")) |
|
|
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
st.markdown("---") |
|
st.warning("**Disclaimer:** SynapseAI is for clinical decision support...") |
|
|
|
if __name__ == "__main__": |
|
main() |