Update agent.py
Browse files
agent.py
CHANGED
@@ -310,6 +310,9 @@ def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
|
|
310 |
|
311 |
# ββ Graph Nodes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
312 |
def agent_node(state: AgentState) -> Dict[str, Any]:
|
|
|
|
|
|
|
313 |
msgs = state["messages"]
|
314 |
if not msgs or not isinstance(msgs[0], SystemMessage):
|
315 |
msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
|
@@ -324,6 +327,8 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
|
|
324 |
return propagate_state(new_state, state)
|
325 |
|
326 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
|
|
|
|
327 |
last = state["messages"][-1]
|
328 |
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
|
329 |
logger.warning("tool_node invoked without pending tool_calls")
|
@@ -374,6 +379,8 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
|
|
374 |
return propagate_state(new_state, state)
|
375 |
|
376 |
def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
|
|
|
377 |
warns = state.get("interaction_warnings")
|
378 |
if not warns:
|
379 |
logger.warning("reflection_node called without warnings")
|
@@ -404,11 +411,9 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
404 |
|
405 |
# ββ Routing Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
406 |
def should_continue(state: AgentState) -> str:
|
407 |
-
# Initialize or increment the iteration counter
|
408 |
state.setdefault("iterations", 0)
|
409 |
state["iterations"] += 1
|
410 |
logger.info(f"Iteration count: {state['iterations']}")
|
411 |
-
# Force termination after a set number of iterations
|
412 |
if state["iterations"] >= 4:
|
413 |
state["done"] = True
|
414 |
return "end_conversation_turn"
|
|
|
310 |
|
311 |
# ββ Graph Nodes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
312 |
def agent_node(state: AgentState) -> Dict[str, Any]:
|
313 |
+
# Check for termination
|
314 |
+
if state.get("done", False):
|
315 |
+
return state
|
316 |
msgs = state["messages"]
|
317 |
if not msgs or not isinstance(msgs[0], SystemMessage):
|
318 |
msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
|
|
|
327 |
return propagate_state(new_state, state)
|
328 |
|
329 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
330 |
+
if state.get("done", False):
|
331 |
+
return state
|
332 |
last = state["messages"][-1]
|
333 |
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
|
334 |
logger.warning("tool_node invoked without pending tool_calls")
|
|
|
379 |
return propagate_state(new_state, state)
|
380 |
|
381 |
def reflection_node(state: AgentState) -> Dict[str, Any]:
|
382 |
+
if state.get("done", False):
|
383 |
+
return state
|
384 |
warns = state.get("interaction_warnings")
|
385 |
if not warns:
|
386 |
logger.warning("reflection_node called without warnings")
|
|
|
411 |
|
412 |
# ββ Routing Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
413 |
def should_continue(state: AgentState) -> str:
|
|
|
414 |
state.setdefault("iterations", 0)
|
415 |
state["iterations"] += 1
|
416 |
logger.info(f"Iteration count: {state['iterations']}")
|
|
|
417 |
if state["iterations"] >= 4:
|
418 |
state["done"] = True
|
419 |
return "end_conversation_turn"
|