mgbam commited on
Commit
d042f84
Β·
verified Β·
1 Parent(s): ec5723b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +7 -2
agent.py CHANGED
@@ -310,6 +310,9 @@ def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
310
 
311
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
312
  def agent_node(state: AgentState) -> Dict[str, Any]:
 
 
 
313
  msgs = state["messages"]
314
  if not msgs or not isinstance(msgs[0], SystemMessage):
315
  msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
@@ -324,6 +327,8 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
324
  return propagate_state(new_state, state)
325
 
326
  def tool_node(state: AgentState) -> Dict[str, Any]:
 
 
327
  last = state["messages"][-1]
328
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
329
  logger.warning("tool_node invoked without pending tool_calls")
@@ -374,6 +379,8 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
374
  return propagate_state(new_state, state)
375
 
376
  def reflection_node(state: AgentState) -> Dict[str, Any]:
 
 
377
  warns = state.get("interaction_warnings")
378
  if not warns:
379
  logger.warning("reflection_node called without warnings")
@@ -404,11 +411,9 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
404
 
405
  # ── Routing Functions ────────────────────────────────────────────────────────
406
  def should_continue(state: AgentState) -> str:
407
- # Initialize or increment the iteration counter
408
  state.setdefault("iterations", 0)
409
  state["iterations"] += 1
410
  logger.info(f"Iteration count: {state['iterations']}")
411
- # Force termination after a set number of iterations
412
  if state["iterations"] >= 4:
413
  state["done"] = True
414
  return "end_conversation_turn"
 
310
 
311
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
312
  def agent_node(state: AgentState) -> Dict[str, Any]:
313
+ # Check for termination
314
+ if state.get("done", False):
315
+ return state
316
  msgs = state["messages"]
317
  if not msgs or not isinstance(msgs[0], SystemMessage):
318
  msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
 
327
  return propagate_state(new_state, state)
328
 
329
  def tool_node(state: AgentState) -> Dict[str, Any]:
330
+ if state.get("done", False):
331
+ return state
332
  last = state["messages"][-1]
333
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
334
  logger.warning("tool_node invoked without pending tool_calls")
 
379
  return propagate_state(new_state, state)
380
 
381
  def reflection_node(state: AgentState) -> Dict[str, Any]:
382
+ if state.get("done", False):
383
+ return state
384
  warns = state.get("interaction_warnings")
385
  if not warns:
386
  logger.warning("reflection_node called without warnings")
 
411
 
412
  # ── Routing Functions ────────────────────────────────────────────────────────
413
  def should_continue(state: AgentState) -> str:
 
414
  state.setdefault("iterations", 0)
415
  state["iterations"] += 1
416
  logger.info(f"Iteration count: {state['iterations']}")
 
417
  if state["iterations"] >= 4:
418
  state["done"] = True
419
  return "end_conversation_turn"