MedQA / agent.py
mgbam's picture
Update agent.py
a6d04e1 verified
raw
history blame
10.8 kB
# /home/user/app/agent.py
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor, create_structured_chat_agent # Using a more general agent
# If you want to try Gemini's native function calling (experimental and might require specific model versions):
# from langchain_google_genai.chat_models import GChatVertexAI # For Vertex AI
# from langchain_google_genai import HarmBlockThreshold, HarmCategory # For safety settings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
# from langchain_community.chat_message_histories import ChatMessageHistory # Not used directly here
from tools import (
UMLSLookupTool, BioPortalLookupTool, QuantumTreatmentOptimizerTool
# GeminiTool might be redundant if the main LLM is Gemini, unless it's for a different Gemini model/task.
)
from config.settings import settings
from services.logger import app_logger
# --- Initialize LLM (Gemini) ---
# Ensure GOOGLE_API_KEY is set in your environment, or pass it directly:
# api_key=settings.GEMINI_API_KEY (if settings.GEMINI_API_KEY maps to GOOGLE_API_KEY)
try:
llm = ChatGoogleGenerativeAI(
model="gemini-pro", # Or "gemini-1.5-pro-latest" if available and preferred
temperature=0.3,
# google_api_key=settings.GEMINI_API_KEY, # Explicitly pass if needed
# convert_system_message_to_human=True, # Sometimes helpful for models not strictly adhering to system role
# safety_settings={ # Optional: configure safety settings
# HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
# HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
# }
)
app_logger.info("ChatGoogleGenerativeAI (Gemini) initialized successfully.")
except Exception as e:
app_logger.error(f"Failed to initialize ChatGoogleGenerativeAI: {e}", exc_info=True)
# Raise an error that can be caught by get_agent_executor to inform the user
raise ValueError(f"Gemini LLM initialization failed: {e}. Check API key and configurations.")
# --- Initialize Tools ---
# Ensure your tools' descriptions are very clear, especially for non-OpenAI function calling agents.
tools = [
UMLSLookupTool(),
BioPortalLookupTool(),
QuantumTreatmentOptimizerTool(),
]
app_logger.info(f"Tools initialized: {[tool.name for tool in tools]}")
# --- Agent Prompt (Adapted for a general structured chat agent) ---
# This prompt needs to guide the LLM to decide on tool use and format tool calls.
# It's more complex than relying on native function calling.
# We might need to instruct it to output a specific JSON structure for tool calls.
# For now, let's try a ReAct-style approach with create_structured_chat_agent.
# LangChain Hub has prompts for this: hub.pull("hwchase17/structured-chat-agent")
# Or we define a custom one:
# This is a simplified prompt. For robust tool use with Gemini (without native function calling),
# you'd often use a ReAct prompt or a prompt that guides it to output JSON for tool calls.
# `create_structured_chat_agent` is designed to work with models that can follow instructions
# to produce a structured output, often for tool usage.
# For Gemini, which now supports tool calling via its API directly (though LangChain integration might vary),
# we *could* try to structure the prompt for that if `ChatGoogleGenerativeAI` has good support.
# Let's assume a more general structured chat approach first if direct tool calling isn't as smooth
# as with OpenAI's function calling agents.
# If trying Gemini's newer tool calling features with LangChain (ensure your langchain-google-genai is up to date):
# You might be able to bind tools directly to the LLM and use a simpler agent structure.
# llm_with_tools = llm.bind_tools(tools) # This is the newer LangChain syntax
# Then create an agent that leverages this.
# For now, let's use `create_structured_chat_agent` which is more general.
# This prompt is similar to hwchase17/structured-chat-agent on Langsmith Hub
SYSTEM_PROMPT_TEXT = (
"You are 'Quantum Health Navigator', a helpful AI assistant for healthcare professionals. "
"Respond to the human as helpfully and accurately as possible. "
"You have access to the following tools:\n\n"
"{tools}\n\n" # This will be filled by the agent with tool names and descriptions
"To use a tool, you can use the following format:\n\n"
"```json\n"
"{{\n"
' "action": "tool_name",\n'
' "action_input": "input_to_tool"\n' # For tools with single string input
# Or for tools with structured input (like QuantumTreatmentOptimizerTool):
# ' "action_input": {{"arg1": "value1", "arg2": "value2"}}\n'
"}}\n"
"```\n\n"
"If you use a tool, the system will give you the observation from the tool. "
"Then you must respond to the human based on this observation and your knowledge. "
"If the human asks a question that doesn't require a tool, answer directly. "
"When asked about treatment optimization for a specific patient based on provided context, "
"you MUST use the 'quantum_treatment_optimizer' tool. "
"For general medical knowledge, you can answer directly or use UMLS/BioPortal. "
"Always cite the tool you used if its output is part of your final response. "
"Do not provide medical advice directly for specific patient cases without using the 'quantum_treatment_optimizer' tool. "
"Patient Context for this session (if provided by the user earlier): {patient_context}\n" # Added patient_context
"Begin!\n\n"
"Previous conversation history:\n"
"{chat_history}\n\n"
"New human question: {input}\n"
"{agent_scratchpad}" # Placeholder for agent's thoughts and tool outputs
)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT_TEXT),
# MessagesPlaceholder(variable_name="chat_history"), # chat_history is now part of the system prompt
# ("human", "{input}"), # input is now part of the system prompt
MessagesPlaceholder(variable_name="agent_scratchpad"), # For structured chat agent, this is important
])
# Note: The `create_structured_chat_agent` expects `input` and `chat_history` to be implicitly handled
# or passed through `agent_scratchpad` based on how it formats things.
# The prompt structure might need adjustment based on the exact agent behavior.
# Often, for these agents, you pass "input" and "chat_history" to invoke, and the prompt template variables
# are {input}, {chat_history}, {agent_scratchpad}, {tools}, {tool_names}.
# For create_structured_chat_agent, the prompt should guide the LLM to produce
# either a final answer or a JSON blob for a tool call.
# The input variables for the prompt are typically 'input', 'chat_history', 'agent_scratchpad', 'tools', 'tool_names'.
# Our SYSTEM_PROMPT_TEXT includes these implicitly or explicitly.
# --- Create Agent ---
try:
# `create_structured_chat_agent` is designed for LLMs that can follow complex instructions
# and output structured data (like JSON for tool calls) when prompted to do so.
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
app_logger.info("Structured chat agent created successfully with Gemini.")
except Exception as e:
app_logger.error(f"Failed to create structured chat agent: {e}", exc_info=True)
raise ValueError(f"Agent creation failed: {e}")
# --- Create Agent Executor ---
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True, # Important for agents that parse LLM output for tool calls
# Example: "Could not parse LLM output: `...`" - then it can retry or return this error.
# max_iterations=7, # Good to prevent overly long chains
# return_intermediate_steps=True # Useful for debugging to see thought process
)
app_logger.info("AgentExecutor created successfully.")
# --- Getter Function for Streamlit App ---
def get_agent_executor():
"""Returns the configured agent executor for Gemini."""
# The llm and agent_executor are already initialized.
# We check for the API key here as a safeguard, though initialization would have failed earlier.
if not (settings.GEMINI_API_KEY or os.environ.get("GOOGLE_API_KEY")): # Check both setting and env var
app_logger.error("GOOGLE_API_KEY (for Gemini) not set. Agent will not function.")
raise ValueError("Google API Key for Gemini not configured. Agent cannot be initialized.")
return agent_executor
# --- Example Usage (for local testing) ---
import os # For checking GOOGLE_API_KEY from environment
if __name__ == "__main__":
if not (settings.GEMINI_API_KEY or os.environ.get("GOOGLE_API_KEY")):
print("Please set your GOOGLE_API_KEY (for Gemini) in .env file or environment.")
else:
print("Gemini Agent Test Console (type 'exit' or 'quit' to stop)")
executor = get_agent_executor()
# For structured chat agents, chat_history is often passed in the invoke call.
# The agent prompt includes {chat_history}.
current_chat_history = [] # List of HumanMessage, AIMessage
# Initial patient context (simulated for testing)
patient_context_for_test = {
"age": 35,
"gender": "Male",
"key_medical_history": "Type 2 Diabetes, Hypertension",
"current_medications": "Metformin, Lisinopril"
}
context_summary_parts_test = [f"{k.replace('_', ' ').title()}: {v}" for k, v in patient_context_for_test.items() if v]
patient_context_str_test = "; ".join(context_summary_parts_test) if context_summary_parts_test else "None provided."
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit"]:
break
try:
response = executor.invoke({
"input": user_input,
"chat_history": current_chat_history,
"patient_context": patient_context_str_test # Passing patient context
# `tools` and `tool_names` are usually handled by the agent constructor
})
ai_output = response.get('output', "No output from agent.")
print(f"Agent: {ai_output}")
current_chat_history.append(HumanMessage(content=user_input))
current_chat_history.append(AIMessage(content=ai_output))
except Exception as e:
print(f"Error invoking agent: {e}")
app_logger.error(f"Error in __main__ agent test: {e}", exc_info=True)