SynapseAI / agent.py
mgbam's picture
Update agent.py
5fb0df4 verified
raw
history blame
9.63 kB
import os
import re
import json
import logging
import traceback
from functools import lru_cache
from typing import List, Dict, Any, Optional, TypedDict
import requests
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import StateGraph, END
# ── Logging Configuration ──────────────────────────────────────────────
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# ── Environment Variables ──────────────────────────────────────────────
UMLS_API_KEY = os.getenv("UMLS_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
logger.error("Missing required API keys")
raise RuntimeError("Missing API keys")
# ── Agent Configuration ──────────────────────────────────────────────
class ClinicalPrompts:
SYSTEM_PROMPT = """
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation...
[SYSTEM PROMPT CONTENT HERE]
"""
MAX_ITERATIONS = 4
AGENT_MODEL_NAME = "llama3-70b-8192"
AGENT_TEMPERATURE = 0.1
# ── State Definition ─────────────────────────────────────────────────
class AgentState(TypedDict):
messages: List[Any]
patient_data: Optional[Dict[str, Any]]
summary: Optional[str]
interaction_warnings: Optional[List[str]]
done: bool
iterations: int
def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
"""Merge new state changes with existing state"""
return {**old, **new}
# ── Core Agent Node ──────────────────────────────────────────────────
def agent_node(state: AgentState) -> Dict[str, Any]:
"""Main agent node with iteration tracking"""
state = dict(state) # Create mutable copy
# Check termination conditions
if state.get("done", False):
return state
# Update iteration count
iterations = state.get("iterations", 0) + 1
state["iterations"] = iterations
# Enforce iteration limit
if iterations >= MAX_ITERATIONS:
return {
"messages": [AIMessage(content="Consultation concluded. Maximum iterations reached.")],
"done": True,
**state
}
# Prepare message history
messages = state.get("messages", [])
if not messages or not isinstance(messages[0], SystemMessage):
messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + messages
try:
# Generate response
llm_response = ChatGroq(
temperature=AGENT_TEMPERATURE,
model=AGENT_MODEL_NAME
).invoke(messages)
return propagate_state({
"messages": [llm_response],
"done": "consultation complete" in llm_response.content.lower()
}, state)
except Exception as e:
logger.error(f"Agent error: {str(e)}")
return propagate_state({
"messages": [AIMessage(content=f"System Error: {str(e)}")],
"done": True
}, state)
# ── Tool Handling Nodes ──────────────────────────────────────────────
tool_executor = ToolExecutor([
TavilySearchResults(max_results=3),
# Include other tools here...
])
def tool_node(state: AgentState) -> Dict[str, Any]:
"""Execute tool calls from last agent message"""
state = dict(state)
messages = state["messages"]
last_message = messages[-1]
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return state
tool_calls = last_message.tool_calls
outputs = []
for tool_call in tool_calls:
try:
output = tool_executor.invoke(tool_call)
outputs.append(
ToolMessage(
content=json.dumps(output),
tool_call_id=tool_call["id"],
name=tool_call["name"]
)
)
except Exception as e:
logger.error(f"Tool error: {str(e)}")
outputs.append(
ToolMessage(
content=json.dumps({"error": str(e)}),
tool_call_id=tool_call["id"],
name=tool_call["name"]
)
)
return propagate_state({
"messages": outputs,
"interaction_warnings": detect_interaction_warnings(outputs)
}, state)
def detect_interaction_warnings(tool_messages: List[ToolMessage]) -> List[str]:
"""Parse tool outputs for interaction warnings"""
warnings = []
for msg in tool_messages:
try:
content = json.loads(msg.content)
if content.get("status") == "warning":
warnings.extend(content.get("warnings", []))
except json.JSONDecodeError:
continue
return warnings
# ── Safety Reflection Node ───────────────────────────────────────────
def reflection_node(state: AgentState) -> Dict[str, Any]:
"""Analyze potential safety issues"""
warnings = state.get("interaction_warnings", [])
if not warnings:
return state
prompt = f"""Analyze these clinical warnings:
{chr(10).join(warnings)}
Provide concise safety recommendations:"""
try:
reflection = ChatGroq(
temperature=0.0, # Strict safety mode
model=AGENT_MODEL_NAME
).invoke([HumanMessage(content=prompt)])
return propagate_state({
"messages": [reflection],
"summary": f"Safety Review:\n{reflection.content}"
}, state)
except Exception as e:
logger.error(f"Reflection error: {str(e)}")
return propagate_state({
"messages": [AIMessage(content=f"Safety review unavailable: {str(e)}")],
"summary": "Failed safety review"
}, state)
# ── State Routing Logic ──────────────────────────────────────────────
def route_state(state: AgentState) -> str:
"""Determine next node in workflow"""
if state.get("done", False):
return "end"
messages = state.get("messages", [])
# Prioritize safety reflection
if state.get("interaction_warnings"):
return "reflection"
# Check for tool calls
if messages and isinstance(messages[-1], AIMessage):
if messages[-1].tool_calls:
return "tools"
return "agent"
# ── Workflow Construction ────────────────────────────────────────────
class ClinicalAgent:
def __init__(self):
self.workflow = StateGraph(AgentState)
# Define nodes
self.workflow.add_node("agent", agent_node)
self.workflow.add_node("tools", tool_node)
self.workflow.add_node("reflection", reflection_node)
# Configure edges
self.workflow.set_entry_point("agent")
self.workflow.add_conditional_edges(
"agent",
lambda state: "tools" if state.get("messages")[-1].tool_calls else "end",
{"tools": "tools", "end": END}
)
self.workflow.add_conditional_edges(
"tools",
lambda state: "reflection" if state.get("interaction_warnings") else "agent",
{"reflection": "reflection", "agent": "agent"}
)
self.workflow.add_edge("reflection", "agent")
self.app = self.workflow.compile()
def consult(self, initial_state: Dict) -> Dict:
"""Execute full consultation workflow"""
try:
return self.app.invoke(
initial_state,
{"recursion_limit": MAX_ITERATIONS + 2}
)
except Exception as e:
logger.error(f"Consultation failed: {str(e)}")
return {
"error": str(e),
"trace": traceback.format_exc(),
"done": True
}
# ── Example Usage ────────────────────────────────────────────────────
if __name__ == "__main__":
agent = ClinicalAgent()
initial_state = {
"messages": [HumanMessage(content="Patient presents with chest pain")],
"patient_data": {
"age": 45,
"vitals": {"bp": "150/95", "hr": 110}
},
"done": False,
"iterations": 0
}
result = agent.consult(initial_state)
print("Final State:", json.dumps(result, indent=2))