Spaces:
Sleeping
Sleeping
import os | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage | |
from langgraph.graph.message import add_messages | |
from langgraph.graph import MessagesState | |
from langgraph.graph import StateGraph, START, END | |
from typing import TypedDict, Annotated, Literal | |
from langchain_community.tools import BraveSearch # web search | |
from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems | |
from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel) | |
from prompt import system_prompt | |
from langchain_core.runnables import RunnableConfig # for LangSmith tracking | |
# LangSmith to observe the agent | |
langsmith_api_key = os.getenv("LANGSMITH_API_KEY") | |
langsmith_tracing = os.getenv("LANGSMITH_TRACING") | |
def get_llm(): | |
api_keys = [os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1")] | |
last_exception = None | |
for key in api_keys: | |
if not key: | |
continue | |
try: | |
llm = ChatOpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=key, | |
model="qwen/qwen3-coder:free", | |
temperature=1 | |
) | |
# Optionally, test the key with a trivial call to ensure it's valid | |
# llm.invoke([SystemMessage(content="ping")]) | |
return llm | |
except Exception as e: | |
last_exception = e | |
continue | |
raise RuntimeError(f"All OpenRouter API keys failed: {last_exception}") | |
llm = get_llm() | |
python_tool = PythonAstREPLTool() | |
search_tool = BraveSearch.from_api_key( | |
api_key=os.getenv("BRAVE_SEARCH_API"), | |
search_kwargs={"count": 4}, # returns the 4 best results and their URL | |
description="Web search using Brave" | |
) | |
community_tools = [search_tool, python_tool] | |
custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel] | |
tools = community_tools + custom_tools | |
llm_with_tools = llm.bind_tools(tools) | |
# Prepare tools by name | |
tools_by_name = {tool.name: tool for tool in tools} | |
class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment) | |
messages: Annotated[list[AnyMessage], add_messages] | |
# LLM node | |
def llm_call(state: MessagesState): | |
return { | |
"messages": [ | |
llm_with_tools.invoke( | |
[SystemMessage(content=system_prompt)] + state["messages"] | |
) | |
] | |
} | |
# Tool node | |
def tool_node(state: MessagesState): | |
"""Executes the tools""" | |
result = [] | |
for tool_call in state["messages"][-1].tool_calls: # gives a list of the tools the LLM decided to call | |
tool = tools_by_name[tool_call["name"]] # look up the actual tool function using a dictionary | |
observation = tool.invoke(tool_call["args"]) # executes the tool | |
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) # the result from the tool is added to the memory | |
return {"messages": result} # thanks to add_messages, LangGraph will automatically append the result to the agent's message history | |
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call | |
def should_continue(state: MessagesState) -> Literal["Action", END]: | |
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" | |
last_message = state["messages"][-1] # looks at the last message (usually from the LLM) | |
# If the LLM makes a tool call, then perform an action | |
if last_message.tool_calls: | |
return "Action" | |
# Otherwise, we stop (reply to the user) | |
return END | |
# Build workflow | |
builder = StateGraph(MessagesState) | |
# Add nodes | |
builder.add_node("llm_call", llm_call) | |
builder.add_node("environment", tool_node) | |
# Add edges to connect nodes | |
builder.add_edge(START, "llm_call") | |
builder.add_conditional_edges( | |
"llm_call", | |
should_continue, | |
{"Action": "environment", # name returned by should_continue : Name of the next node | |
END: END} | |
) | |
# If tool calls -> "Action" -> environment (executes the tool) | |
# If no tool calls -> END | |
builder.add_edge("environment", "llm_call") # after running the tools go back to the LLM for another round of reasoning | |
gaia_agent = builder.compile() # converts my builder into a runnable agent by using gaia_agent.invoke() | |
# Wrapper class to initialize and call the LangGraph agent with a user question | |
class LangGraphAgent: | |
def __init__(self): | |
print("LangGraphAgent initialized.") | |
self.question_count = 0 # Track the number of questions processed | |
def __call__(self, question: str) -> str: | |
# Determine which API key to use based on question count | |
# First 50% of questions use OPENROUTER_API_KEY, rest use OPENROUTER_API_KEY_1 | |
api_key = os.getenv("OPENROUTER_API_KEY") if self.question_count % 2 == 0 else os.getenv("OPENROUTER_API_KEY_1") | |
# Create a new LLM instance with the selected API key | |
current_llm = ChatOpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=api_key, | |
model="qwen/qwen3-coder:free", | |
temperature=1 | |
) | |
# Bind tools to the current LLM | |
current_llm_with_tools = current_llm.bind_tools(tools) | |
# Increment question counter for next call | |
self.question_count += 1 | |
print(f"Running LangGraphAgent with input: {question[:150]}... (Using API key {self.question_count % 2 + 1})") | |
# Create a custom LLM node for this specific question | |
def custom_llm_call(state: MessagesState): | |
return { | |
"messages": [ | |
current_llm_with_tools.invoke( | |
[SystemMessage(content=system_prompt)] + state["messages"] | |
) | |
] | |
} | |
# Build a new workflow with the custom LLM | |
custom_builder = StateGraph(MessagesState) | |
custom_builder.add_node("llm_call", custom_llm_call) | |
custom_builder.add_node("environment", tool_node) | |
custom_builder.add_edge(START, "llm_call") | |
custom_builder.add_conditional_edges( | |
"llm_call", | |
should_continue, | |
{"Action": "environment", END: END} | |
) | |
custom_builder.add_edge("environment", "llm_call") | |
custom_agent = custom_builder.compile() | |
# Prepare the initial state and config | |
input_state = {"messages": [HumanMessage(content=question)]} | |
config = RunnableConfig( | |
config={ | |
"run_name": "GAIA Agent", | |
"tags": ["gaia", "langgraph", "agent"], | |
"metadata": {"user_input": question}, | |
"recursion_limit": 30, | |
"tracing": True | |
} | |
) | |
# Run the agent | |
result = custom_agent.invoke(input_state, config) | |
final_response = result["messages"][-1].content | |
try: | |
return final_response.split("FINAL ANSWER:")[-1].strip() | |
except Exception: | |
print("Could not split on 'FINAL ANSWER:'") | |
return final_response |