mgbam commited on
Commit
cb940e4
Β·
verified Β·
1 Parent(s): c81a632

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +10 -2
agent.py CHANGED
@@ -299,6 +299,7 @@ class AgentState(TypedDict):
299
  summary: Optional[str]
300
  interaction_warnings: Optional[List[str]]
301
  done: Optional[bool]
 
302
 
303
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
304
  def agent_node(state: AgentState) -> Dict[str, Any]:
@@ -325,7 +326,7 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
325
  med = call["args"].get("medication_name", "").lower()
326
  if not any(
327
  c["name"] == "check_drug_interactions" and
328
- c["args"].get("potential_prescription","").lower() == med
329
  for c in calls
330
  ):
331
  logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
@@ -337,7 +338,7 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
337
  call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
338
  call["args"].setdefault("allergies", pd.get("allergies", []))
339
  messages: List[ToolMessage] = []
340
- warnings: List[str] = []
341
  try:
342
  responses = tool_executor.batch(to_execute, return_exceptions=True)
343
  for call, resp in zip(to_execute, responses):
@@ -388,6 +389,13 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
388
 
389
  # ── Routing Functions ────────────────────────────────────────────────────────
390
  def should_continue(state: AgentState) -> str:
 
 
 
 
 
 
 
391
  last = state["messages"][-1] if state["messages"] else None
392
  if not isinstance(last, AIMessage):
393
  state["done"] = True
 
299
  summary: Optional[str]
300
  interaction_warnings: Optional[List[str]]
301
  done: Optional[bool]
302
+ iterations: Optional[int]
303
 
304
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
305
  def agent_node(state: AgentState) -> Dict[str, Any]:
 
326
  med = call["args"].get("medication_name", "").lower()
327
  if not any(
328
  c["name"] == "check_drug_interactions" and
329
+ c["args"].get("potential_prescription", "").lower() == med
330
  for c in calls
331
  ):
332
  logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
 
338
  call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
339
  call["args"].setdefault("allergies", pd.get("allergies", []))
340
  messages: List[ToolMessage] = []
341
+ warnings: List[str] = []
342
  try:
343
  responses = tool_executor.batch(to_execute, return_exceptions=True)
344
  for call, resp in zip(to_execute, responses):
 
389
 
390
  # ── Routing Functions ────────────────────────────────────────────────────────
391
  def should_continue(state: AgentState) -> str:
392
+ # Initialize or increment the iteration counter
393
+ state.setdefault("iterations", 0)
394
+ state["iterations"] += 1
395
+ # Force termination after a set number of iterations to prevent infinite loops.
396
+ if state["iterations"] >= 4:
397
+ state["done"] = True
398
+ return "end_conversation_turn"
399
  last = state["messages"][-1] if state["messages"] else None
400
  if not isinstance(last, AIMessage):
401
  state["done"] = True