MedQA / agent.py
mgbam's picture
Update agent.py
29659ac verified
raw
history blame
6.29 kB
import os
import sys
from typing import List, Union
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import MessagesPlaceholder
from langchain.schema import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain.agents import AgentExecutor, create_structured_chat_agent
from langchain_google_genai import ChatGoogleGenerativeAI
from config.settings import settings
from services.logger import app_logger
from tools import BioPortalLookupTool, UMLSLookupTool, QuantumTreatmentOptimizerTool
# -----------------------------------------------------------------------------
# 1. Initialize the Gemini LLM
# -----------------------------------------------------------------------------
def _init_llm() -> ChatGoogleGenerativeAI:
"""
Initialize the Google Gemini LLM with the configured API key.
Raises ValueError if no key is found or initialization fails.
"""
api_key = settings.GEMINI_API_KEY or os.getenv("GOOGLE_API_KEY")
if not api_key:
err = "Gemini API key not found: set GEMINI_API_KEY in settings or GOOGLE_API_KEY in env"
app_logger.error(err)
raise ValueError(err)
try:
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro-latest",
temperature=0.2,
google_api_key=api_key,
convert_system_message_to_human=True,
)
app_logger.info(f"Gemini LLM initialized ({llm.model})")
return llm
except Exception as e:
err = f"Failed to initialize Gemini LLM: {e}"
app_logger.error(err, exc_info=True)
raise ValueError(err)
# -----------------------------------------------------------------------------
# 2. Build the structured chat prompt
# -----------------------------------------------------------------------------
def _build_prompt_template(tool_names: List[str], tools) -> ChatPromptTemplate:
"""
Construct a ChatPromptTemplate that includes:
- a system instruction block,
- a placeholder for chat_history (List[BaseMessage]),
- the current human input,
- a placeholder for agent_scratchpad (List[BaseMessage]) to manage tool calls.
"""
system_text = (
"You are Quantum Health Navigator, an AI assistant for healthcare professionals.\n\n"
"β€’ Disclaim: you are an AI, not a substitute for clinical judgment.\n"
"β€’ Patient context: {patient_context}\n"
"β€’ Available tools: {tool_names}\n"
"{tools}\n\n"
"To call a tool, reply *only* with a JSON code block:\n"
"{{\"action\": \"<tool_name>\", \"action_input\": <input>}}\n\n"
"After you receive the tool’s output, craft a full answer for the user, citing any tools used."
)
return ChatPromptTemplate.from_messages([
("system", system_text),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
# -----------------------------------------------------------------------------
# 3. Lazily build and return the AgentExecutor singleton
# -----------------------------------------------------------------------------
def get_agent_executor() -> AgentExecutor:
"""
Returns a singleton AgentExecutor, creating it on first call.
Sets up LLM, tools, prompt, and executor params.
"""
global _agent_executor_instance
if "_agent_executor_instance" not in globals():
# 3.1 Initialize LLM
llm = _init_llm()
# 3.2 Prepare tools
tools_list = [
UMLSLookupTool(),
BioPortalLookupTool(),
QuantumTreatmentOptimizerTool(),
]
app_logger.info(f"Loaded tools: {[t.name for t in tools_list]}")
# 3.3 Build prompt
prompt = _build_prompt_template(
tool_names=[t.name for t in tools_list],
tools=tools_list
)
app_logger.info("Prompt template built")
# 3.4 Create the structured agent
agent = create_structured_chat_agent(
llm=llm,
tools=tools_list,
prompt=prompt
)
app_logger.info("Structured chat agent created")
# 3.5 Create the executor
executor = AgentExecutor(
agent=agent,
tools=tools_list,
verbose=True,
handle_parsing_errors=True,
max_iterations=10,
early_stopping_method="generate",
)
app_logger.info("AgentExecutor initialized")
_agent_executor_instance = executor
return _agent_executor_instance
# -----------------------------------------------------------------------------
# 4. Optional REPL for local testing
# -----------------------------------------------------------------------------
if __name__ == "__main__":
try:
executor = get_agent_executor()
except Exception as e:
print(f"❌ Initialization failed: {e}")
sys.exit(1)
# Sample patient context for testing
patient_context = (
"Age: 58; Gender: Female; Chief Complaint: Blurry vision & fatigue; "
"History: Prediabetes, mild dyslipidemia; Medications: None."
)
chat_history: List[Union[SystemMessage, HumanMessage, AIMessage]] = []
print("πŸš€ Quantum Health Navigator Console (type 'exit' to quit)")
while True:
user_input = input("πŸ‘€ You: ").strip()
if user_input.lower() in {"exit", "quit"}:
print("πŸ‘‹ Goodbye!")
break
if not user_input:
continue
try:
result = executor.invoke({
"input": user_input,
"chat_history": chat_history,
"patient_context": patient_context
})
reply = result.get("output", "")
print(f"πŸ€– Agent: {reply}\n")
# Update history
chat_history.append(HumanMessage(content=user_input))
chat_history.append(AIMessage(content=reply))
# Trim to last 20 messages
if len(chat_history) > 20:
chat_history = chat_history[-20:]
except Exception as err:
print(f"⚠️ Inference error: {err}")
app_logger.error("Runtime error in REPL", exc_info=True)