mgbam commited on
Commit
ec5723b
Β·
verified Β·
1 Parent(s): cb940e4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +26 -10
agent.py CHANGED
@@ -154,7 +154,7 @@ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
154
  flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
155
  if sys <= 90 or dia <= 60:
156
  flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
157
- return list(dict.fromkeys(flags)) # dedupe, preserve order
158
 
159
  def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
160
  """Format patient_data dict into a markdown-like prompt section."""
@@ -301,6 +301,13 @@ class AgentState(TypedDict):
301
  done: Optional[bool]
302
  iterations: Optional[int]
303
 
 
 
 
 
 
 
 
304
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
305
  def agent_node(state: AgentState) -> Dict[str, Any]:
306
  msgs = state["messages"]
@@ -309,16 +316,19 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
309
  logger.info(f"Invoking LLM with {len(msgs)} messages")
310
  try:
311
  response = model_with_tools.invoke(msgs)
312
- return {"messages": [response]}
 
313
  except Exception as e:
314
  logger.exception("Error in agent_node")
315
- return {"messages": [AIMessage(content=f"Error: {e}")]}
 
316
 
317
  def tool_node(state: AgentState) -> Dict[str, Any]:
318
  last = state["messages"][-1]
319
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
320
  logger.warning("tool_node invoked without pending tool_calls")
321
- return {"messages": [], "interaction_warnings": None}
 
322
  calls = last.tool_calls
323
  blocked_ids = set()
324
  for call in calls:
@@ -360,20 +370,23 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
360
  tool_call_id=call["id"],
361
  name=call["name"]
362
  ))
363
- return {"messages": messages, "interaction_warnings": warnings or None}
 
364
 
365
  def reflection_node(state: AgentState) -> Dict[str, Any]:
366
  warns = state.get("interaction_warnings")
367
  if not warns:
368
  logger.warning("reflection_node called without warnings")
369
- return {"messages": [], "interaction_warnings": None}
 
370
  triggering = None
371
  for msg in reversed(state["messages"]):
372
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
373
  triggering = msg
374
  break
375
  if not triggering:
376
- return {"messages": [AIMessage(content="Internal Error: reflection context missing.")], "interaction_warnings": None}
 
377
  prompt = (
378
  "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
379
  f"{triggering.content}\n\n"
@@ -382,17 +395,20 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
382
  )
383
  try:
384
  resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
385
- return {"messages": [AIMessage(content=resp.content)], "interaction_warnings": None}
 
386
  except Exception as e:
387
  logger.exception("Error during reflection")
388
- return {"messages": [AIMessage(content=f"Error during reflection: {e}")], "interaction_warnings": None}
 
389
 
390
  # ── Routing Functions ────────────────────────────────────────────────────────
391
  def should_continue(state: AgentState) -> str:
392
  # Initialize or increment the iteration counter
393
  state.setdefault("iterations", 0)
394
  state["iterations"] += 1
395
- # Force termination after a set number of iterations to prevent infinite loops.
 
396
  if state["iterations"] >= 4:
397
  state["done"] = True
398
  return "end_conversation_turn"
 
154
  flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
155
  if sys <= 90 or dia <= 60:
156
  flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
157
+ return list(dict.fromkeys(flags))
158
 
159
  def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
160
  """Format patient_data dict into a markdown-like prompt section."""
 
301
  done: Optional[bool]
302
  iterations: Optional[int]
303
 
304
+ # Helper to propagate state fields
305
+ def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
306
+ for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
307
+ if key in old and key not in new:
308
+ new[key] = old[key]
309
+ return new
310
+
311
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
312
  def agent_node(state: AgentState) -> Dict[str, Any]:
313
  msgs = state["messages"]
 
316
  logger.info(f"Invoking LLM with {len(msgs)} messages")
317
  try:
318
  response = model_with_tools.invoke(msgs)
319
+ new_state = {"messages": [response]}
320
+ return propagate_state(new_state, state)
321
  except Exception as e:
322
  logger.exception("Error in agent_node")
323
+ new_state = {"messages": [AIMessage(content=f"Error: {e}")]}
324
+ return propagate_state(new_state, state)
325
 
326
  def tool_node(state: AgentState) -> Dict[str, Any]:
327
  last = state["messages"][-1]
328
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
329
  logger.warning("tool_node invoked without pending tool_calls")
330
+ new_state = {"messages": []}
331
+ return propagate_state(new_state, state)
332
  calls = last.tool_calls
333
  blocked_ids = set()
334
  for call in calls:
 
370
  tool_call_id=call["id"],
371
  name=call["name"]
372
  ))
373
+ new_state = {"messages": messages, "interaction_warnings": warnings or None}
374
+ return propagate_state(new_state, state)
375
 
376
  def reflection_node(state: AgentState) -> Dict[str, Any]:
377
  warns = state.get("interaction_warnings")
378
  if not warns:
379
  logger.warning("reflection_node called without warnings")
380
+ new_state = {"messages": []}
381
+ return propagate_state(new_state, state)
382
  triggering = None
383
  for msg in reversed(state["messages"]):
384
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
385
  triggering = msg
386
  break
387
  if not triggering:
388
+ new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
389
+ return propagate_state(new_state, state)
390
  prompt = (
391
  "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
392
  f"{triggering.content}\n\n"
 
395
  )
396
  try:
397
  resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
398
+ new_state = {"messages": [AIMessage(content=resp.content)]}
399
+ return propagate_state(new_state, state)
400
  except Exception as e:
401
  logger.exception("Error during reflection")
402
+ new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
403
+ return propagate_state(new_state, state)
404
 
405
  # ── Routing Functions ────────────────────────────────────────────────────────
406
  def should_continue(state: AgentState) -> str:
407
  # Initialize or increment the iteration counter
408
  state.setdefault("iterations", 0)
409
  state["iterations"] += 1
410
+ logger.info(f"Iteration count: {state['iterations']}")
411
+ # Force termination after a set number of iterations
412
  if state["iterations"] >= 4:
413
  state["done"] = True
414
  return "end_conversation_turn"