|
import os |
|
import re |
|
import json |
|
import logging |
|
import traceback |
|
from functools import lru_cache |
|
from typing import List, Dict, Any, Optional, TypedDict |
|
|
|
import requests |
|
from langchain_groq import ChatGroq |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolExecutor |
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
UMLS_API_KEY = os.getenv("UMLS_API_KEY") |
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]): |
|
logger.error("Missing required API keys") |
|
raise RuntimeError("Missing API keys") |
|
|
|
|
|
class ClinicalPrompts: |
|
SYSTEM_PROMPT = """ |
|
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation... |
|
[SYSTEM PROMPT CONTENT HERE] |
|
""" |
|
|
|
MAX_ITERATIONS = 4 |
|
AGENT_MODEL_NAME = "llama3-70b-8192" |
|
AGENT_TEMPERATURE = 0.1 |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: List[Any] |
|
patient_data: Optional[Dict[str, Any]] |
|
summary: Optional[str] |
|
interaction_warnings: Optional[List[str]] |
|
done: bool |
|
iterations: int |
|
|
|
def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Merge new state changes with existing state""" |
|
return {**old, **new} |
|
|
|
|
|
def agent_node(state: AgentState) -> Dict[str, Any]: |
|
"""Main agent node with iteration tracking""" |
|
state = dict(state) |
|
|
|
|
|
if state.get("done", False): |
|
return state |
|
|
|
|
|
iterations = state.get("iterations", 0) + 1 |
|
state["iterations"] = iterations |
|
|
|
|
|
if iterations >= MAX_ITERATIONS: |
|
return { |
|
"messages": [AIMessage(content="Consultation concluded. Maximum iterations reached.")], |
|
"done": True, |
|
**state |
|
} |
|
|
|
|
|
messages = state.get("messages", []) |
|
if not messages or not isinstance(messages[0], SystemMessage): |
|
messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + messages |
|
|
|
try: |
|
|
|
llm_response = ChatGroq( |
|
temperature=AGENT_TEMPERATURE, |
|
model=AGENT_MODEL_NAME |
|
).invoke(messages) |
|
|
|
return propagate_state({ |
|
"messages": [llm_response], |
|
"done": "consultation complete" in llm_response.content.lower() |
|
}, state) |
|
|
|
except Exception as e: |
|
logger.error(f"Agent error: {str(e)}") |
|
return propagate_state({ |
|
"messages": [AIMessage(content=f"System Error: {str(e)}")], |
|
"done": True |
|
}, state) |
|
|
|
|
|
tool_executor = ToolExecutor([ |
|
TavilySearchResults(max_results=3), |
|
|
|
]) |
|
|
|
def tool_node(state: AgentState) -> Dict[str, Any]: |
|
"""Execute tool calls from last agent message""" |
|
state = dict(state) |
|
messages = state["messages"] |
|
last_message = messages[-1] |
|
|
|
if not isinstance(last_message, AIMessage) or not last_message.tool_calls: |
|
return state |
|
|
|
tool_calls = last_message.tool_calls |
|
outputs = [] |
|
|
|
for tool_call in tool_calls: |
|
try: |
|
output = tool_executor.invoke(tool_call) |
|
outputs.append( |
|
ToolMessage( |
|
content=json.dumps(output), |
|
tool_call_id=tool_call["id"], |
|
name=tool_call["name"] |
|
) |
|
) |
|
except Exception as e: |
|
logger.error(f"Tool error: {str(e)}") |
|
outputs.append( |
|
ToolMessage( |
|
content=json.dumps({"error": str(e)}), |
|
tool_call_id=tool_call["id"], |
|
name=tool_call["name"] |
|
) |
|
) |
|
|
|
return propagate_state({ |
|
"messages": outputs, |
|
"interaction_warnings": detect_interaction_warnings(outputs) |
|
}, state) |
|
|
|
def detect_interaction_warnings(tool_messages: List[ToolMessage]) -> List[str]: |
|
"""Parse tool outputs for interaction warnings""" |
|
warnings = [] |
|
for msg in tool_messages: |
|
try: |
|
content = json.loads(msg.content) |
|
if content.get("status") == "warning": |
|
warnings.extend(content.get("warnings", [])) |
|
except json.JSONDecodeError: |
|
continue |
|
return warnings |
|
|
|
|
|
def reflection_node(state: AgentState) -> Dict[str, Any]: |
|
"""Analyze potential safety issues""" |
|
warnings = state.get("interaction_warnings", []) |
|
if not warnings: |
|
return state |
|
|
|
prompt = f"""Analyze these clinical warnings: |
|
{chr(10).join(warnings)} |
|
|
|
Provide concise safety recommendations:""" |
|
|
|
try: |
|
reflection = ChatGroq( |
|
temperature=0.0, |
|
model=AGENT_MODEL_NAME |
|
).invoke([HumanMessage(content=prompt)]) |
|
|
|
return propagate_state({ |
|
"messages": [reflection], |
|
"summary": f"Safety Review:\n{reflection.content}" |
|
}, state) |
|
|
|
except Exception as e: |
|
logger.error(f"Reflection error: {str(e)}") |
|
return propagate_state({ |
|
"messages": [AIMessage(content=f"Safety review unavailable: {str(e)}")], |
|
"summary": "Failed safety review" |
|
}, state) |
|
|
|
|
|
def route_state(state: AgentState) -> str: |
|
"""Determine next node in workflow""" |
|
if state.get("done", False): |
|
return "end" |
|
|
|
messages = state.get("messages", []) |
|
|
|
|
|
if state.get("interaction_warnings"): |
|
return "reflection" |
|
|
|
|
|
if messages and isinstance(messages[-1], AIMessage): |
|
if messages[-1].tool_calls: |
|
return "tools" |
|
|
|
return "agent" |
|
|
|
|
|
class ClinicalAgent: |
|
def __init__(self): |
|
self.workflow = StateGraph(AgentState) |
|
|
|
|
|
self.workflow.add_node("agent", agent_node) |
|
self.workflow.add_node("tools", tool_node) |
|
self.workflow.add_node("reflection", reflection_node) |
|
|
|
|
|
self.workflow.set_entry_point("agent") |
|
|
|
self.workflow.add_conditional_edges( |
|
"agent", |
|
lambda state: "tools" if state.get("messages")[-1].tool_calls else "end", |
|
{"tools": "tools", "end": END} |
|
) |
|
|
|
self.workflow.add_conditional_edges( |
|
"tools", |
|
lambda state: "reflection" if state.get("interaction_warnings") else "agent", |
|
{"reflection": "reflection", "agent": "agent"} |
|
) |
|
|
|
self.workflow.add_edge("reflection", "agent") |
|
|
|
self.app = self.workflow.compile() |
|
|
|
def consult(self, initial_state: Dict) -> Dict: |
|
"""Execute full consultation workflow""" |
|
try: |
|
return self.app.invoke( |
|
initial_state, |
|
{"recursion_limit": MAX_ITERATIONS + 2} |
|
) |
|
except Exception as e: |
|
logger.error(f"Consultation failed: {str(e)}") |
|
return { |
|
"error": str(e), |
|
"trace": traceback.format_exc(), |
|
"done": True |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
agent = ClinicalAgent() |
|
|
|
initial_state = { |
|
"messages": [HumanMessage(content="Patient presents with chest pain")], |
|
"patient_data": { |
|
"age": 45, |
|
"vitals": {"bp": "150/95", "hr": 110} |
|
}, |
|
"done": False, |
|
"iterations": 0 |
|
} |
|
|
|
result = agent.consult(initial_state) |
|
print("Final State:", json.dumps(result, indent=2)) |