gaia-agent / agent.py
daniel-amendoeira's picture
Update agent.py
e6871fc verified
import os
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated, Literal
from langchain_community.tools import BraveSearch # web search
from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel)
from prompt import system_prompt
from langchain_core.runnables import RunnableConfig # for LangSmith tracking
# LangSmith to observe the agent
langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
langsmith_tracing = os.getenv("LANGSMITH_TRACING")
# gpt-4o-mini: cheaper for debugging, temperature 0 for less randomness
# o4-mini: better reasoning for benchmarking, temperature 1 (default)
llm = ChatOpenAI(
model="o4-mini",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=1
)
python_tool = PythonAstREPLTool()
search_tool = BraveSearch.from_api_key(
api_key=os.getenv("BRAVE_SEARCH_API"),
search_kwargs={"count": 4}, # returns the 4 best results and their URL
description="Web search using Brave"
)
community_tools = [search_tool, python_tool]
custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel]
tools = community_tools + custom_tools
llm_with_tools = llm.bind_tools(tools)
# Prepare tools by name
tools_by_name = {tool.name: tool for tool in tools}
class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment)
messages: Annotated[list[AnyMessage], add_messages]
# LLM node
def llm_call(state: MessagesState):
return {
"messages": [
llm_with_tools.invoke(
[SystemMessage(content=system_prompt)] + state["messages"]
)
]
}
# Tool node
def tool_node(state: MessagesState):
"""Executes the tools"""
result = []
for tool_call in state["messages"][-1].tool_calls: # gives a list of the tools the LLM decided to call
tool = tools_by_name[tool_call["name"]] # look up the actual tool function using a dictionary
observation = tool.invoke(tool_call["args"]) # executes the tool
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) # the result from the tool is added to the memory
return {"messages": result} # thanks to add_messages, LangGraph will automatically append the result to the agent's message history
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["Action", END]:
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
last_message = state["messages"][-1] # looks at the last message (usually from the LLM)
# If the LLM makes a tool call, then perform an action
if last_message.tool_calls:
return "Action"
# Otherwise, we stop (reply to the user)
return END
# Build workflow
builder = StateGraph(MessagesState)
# Add nodes
builder.add_node("llm_call", llm_call)
builder.add_node("environment", tool_node)
# Add edges to connect nodes
builder.add_edge(START, "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{"Action": "environment", # name returned by should_continue : Name of the next node
END: END}
)
# If tool calls -> "Action" -> environment (executes the tool)
# If no tool calls -> END
builder.add_edge("environment", "llm_call") # after running the tools go back to the LLM for another round of reasoning
gaia_agent = builder.compile() # converts my builder into a runnable agent by using gaia_agent.invoke()
# Wrapper class to initialize and call the LangGraph agent with a user question
class LangGraphAgent:
def __init__(self):
print("LangGraphAgent initialized.")
def __call__(self, question: str) -> str:
input_state = {"messages": [HumanMessage(content=question)]} # prepare the initial user message
print(f"Running LangGraphAgent with input: {question[:150]}...")
# tracing configuration for LangSmith
config = RunnableConfig(
config={
"run_name": "GAIA Agent",
"tags": ["gaia", "langgraph", "agent"],
"metadata": {"user_input": question},
"recursion_limit": 30
}
)
result = gaia_agent.invoke(input_state, config) # prevents infinite looping when the LLM keeps calling tools over and over
final_response = result["messages"][-1].content
try:
return final_response.split("FINAL ANSWER:")[-1].strip() # parse out only what's after "FINAL ANSWER:"
except Exception:
print("Could not split on 'FINAL ANSWER:'")
return final_response