Ahmud's picture
Update agent.py
0edd622 verified
raw
history blame
7.4 kB
import os
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated, Literal
from langchain_community.tools import BraveSearch # web search
from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel)
from prompt import system_prompt
from langchain_core.runnables import RunnableConfig # for LangSmith tracking
# LangSmith to observe the agent
langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
langsmith_tracing = os.getenv("LANGSMITH_TRACING")
def get_llm():
api_keys = [os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1")]
last_exception = None
for key in api_keys:
if not key:
continue
try:
llm = ChatOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=key,
model="qwen/qwen3-coder:free",
temperature=1
)
# Optionally, test the key with a trivial call to ensure it's valid
# llm.invoke([SystemMessage(content="ping")])
return llm
except Exception as e:
last_exception = e
continue
raise RuntimeError(f"All OpenRouter API keys failed: {last_exception}")
llm = get_llm()
python_tool = PythonAstREPLTool()
search_tool = BraveSearch.from_api_key(
api_key=os.getenv("BRAVE_SEARCH_API"),
search_kwargs={"count": 4}, # returns the 4 best results and their URL
description="Web search using Brave"
)
community_tools = [search_tool, python_tool]
custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel]
tools = community_tools + custom_tools
llm_with_tools = llm.bind_tools(tools)
# Prepare tools by name
tools_by_name = {tool.name: tool for tool in tools}
class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment)
messages: Annotated[list[AnyMessage], add_messages]
# LLM node
def llm_call(state: MessagesState):
return {
"messages": [
llm_with_tools.invoke(
[SystemMessage(content=system_prompt)] + state["messages"]
)
]
}
# Tool node
def tool_node(state: MessagesState):
"""Executes the tools"""
result = []
for tool_call in state["messages"][-1].tool_calls: # gives a list of the tools the LLM decided to call
tool = tools_by_name[tool_call["name"]] # look up the actual tool function using a dictionary
observation = tool.invoke(tool_call["args"]) # executes the tool
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) # the result from the tool is added to the memory
return {"messages": result} # thanks to add_messages, LangGraph will automatically append the result to the agent's message history
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["Action", END]:
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
last_message = state["messages"][-1] # looks at the last message (usually from the LLM)
# If the LLM makes a tool call, then perform an action
if last_message.tool_calls:
return "Action"
# Otherwise, we stop (reply to the user)
return END
# Build workflow
builder = StateGraph(MessagesState)
# Add nodes
builder.add_node("llm_call", llm_call)
builder.add_node("environment", tool_node)
# Add edges to connect nodes
builder.add_edge(START, "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{"Action": "environment", # name returned by should_continue : Name of the next node
END: END}
)
# If tool calls -> "Action" -> environment (executes the tool)
# If no tool calls -> END
builder.add_edge("environment", "llm_call") # after running the tools go back to the LLM for another round of reasoning
gaia_agent = builder.compile() # converts my builder into a runnable agent by using gaia_agent.invoke()
# Wrapper class to initialize and call the LangGraph agent with a user question
class LangGraphAgent:
def __init__(self):
print("LangGraphAgent initialized.")
self.question_count = 0 # Track the number of questions processed
def __call__(self, question: str) -> str:
# Determine which API key to use based on question count
# First 50% of questions use OPENROUTER_API_KEY, rest use OPENROUTER_API_KEY_1
api_key = os.getenv("OPENROUTER_API_KEY") if self.question_count % 2 == 0 else os.getenv("OPENROUTER_API_KEY_1")
# Create a new LLM instance with the selected API key
current_llm = ChatOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
model="qwen/qwen3-coder:free",
temperature=1
)
# Bind tools to the current LLM
current_llm_with_tools = current_llm.bind_tools(tools)
# Increment question counter for next call
self.question_count += 1
print(f"Running LangGraphAgent with input: {question[:150]}... (Using API key {self.question_count % 2 + 1})")
# Create a custom LLM node for this specific question
def custom_llm_call(state: MessagesState):
return {
"messages": [
current_llm_with_tools.invoke(
[SystemMessage(content=system_prompt)] + state["messages"]
)
]
}
# Build a new workflow with the custom LLM
custom_builder = StateGraph(MessagesState)
custom_builder.add_node("llm_call", custom_llm_call)
custom_builder.add_node("environment", tool_node)
custom_builder.add_edge(START, "llm_call")
custom_builder.add_conditional_edges(
"llm_call",
should_continue,
{"Action": "environment", END: END}
)
custom_builder.add_edge("environment", "llm_call")
custom_agent = custom_builder.compile()
# Prepare the initial state and config
input_state = {"messages": [HumanMessage(content=question)]}
config = RunnableConfig(
config={
"run_name": "GAIA Agent",
"tags": ["gaia", "langgraph", "agent"],
"metadata": {"user_input": question},
"recursion_limit": 30,
"tracing": True
}
)
# Run the agent
result = custom_agent.invoke(input_state, config)
final_response = result["messages"][-1].content
try:
return final_response.split("FINAL ANSWER:")[-1].strip()
except Exception:
print("Could not split on 'FINAL ANSWER:'")
return final_response