File size: 5,204 Bytes
a5164b3
 
86166ac
64488e6
 
2dabf6a
 
7f1aa93
2dabf6a
459f011
4142e2f
86166ac
0ddf13f
0aa3b8d
2dabf6a
cfb2c00
 
 
 
 
4a311c0
e6871fc
 
ca81ac5
b5d3429
d08bd76
769f332
ca81ac5
a5164b3
ca81ac5
a5164b3
 
 
 
ca81ac5
 
916369c
0ddf13f
ca81ac5
 
64488e6
 
3bfa53a
64488e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9815922
 
64488e6
9815922
 
 
 
 
64488e6
ecf32d9
64488e6
 
 
9815922
 
64488e6
 
 
 
 
ecf32d9
 
 
 
 
 
 
 
 
 
 
9815922
ecf32d9
9815922
 
 
a38ba97
 
ecf32d9
a38ba97
 
 
3bfa53a
4a311c0
87c5184
 
 
 
 
65b3f13
27cf5cc
cfb2c00
 
 
 
 
 
60d475a
 
cfb2c00
 
60d475a
27cf5cc
87c5184
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os

from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated, Literal

from langchain_community.tools import BraveSearch  # web search
from langchain_experimental.tools.python.tool import PythonAstREPLTool  # for logic/math problems

from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel)
from prompt import system_prompt

from langchain_core.runnables import RunnableConfig  # for LangSmith tracking

# LangSmith to observe the agent
langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
langsmith_tracing = os.getenv("LANGSMITH_TRACING")

# gpt-4o-mini: cheaper for debugging, temperature 0 for less randomness
# o4-mini: better reasoning for benchmarking, temperature 1 (default)
llm = ChatOpenAI(
    model="o4-mini",
    api_key=os.getenv("OPENAI_API_KEY"),
    temperature=1
)

python_tool = PythonAstREPLTool()
search_tool = BraveSearch.from_api_key(
                    api_key=os.getenv("BRAVE_SEARCH_API"),
                    search_kwargs={"count": 4},  # returns the 4 best results and their URL
                    description="Web search using Brave"
)

community_tools = [search_tool, python_tool]
custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel]

tools = community_tools + custom_tools
llm_with_tools = llm.bind_tools(tools)

# Prepare tools by name
tools_by_name = {tool.name: tool for tool in tools}

class MessagesState(TypedDict):  # creates the state (is like the agent's memory at any moment)
    messages: Annotated[list[AnyMessage], add_messages]

# LLM node
def llm_call(state: MessagesState):
    return {
        "messages": [
            llm_with_tools.invoke(
                [SystemMessage(content=system_prompt)] + state["messages"]
            )
        ]
    }

# Tool node
def tool_node(state: MessagesState):
    """Executes the tools"""
    
    result = []
    for tool_call in state["messages"][-1].tool_calls:   # gives a list of the tools the LLM decided to call
        tool = tools_by_name[tool_call["name"]]   # look up the actual tool function using a dictionary
        observation = tool.invoke(tool_call["args"])   # executes the tool
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))   # the result from the tool is added to the memory
    return {"messages": result}   # thanks to add_messages, LangGraph will automatically append the result to the agent's message history

# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["Action", END]:
    """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""

    last_message = state["messages"][-1]  # looks at the last message (usually from the LLM)
    
    # If the LLM makes a tool call, then perform an action
    if last_message.tool_calls:
        return "Action"
    # Otherwise, we stop (reply to the user)
    return END

# Build workflow
builder = StateGraph(MessagesState)

# Add nodes
builder.add_node("llm_call", llm_call)
builder.add_node("environment", tool_node)

# Add edges to connect nodes
builder.add_edge(START, "llm_call")
builder.add_conditional_edges(
    "llm_call",
    should_continue,
    {"Action": "environment",  # name returned by should_continue : Name of the next node
     END: END}
)
    # If tool calls -> "Action" -> environment (executes the tool)
    # If no tool calls -> END

builder.add_edge("environment", "llm_call")  # after running the tools go back to the LLM for another round of reasoning

gaia_agent = builder.compile()  # converts my builder into a runnable agent by using gaia_agent.invoke()

# Wrapper class to initialize and call the LangGraph agent with a user question
class LangGraphAgent:
    def __init__(self):
        print("LangGraphAgent initialized.")

    def __call__(self, question: str) -> str:
        input_state = {"messages": [HumanMessage(content=question)]}  # prepare the initial user message
        print(f"Running LangGraphAgent with input: {question[:150]}...")
        
        # tracing configuration for LangSmith
        config = RunnableConfig(
            config={
                "run_name": "GAIA Agent",
                "tags": ["gaia", "langgraph", "agent"],
                "metadata": {"user_input": question},
                "recursion_limit": 30
            }
        )
        result = gaia_agent.invoke(input_state, config)  # prevents infinite looping when the LLM keeps calling tools over and over
        final_response = result["messages"][-1].content
        
        try:
            return final_response.split("FINAL ANSWER:")[-1].strip()  #  parse out only what's after "FINAL ANSWER:"
        except Exception:
            print("Could not split on 'FINAL ANSWER:'")
            return final_response