mgbam commited on
Commit
63a6acf
Β·
verified Β·
1 Parent(s): d042f84

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +28 -20
agent.py CHANGED
@@ -15,11 +15,11 @@ from langchain_core.tools import tool
15
  from langgraph.prebuilt import ToolExecutor
16
  from langgraph.graph import StateGraph, END
17
 
18
- # ── Logging Configuration ─────────────────────────────────────────────────────
19
  logger = logging.getLogger(__name__)
20
  logging.basicConfig(level=logging.INFO)
21
 
22
- # ── Environment Variables ─────────────────────────────────────────────────────
23
  UMLS_API_KEY = os.getenv("UMLS_API_KEY")
24
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
@@ -28,7 +28,7 @@ if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
28
  logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
29
  raise RuntimeError("Missing required API keys")
30
 
31
- # ── Agent Configuration ───────────────────────────────────────────────────────
32
  AGENT_MODEL_NAME = "llama3-70b-8192"
33
  AGENT_TEMPERATURE = 0.1
34
  MAX_SEARCH_RESULTS = 3
@@ -39,7 +39,7 @@ class ClinicalPrompts:
39
  [SYSTEM PROMPT CONTENT HERE]
40
  """
41
 
42
- # ── Helper Functions ──────────────────────────────────────────────────────────
43
  UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
44
  RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
45
  OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
@@ -174,7 +174,7 @@ def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
174
  lines.append(f"**{title}:** {value}")
175
  return "\n".join(lines)
176
 
177
- # ── Tool Input Schemas ────────────────────────────────────────────────────────
178
  class LabOrderInput(BaseModel):
179
  test_name: str = Field(...)
180
  reason: str = Field(...)
@@ -197,7 +197,7 @@ class FlagRiskInput(BaseModel):
197
  risk_description: str = Field(...)
198
  urgency: str = Field("High")
199
 
200
- # ── Tool Implementations ──────────────────────────────────────────────────────
201
  @tool("order_lab_test", args_schema=LabOrderInput)
202
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
203
  """
@@ -287,12 +287,12 @@ def flag_risk(risk_description: str, urgency: str = "High") -> str:
287
  search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
288
  all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
289
 
290
- # ── LLM & Tool Executor ──────────────────────────────────────────────────────
291
  llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
292
  model_with_tools = llm.bind_tools(all_tools)
293
  tool_executor = ToolExecutor(all_tools)
294
 
295
- # ── State Definition ──────────────────────────────────────────────────────────
296
  class AgentState(TypedDict):
297
  messages: List[Any]
298
  patient_data: Optional[Dict[str, Any]]
@@ -301,19 +301,18 @@ class AgentState(TypedDict):
301
  done: Optional[bool]
302
  iterations: Optional[int]
303
 
304
- # Helper to propagate state fields
305
  def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
306
  for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
307
  if key in old and key not in new:
308
  new[key] = old[key]
309
  return new
310
 
311
- # ── Graph Nodes ───────────────────────────────────────────────────────────────
312
  def agent_node(state: AgentState) -> Dict[str, Any]:
313
- # Check for termination
314
  if state.get("done", False):
315
  return state
316
- msgs = state["messages"]
317
  if not msgs or not isinstance(msgs[0], SystemMessage):
318
  msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
319
  logger.info(f"Invoking LLM with {len(msgs)} messages")
@@ -329,7 +328,7 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
329
  def tool_node(state: AgentState) -> Dict[str, Any]:
330
  if state.get("done", False):
331
  return state
332
- last = state["messages"][-1]
333
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
334
  logger.warning("tool_node invoked without pending tool_calls")
335
  new_state = {"messages": []}
@@ -387,7 +386,7 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
387
  new_state = {"messages": []}
388
  return propagate_state(new_state, state)
389
  triggering = None
390
- for msg in reversed(state["messages"]):
391
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
392
  triggering = msg
393
  break
@@ -409,27 +408,35 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
409
  new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
410
  return propagate_state(new_state, state)
411
 
412
- # ── Routing Functions ────────────────────────────────────────────────────────
413
  def should_continue(state: AgentState) -> str:
414
  state.setdefault("iterations", 0)
415
  state["iterations"] += 1
416
  logger.info(f"Iteration count: {state['iterations']}")
 
417
  if state["iterations"] >= 4:
 
 
 
 
418
  state["done"] = True
419
  return "end_conversation_turn"
420
- last = state["messages"][-1] if state["messages"] else None
421
  if not isinstance(last, AIMessage):
422
  state["done"] = True
423
  return "end_conversation_turn"
424
  if getattr(last, "tool_calls", None):
425
  return "continue_tools"
426
- state["done"] = True
427
- return "end_conversation_turn"
 
 
 
428
 
429
  def after_tools_router(state: AgentState) -> str:
430
  return "reflection" if state.get("interaction_warnings") else "agent"
431
 
432
- # ── ClinicalAgent ────────────────────────────────────────────────────────────
433
  class ClinicalAgent:
434
  def __init__(self):
435
  logger.info("Building ClinicalAgent workflow")
@@ -452,7 +459,8 @@ class ClinicalAgent:
452
 
453
  def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
454
  try:
455
- result = self.graph_app.invoke(state, {"recursion_limit": 15})
 
456
  result.setdefault("summary", state.get("summary"))
457
  result.setdefault("interaction_warnings", None)
458
  return result
 
15
  from langgraph.prebuilt import ToolExecutor
16
  from langgraph.graph import StateGraph, END
17
 
18
+ # ── Logging Configuration ──────────────────────────────────────────────
19
  logger = logging.getLogger(__name__)
20
  logging.basicConfig(level=logging.INFO)
21
 
22
+ # ── Environment Variables ──────────────────────────────────────────────
23
  UMLS_API_KEY = os.getenv("UMLS_API_KEY")
24
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
 
28
  logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
29
  raise RuntimeError("Missing required API keys")
30
 
31
+ # ── Agent Configuration ──────────────────────────────────────────────
32
  AGENT_MODEL_NAME = "llama3-70b-8192"
33
  AGENT_TEMPERATURE = 0.1
34
  MAX_SEARCH_RESULTS = 3
 
39
  [SYSTEM PROMPT CONTENT HERE]
40
  """
41
 
42
+ # ── Helper Functions ─────────────────────────────────────────────────────
43
  UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
44
  RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
45
  OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
 
174
  lines.append(f"**{title}:** {value}")
175
  return "\n".join(lines)
176
 
177
+ # ── Tool Input Schemas ─────────────────────────────────────────────────────
178
  class LabOrderInput(BaseModel):
179
  test_name: str = Field(...)
180
  reason: str = Field(...)
 
197
  risk_description: str = Field(...)
198
  urgency: str = Field("High")
199
 
200
+ # ── Tool Implementations ───────────────────────────────────────────────────
201
  @tool("order_lab_test", args_schema=LabOrderInput)
202
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
203
  """
 
287
  search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
288
  all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
289
 
290
+ # ── LLM & Tool Executor ───────────────────────────────────────────────────
291
  llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
292
  model_with_tools = llm.bind_tools(all_tools)
293
  tool_executor = ToolExecutor(all_tools)
294
 
295
+ # ── State Definition ─────────────────────────────────────────────────────
296
  class AgentState(TypedDict):
297
  messages: List[Any]
298
  patient_data: Optional[Dict[str, Any]]
 
301
  done: Optional[bool]
302
  iterations: Optional[int]
303
 
304
+ # Helper to propagate state fields between nodes
305
  def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
306
  for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
307
  if key in old and key not in new:
308
  new[key] = old[key]
309
  return new
310
 
311
+ # ── Graph Nodes ─────────────────────────────────────────────────────────
312
  def agent_node(state: AgentState) -> Dict[str, Any]:
 
313
  if state.get("done", False):
314
  return state
315
+ msgs = state.get("messages", [])
316
  if not msgs or not isinstance(msgs[0], SystemMessage):
317
  msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
318
  logger.info(f"Invoking LLM with {len(msgs)} messages")
 
328
  def tool_node(state: AgentState) -> Dict[str, Any]:
329
  if state.get("done", False):
330
  return state
331
+ last = state.get("messages", [])[-1]
332
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
333
  logger.warning("tool_node invoked without pending tool_calls")
334
  new_state = {"messages": []}
 
386
  new_state = {"messages": []}
387
  return propagate_state(new_state, state)
388
  triggering = None
389
+ for msg in reversed(state.get("messages", [])):
390
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
391
  triggering = msg
392
  break
 
408
  new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
409
  return propagate_state(new_state, state)
410
 
411
+ # ── Routing Functions ────────────────────────────────────────────────────
412
  def should_continue(state: AgentState) -> str:
413
  state.setdefault("iterations", 0)
414
  state["iterations"] += 1
415
  logger.info(f"Iteration count: {state['iterations']}")
416
+ # When iterations exceed threshold, force a final message and mark done.
417
  if state["iterations"] >= 4:
418
+ state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
419
+ state["done"] = True
420
+ return "end_conversation_turn"
421
+ if not state.get("messages"):
422
  state["done"] = True
423
  return "end_conversation_turn"
424
+ last = state["messages"][-1]
425
  if not isinstance(last, AIMessage):
426
  state["done"] = True
427
  return "end_conversation_turn"
428
  if getattr(last, "tool_calls", None):
429
  return "continue_tools"
430
+ if "consultation complete" in last.content.lower():
431
+ state["done"] = True
432
+ return "end_conversation_turn"
433
+ state["done"] = False
434
+ return "agent"
435
 
436
  def after_tools_router(state: AgentState) -> str:
437
  return "reflection" if state.get("interaction_warnings") else "agent"
438
 
439
+ # ── ClinicalAgent ─────────────────────────────────────────────────────────
440
  class ClinicalAgent:
441
  def __init__(self):
442
  logger.info("Building ClinicalAgent workflow")
 
459
 
460
  def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
461
  try:
462
+ # Increase the recursion_limit as a temporary workaround if needed.
463
+ result = self.graph_app.invoke(state, {"recursion_limit": 100})
464
  result.setdefault("summary", state.get("summary"))
465
  result.setdefault("interaction_warnings", None)
466
  return result