File size: 7,398 Bytes
9c9b3ff
 
 
 
 
38cfdef
 
 
9c9b3ff
 
 
 
38cfdef
9c9b3ff
 
38cfdef
 
 
 
 
 
0edd622
 
38cfdef
 
 
 
 
 
 
 
 
 
 
0edd622
 
38cfdef
 
 
 
 
 
0edd622
9c9b3ff
 
38cfdef
 
 
9c9b3ff
 
 
38cfdef
9c9b3ff
 
 
38cfdef
 
9c9b3ff
 
38cfdef
9c9b3ff
 
38cfdef
9c9b3ff
 
 
 
 
 
 
 
 
38cfdef
9c9b3ff
38cfdef
 
 
 
 
 
 
 
 
 
9c9b3ff
38cfdef
 
 
 
 
 
 
 
 
9c9b3ff
38cfdef
9c9b3ff
 
38cfdef
9c9b3ff
 
 
38cfdef
9c9b3ff
 
 
 
38cfdef
 
9c9b3ff
38cfdef
 
9c9b3ff
38cfdef
9c9b3ff
38cfdef
 
 
9c9b3ff
 
38cfdef
0edd622
9c9b3ff
 
0edd622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38cfdef
0edd622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c9b3ff
 
 
 
 
 
 
 
 
0edd622
 
 
9c9b3ff
38cfdef
 
0edd622
38cfdef
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os

from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated, Literal

from langchain_community.tools import BraveSearch  # web search
from langchain_experimental.tools.python.tool import PythonAstREPLTool  # for logic/math problems

from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel)
from prompt import system_prompt

from langchain_core.runnables import RunnableConfig  # for LangSmith tracking

# LangSmith to observe the agent
langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
langsmith_tracing = os.getenv("LANGSMITH_TRACING")

def get_llm():
    api_keys = [os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1")]
    last_exception = None
    for key in api_keys:
        if not key:
            continue
        try:
            llm = ChatOpenAI(
                base_url="https://openrouter.ai/api/v1",
                api_key=key,
                model="qwen/qwen3-coder:free",
                temperature=1
            )
            # Optionally, test the key with a trivial call to ensure it's valid
            # llm.invoke([SystemMessage(content="ping")])
            return llm
        except Exception as e:
            last_exception = e
            continue
    raise RuntimeError(f"All OpenRouter API keys failed: {last_exception}")

llm = get_llm()
python_tool = PythonAstREPLTool()
search_tool = BraveSearch.from_api_key(
                    api_key=os.getenv("BRAVE_SEARCH_API"),
                    search_kwargs={"count": 4},  # returns the 4 best results and their URL
                    description="Web search using Brave"
)

community_tools = [search_tool, python_tool]
custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel]

tools = community_tools + custom_tools
llm_with_tools = llm.bind_tools(tools)

# Prepare tools by name
tools_by_name = {tool.name: tool for tool in tools}

class MessagesState(TypedDict):  # creates the state (is like the agent's memory at any moment)
    messages: Annotated[list[AnyMessage], add_messages]

# LLM node
def llm_call(state: MessagesState):
    return {
        "messages": [
            llm_with_tools.invoke(
                [SystemMessage(content=system_prompt)] + state["messages"]
            )
        ]
    }

# Tool node
def tool_node(state: MessagesState):
    """Executes the tools"""
    
    result = []
    for tool_call in state["messages"][-1].tool_calls:   # gives a list of the tools the LLM decided to call
        tool = tools_by_name[tool_call["name"]]   # look up the actual tool function using a dictionary
        observation = tool.invoke(tool_call["args"])   # executes the tool
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))   # the result from the tool is added to the memory
    return {"messages": result}   # thanks to add_messages, LangGraph will automatically append the result to the agent's message history

# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["Action", END]:
    """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""

    last_message = state["messages"][-1]  # looks at the last message (usually from the LLM)
    
    # If the LLM makes a tool call, then perform an action
    if last_message.tool_calls:
        return "Action"
    # Otherwise, we stop (reply to the user)
    return END

# Build workflow
builder = StateGraph(MessagesState)

# Add nodes
builder.add_node("llm_call", llm_call)
builder.add_node("environment", tool_node)

# Add edges to connect nodes
builder.add_edge(START, "llm_call")
builder.add_conditional_edges(
    "llm_call",
    should_continue,
    {"Action": "environment",  # name returned by should_continue : Name of the next node
     END: END}
)
    # If tool calls -> "Action" -> environment (executes the tool)
    # If no tool calls -> END

builder.add_edge("environment", "llm_call")  # after running the tools go back to the LLM for another round of reasoning

gaia_agent = builder.compile()  # converts my builder into a runnable agent by using gaia_agent.invoke()

# Wrapper class to initialize and call the LangGraph agent with a user question
class LangGraphAgent:
    def __init__(self):
        print("LangGraphAgent initialized.")
        self.question_count = 0  # Track the number of questions processed

    def __call__(self, question: str) -> str:
        # Determine which API key to use based on question count
        # First 50% of questions use OPENROUTER_API_KEY, rest use OPENROUTER_API_KEY_1
        api_key = os.getenv("OPENROUTER_API_KEY") if self.question_count % 2 == 0 else os.getenv("OPENROUTER_API_KEY_1")
        
        # Create a new LLM instance with the selected API key
        current_llm = ChatOpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=api_key,
            model="qwen/qwen3-coder:free",
            temperature=1
        )
        
        # Bind tools to the current LLM
        current_llm_with_tools = current_llm.bind_tools(tools)
        
        # Increment question counter for next call
        self.question_count += 1
        
        print(f"Running LangGraphAgent with input: {question[:150]}... (Using API key {self.question_count % 2 + 1})")
        
        # Create a custom LLM node for this specific question
        def custom_llm_call(state: MessagesState):
            return {
                "messages": [
                    current_llm_with_tools.invoke(
                        [SystemMessage(content=system_prompt)] + state["messages"]
                    )
                ]
            }
        
        # Build a new workflow with the custom LLM
        custom_builder = StateGraph(MessagesState)
        custom_builder.add_node("llm_call", custom_llm_call)
        custom_builder.add_node("environment", tool_node)
        custom_builder.add_edge(START, "llm_call")
        custom_builder.add_conditional_edges(
            "llm_call",
            should_continue,
            {"Action": "environment", END: END}
        )
        custom_builder.add_edge("environment", "llm_call")
        custom_agent = custom_builder.compile()
        
        # Prepare the initial state and config
        input_state = {"messages": [HumanMessage(content=question)]}
        config = RunnableConfig(
            config={
                "run_name": "GAIA Agent",
                "tags": ["gaia", "langgraph", "agent"],
                "metadata": {"user_input": question},
                "recursion_limit": 30,
                "tracing": True
            }
        )
        
        # Run the agent
        result = custom_agent.invoke(input_state, config)
        final_response = result["messages"][-1].content
        
        try:
            return final_response.split("FINAL ANSWER:")[-1].strip()
        except Exception:
            print("Could not split on 'FINAL ANSWER:'")
            return final_response