Update agent.py
Browse files
agent.py
CHANGED
@@ -154,7 +154,7 @@ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
|
|
154 |
flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
|
155 |
if sys <= 90 or dia <= 60:
|
156 |
flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
|
157 |
-
return list(dict.fromkeys(flags))
|
158 |
|
159 |
def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
|
160 |
"""Format patient_data dict into a markdown-like prompt section."""
|
@@ -301,6 +301,13 @@ class AgentState(TypedDict):
|
|
301 |
done: Optional[bool]
|
302 |
iterations: Optional[int]
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
# ββ Graph Nodes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
305 |
def agent_node(state: AgentState) -> Dict[str, Any]:
|
306 |
msgs = state["messages"]
|
@@ -309,16 +316,19 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
|
|
309 |
logger.info(f"Invoking LLM with {len(msgs)} messages")
|
310 |
try:
|
311 |
response = model_with_tools.invoke(msgs)
|
312 |
-
|
|
|
313 |
except Exception as e:
|
314 |
logger.exception("Error in agent_node")
|
315 |
-
|
|
|
316 |
|
317 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
318 |
last = state["messages"][-1]
|
319 |
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
|
320 |
logger.warning("tool_node invoked without pending tool_calls")
|
321 |
-
|
|
|
322 |
calls = last.tool_calls
|
323 |
blocked_ids = set()
|
324 |
for call in calls:
|
@@ -360,20 +370,23 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
|
|
360 |
tool_call_id=call["id"],
|
361 |
name=call["name"]
|
362 |
))
|
363 |
-
|
|
|
364 |
|
365 |
def reflection_node(state: AgentState) -> Dict[str, Any]:
|
366 |
warns = state.get("interaction_warnings")
|
367 |
if not warns:
|
368 |
logger.warning("reflection_node called without warnings")
|
369 |
-
|
|
|
370 |
triggering = None
|
371 |
for msg in reversed(state["messages"]):
|
372 |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
373 |
triggering = msg
|
374 |
break
|
375 |
if not triggering:
|
376 |
-
|
|
|
377 |
prompt = (
|
378 |
"You are SynapseAI, performing a focused safety review of the following plan:\n\n"
|
379 |
f"{triggering.content}\n\n"
|
@@ -382,17 +395,20 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
382 |
)
|
383 |
try:
|
384 |
resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
|
385 |
-
|
|
|
386 |
except Exception as e:
|
387 |
logger.exception("Error during reflection")
|
388 |
-
|
|
|
389 |
|
390 |
# ββ Routing Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
391 |
def should_continue(state: AgentState) -> str:
|
392 |
# Initialize or increment the iteration counter
|
393 |
state.setdefault("iterations", 0)
|
394 |
state["iterations"] += 1
|
395 |
-
|
|
|
396 |
if state["iterations"] >= 4:
|
397 |
state["done"] = True
|
398 |
return "end_conversation_turn"
|
|
|
154 |
flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
|
155 |
if sys <= 90 or dia <= 60:
|
156 |
flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
|
157 |
+
return list(dict.fromkeys(flags))
|
158 |
|
159 |
def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
|
160 |
"""Format patient_data dict into a markdown-like prompt section."""
|
|
|
301 |
done: Optional[bool]
|
302 |
iterations: Optional[int]
|
303 |
|
304 |
+
# Helper to propagate state fields
|
305 |
+
def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
|
306 |
+
for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
|
307 |
+
if key in old and key not in new:
|
308 |
+
new[key] = old[key]
|
309 |
+
return new
|
310 |
+
|
311 |
# ββ Graph Nodes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
312 |
def agent_node(state: AgentState) -> Dict[str, Any]:
|
313 |
msgs = state["messages"]
|
|
|
316 |
logger.info(f"Invoking LLM with {len(msgs)} messages")
|
317 |
try:
|
318 |
response = model_with_tools.invoke(msgs)
|
319 |
+
new_state = {"messages": [response]}
|
320 |
+
return propagate_state(new_state, state)
|
321 |
except Exception as e:
|
322 |
logger.exception("Error in agent_node")
|
323 |
+
new_state = {"messages": [AIMessage(content=f"Error: {e}")]}
|
324 |
+
return propagate_state(new_state, state)
|
325 |
|
326 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
327 |
last = state["messages"][-1]
|
328 |
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
|
329 |
logger.warning("tool_node invoked without pending tool_calls")
|
330 |
+
new_state = {"messages": []}
|
331 |
+
return propagate_state(new_state, state)
|
332 |
calls = last.tool_calls
|
333 |
blocked_ids = set()
|
334 |
for call in calls:
|
|
|
370 |
tool_call_id=call["id"],
|
371 |
name=call["name"]
|
372 |
))
|
373 |
+
new_state = {"messages": messages, "interaction_warnings": warnings or None}
|
374 |
+
return propagate_state(new_state, state)
|
375 |
|
376 |
def reflection_node(state: AgentState) -> Dict[str, Any]:
|
377 |
warns = state.get("interaction_warnings")
|
378 |
if not warns:
|
379 |
logger.warning("reflection_node called without warnings")
|
380 |
+
new_state = {"messages": []}
|
381 |
+
return propagate_state(new_state, state)
|
382 |
triggering = None
|
383 |
for msg in reversed(state["messages"]):
|
384 |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
385 |
triggering = msg
|
386 |
break
|
387 |
if not triggering:
|
388 |
+
new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
|
389 |
+
return propagate_state(new_state, state)
|
390 |
prompt = (
|
391 |
"You are SynapseAI, performing a focused safety review of the following plan:\n\n"
|
392 |
f"{triggering.content}\n\n"
|
|
|
395 |
)
|
396 |
try:
|
397 |
resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
|
398 |
+
new_state = {"messages": [AIMessage(content=resp.content)]}
|
399 |
+
return propagate_state(new_state, state)
|
400 |
except Exception as e:
|
401 |
logger.exception("Error during reflection")
|
402 |
+
new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
|
403 |
+
return propagate_state(new_state, state)
|
404 |
|
405 |
# ββ Routing Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
406 |
def should_continue(state: AgentState) -> str:
|
407 |
# Initialize or increment the iteration counter
|
408 |
state.setdefault("iterations", 0)
|
409 |
state["iterations"] += 1
|
410 |
+
logger.info(f"Iteration count: {state['iterations']}")
|
411 |
+
# Force termination after a set number of iterations
|
412 |
if state["iterations"] >= 4:
|
413 |
state["done"] = True
|
414 |
return "end_conversation_turn"
|