Update agent.py
Browse files
agent.py
CHANGED
@@ -15,11 +15,11 @@ from langchain_core.tools import tool
|
|
15 |
from langgraph.prebuilt import ToolExecutor
|
16 |
from langgraph.graph import StateGraph, END
|
17 |
|
18 |
-
# ββ Logging Configuration
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
|
22 |
-
# ββ Environment Variables
|
23 |
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
|
24 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
25 |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
@@ -28,7 +28,7 @@ if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
|
|
28 |
logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
|
29 |
raise RuntimeError("Missing required API keys")
|
30 |
|
31 |
-
# ββ Agent Configuration
|
32 |
AGENT_MODEL_NAME = "llama3-70b-8192"
|
33 |
AGENT_TEMPERATURE = 0.1
|
34 |
MAX_SEARCH_RESULTS = 3
|
@@ -39,7 +39,7 @@ class ClinicalPrompts:
|
|
39 |
[SYSTEM PROMPT CONTENT HERE]
|
40 |
"""
|
41 |
|
42 |
-
# ββ Helper Functions
|
43 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
44 |
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
45 |
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
@@ -174,7 +174,7 @@ def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
|
|
174 |
lines.append(f"**{title}:** {value}")
|
175 |
return "\n".join(lines)
|
176 |
|
177 |
-
# ββ Tool Input Schemas
|
178 |
class LabOrderInput(BaseModel):
|
179 |
test_name: str = Field(...)
|
180 |
reason: str = Field(...)
|
@@ -197,7 +197,7 @@ class FlagRiskInput(BaseModel):
|
|
197 |
risk_description: str = Field(...)
|
198 |
urgency: str = Field("High")
|
199 |
|
200 |
-
# ββ Tool Implementations
|
201 |
@tool("order_lab_test", args_schema=LabOrderInput)
|
202 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
203 |
"""
|
@@ -287,12 +287,12 @@ def flag_risk(risk_description: str, urgency: str = "High") -> str:
|
|
287 |
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
|
288 |
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
289 |
|
290 |
-
# ββ LLM & Tool Executor
|
291 |
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
|
292 |
model_with_tools = llm.bind_tools(all_tools)
|
293 |
tool_executor = ToolExecutor(all_tools)
|
294 |
|
295 |
-
# ββ State Definition
|
296 |
class AgentState(TypedDict):
|
297 |
messages: List[Any]
|
298 |
patient_data: Optional[Dict[str, Any]]
|
@@ -301,19 +301,18 @@ class AgentState(TypedDict):
|
|
301 |
done: Optional[bool]
|
302 |
iterations: Optional[int]
|
303 |
|
304 |
-
# Helper to propagate state fields
|
305 |
def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
|
306 |
for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
|
307 |
if key in old and key not in new:
|
308 |
new[key] = old[key]
|
309 |
return new
|
310 |
|
311 |
-
# ββ Graph Nodes
|
312 |
def agent_node(state: AgentState) -> Dict[str, Any]:
|
313 |
-
# Check for termination
|
314 |
if state.get("done", False):
|
315 |
return state
|
316 |
-
msgs = state
|
317 |
if not msgs or not isinstance(msgs[0], SystemMessage):
|
318 |
msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
|
319 |
logger.info(f"Invoking LLM with {len(msgs)} messages")
|
@@ -329,7 +328,7 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
|
|
329 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
330 |
if state.get("done", False):
|
331 |
return state
|
332 |
-
last = state
|
333 |
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
|
334 |
logger.warning("tool_node invoked without pending tool_calls")
|
335 |
new_state = {"messages": []}
|
@@ -387,7 +386,7 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
387 |
new_state = {"messages": []}
|
388 |
return propagate_state(new_state, state)
|
389 |
triggering = None
|
390 |
-
for msg in reversed(state
|
391 |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
392 |
triggering = msg
|
393 |
break
|
@@ -409,27 +408,35 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
409 |
new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
|
410 |
return propagate_state(new_state, state)
|
411 |
|
412 |
-
# ββ Routing Functions
|
413 |
def should_continue(state: AgentState) -> str:
|
414 |
state.setdefault("iterations", 0)
|
415 |
state["iterations"] += 1
|
416 |
logger.info(f"Iteration count: {state['iterations']}")
|
|
|
417 |
if state["iterations"] >= 4:
|
|
|
|
|
|
|
|
|
418 |
state["done"] = True
|
419 |
return "end_conversation_turn"
|
420 |
-
last = state["messages"][-1]
|
421 |
if not isinstance(last, AIMessage):
|
422 |
state["done"] = True
|
423 |
return "end_conversation_turn"
|
424 |
if getattr(last, "tool_calls", None):
|
425 |
return "continue_tools"
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
428 |
|
429 |
def after_tools_router(state: AgentState) -> str:
|
430 |
return "reflection" if state.get("interaction_warnings") else "agent"
|
431 |
|
432 |
-
# ββ ClinicalAgent
|
433 |
class ClinicalAgent:
|
434 |
def __init__(self):
|
435 |
logger.info("Building ClinicalAgent workflow")
|
@@ -452,7 +459,8 @@ class ClinicalAgent:
|
|
452 |
|
453 |
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
454 |
try:
|
455 |
-
|
|
|
456 |
result.setdefault("summary", state.get("summary"))
|
457 |
result.setdefault("interaction_warnings", None)
|
458 |
return result
|
|
|
15 |
from langgraph.prebuilt import ToolExecutor
|
16 |
from langgraph.graph import StateGraph, END
|
17 |
|
18 |
+
# ββ Logging Configuration ββββββββββββββββββββββββββββββββββββββββββββββ
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
|
22 |
+
# ββ Environment Variables ββββββββββββββββββββββββββββββββββββββββββββββ
|
23 |
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
|
24 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
25 |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
|
|
|
28 |
logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
|
29 |
raise RuntimeError("Missing required API keys")
|
30 |
|
31 |
+
# ββ Agent Configuration ββββββββββββββββββββββββββββββββββββββββββββββ
|
32 |
AGENT_MODEL_NAME = "llama3-70b-8192"
|
33 |
AGENT_TEMPERATURE = 0.1
|
34 |
MAX_SEARCH_RESULTS = 3
|
|
|
39 |
[SYSTEM PROMPT CONTENT HERE]
|
40 |
"""
|
41 |
|
42 |
+
# ββ Helper Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
43 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
44 |
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
45 |
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
|
|
174 |
lines.append(f"**{title}:** {value}")
|
175 |
return "\n".join(lines)
|
176 |
|
177 |
+
# ββ Tool Input Schemas βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
178 |
class LabOrderInput(BaseModel):
|
179 |
test_name: str = Field(...)
|
180 |
reason: str = Field(...)
|
|
|
197 |
risk_description: str = Field(...)
|
198 |
urgency: str = Field("High")
|
199 |
|
200 |
+
# ββ Tool Implementations βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
201 |
@tool("order_lab_test", args_schema=LabOrderInput)
|
202 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
203 |
"""
|
|
|
287 |
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
|
288 |
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
289 |
|
290 |
+
# ββ LLM & Tool Executor βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
291 |
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
|
292 |
model_with_tools = llm.bind_tools(all_tools)
|
293 |
tool_executor = ToolExecutor(all_tools)
|
294 |
|
295 |
+
# ββ State Definition βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
296 |
class AgentState(TypedDict):
|
297 |
messages: List[Any]
|
298 |
patient_data: Optional[Dict[str, Any]]
|
|
|
301 |
done: Optional[bool]
|
302 |
iterations: Optional[int]
|
303 |
|
304 |
+
# Helper to propagate state fields between nodes
|
305 |
def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
|
306 |
for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
|
307 |
if key in old and key not in new:
|
308 |
new[key] = old[key]
|
309 |
return new
|
310 |
|
311 |
+
# ββ Graph Nodes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
312 |
def agent_node(state: AgentState) -> Dict[str, Any]:
|
|
|
313 |
if state.get("done", False):
|
314 |
return state
|
315 |
+
msgs = state.get("messages", [])
|
316 |
if not msgs or not isinstance(msgs[0], SystemMessage):
|
317 |
msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
|
318 |
logger.info(f"Invoking LLM with {len(msgs)} messages")
|
|
|
328 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
329 |
if state.get("done", False):
|
330 |
return state
|
331 |
+
last = state.get("messages", [])[-1]
|
332 |
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
|
333 |
logger.warning("tool_node invoked without pending tool_calls")
|
334 |
new_state = {"messages": []}
|
|
|
386 |
new_state = {"messages": []}
|
387 |
return propagate_state(new_state, state)
|
388 |
triggering = None
|
389 |
+
for msg in reversed(state.get("messages", [])):
|
390 |
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
391 |
triggering = msg
|
392 |
break
|
|
|
408 |
new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
|
409 |
return propagate_state(new_state, state)
|
410 |
|
411 |
+
# ββ Routing Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
412 |
def should_continue(state: AgentState) -> str:
|
413 |
state.setdefault("iterations", 0)
|
414 |
state["iterations"] += 1
|
415 |
logger.info(f"Iteration count: {state['iterations']}")
|
416 |
+
# When iterations exceed threshold, force a final message and mark done.
|
417 |
if state["iterations"] >= 4:
|
418 |
+
state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
|
419 |
+
state["done"] = True
|
420 |
+
return "end_conversation_turn"
|
421 |
+
if not state.get("messages"):
|
422 |
state["done"] = True
|
423 |
return "end_conversation_turn"
|
424 |
+
last = state["messages"][-1]
|
425 |
if not isinstance(last, AIMessage):
|
426 |
state["done"] = True
|
427 |
return "end_conversation_turn"
|
428 |
if getattr(last, "tool_calls", None):
|
429 |
return "continue_tools"
|
430 |
+
if "consultation complete" in last.content.lower():
|
431 |
+
state["done"] = True
|
432 |
+
return "end_conversation_turn"
|
433 |
+
state["done"] = False
|
434 |
+
return "agent"
|
435 |
|
436 |
def after_tools_router(state: AgentState) -> str:
|
437 |
return "reflection" if state.get("interaction_warnings") else "agent"
|
438 |
|
439 |
+
# ββ ClinicalAgent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
440 |
class ClinicalAgent:
|
441 |
def __init__(self):
|
442 |
logger.info("Building ClinicalAgent workflow")
|
|
|
459 |
|
460 |
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
461 |
try:
|
462 |
+
# Increase the recursion_limit as a temporary workaround if needed.
|
463 |
+
result = self.graph_app.invoke(state, {"recursion_limit": 100})
|
464 |
result.setdefault("summary", state.get("summary"))
|
465 |
result.setdefault("interaction_warnings", None)
|
466 |
return result
|