File size: 5,614 Bytes
ca8728d
 
4754c75
283e426
809f87e
283e426
809f87e
 
 
 
4754c75
809f87e
 
3568413
82e5cca
 
809f87e
 
 
 
 
 
 
 
 
82e5cca
 
 
 
 
809f87e
 
 
4754c75
 
 
26aec96
283e426
3568413
4754c75
3568413
 
4754c75
 
82e5cca
3568413
 
283e426
809f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283e426
 
 
 
809f87e
283e426
 
 
809f87e
 
 
 
 
 
 
 
 
26aec96
 
809f87e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# env variable needed: HF_TOKEN, OPENAI_API_KEY, BRAVE_SEARCH_API_KEY

import os
import json

from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, START, END
from langchain_community.tools import BraveSearch

from .prompt import system_prompt
from .custom_tools import (multiply, add, subtract, divide, modulus, power,
    query_image, automatic_speech_recognition, get_webpage_content, python_repl_tool,
    get_youtube_transcript)


class LangGraphAgent:
    def __init__(self,
                 model_name="gpt-4.1-nano",
                 show_tools_desc=True,
                 show_prompt=True):

        # =========== LLM definition ===========
        if model_name.startswith('o'):
            # reasoning model (no temperature setting)
            llm = ChatOpenAI(model=model_name) # needs OPENAI_API_KEY in env
        else:
            llm = ChatOpenAI(model=model_name, temperature=0)
        print(f"LangGraphAgent initialized with model \"{model_name}\"")

        # =========== Augment the LLM with tools ===========
        community_tools = [
            BraveSearch.from_api_key(   # Web search (more performant than DuckDuckGo)
                api_key=os.getenv("BRAVE_SEARCH_API_KEY"), # needs BRAVE_SEARCH_API_KEY in env
                search_kwargs={"count": 5}),
        ]
        custom_tools = [
            multiply, add, subtract, divide, modulus, power,  # Basic arithmetic
            query_image, # Ask anything about an image using a VLM
            automatic_speech_recognition, # Transcribe an audio file to text
            get_webpage_content, # Load a web page and return it to markdown
            python_repl_tool, # Python code interpreter
            get_youtube_transcript, # Get the transcript of a YouTube video
        ]

        tools = community_tools + custom_tools
        tools_by_name = {tool.name: tool for tool in tools}
        llm_with_tools = llm.bind_tools(tools)

        # =========== Agent definition ===========

        # Nodes
        def llm_call(state: MessagesState):
            """LLM decides whether to call a tool or not"""

            return {
                "messages": [
                    llm_with_tools.invoke(
                        [
                            SystemMessage(
                                content=system_prompt
                            )
                        ]
                        + state["messages"]
                    )
                ]
            }

        def tool_node(state: dict):
            """Performs the tool call"""

            result = []
            for tool_call in state["messages"][-1].tool_calls:
                tool = tools_by_name[tool_call["name"]]
                observation = tool.invoke(tool_call["args"])
                result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
            return {"messages": result}


        # Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
        def should_continue(state: MessagesState) -> Literal["environment", END]:
            """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""

            messages = state["messages"]
            last_message = messages[-1]
            # If the LLM makes a tool call, then perform an action
            if last_message.tool_calls:
                return "Action"
            # Otherwise, we stop (reply to the user)
            return END

        # Build workflow
        agent_builder = StateGraph(MessagesState)

        # Add nodes
        agent_builder.add_node("llm_call", llm_call)
        agent_builder.add_node("environment", tool_node)

        # Add edges to connect nodes
        agent_builder.add_edge(START, "llm_call")
        agent_builder.add_conditional_edges(
            "llm_call",
            should_continue,
            {
                # Name returned by should_continue : Name of next node to visit
                "Action": "environment",
                END: END,
            },
        )
        agent_builder.add_edge("environment", "llm_call")

        # Compile the agent
        self.agent = agent_builder.compile()

        if show_tools_desc:
            for i, tool in enumerate(llm_with_tools.kwargs['tools']):
                print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
                print(json.dumps(tool[tool['type']], indent=4))

        if show_prompt:
            print("\n" + "="*30 + f" System prompt " + "="*30)
            print(system_prompt)


    def __call__(self, question: str) -> str:
        print("\n\n"+"*"*50)
        print(f"Agent received question: {question}")
        print("*"*50)

        # Invoke
        messages = [HumanMessage(content=question)]
        messages = self.agent.invoke({"messages": messages},
                                     {"recursion_limit": 30}) # maximum number of steps before hitting a stop condition
        for m in messages["messages"]:
            m.pretty_print()

        # post-process the response (keep only what's after "FINAL ANSWER:" for the exact match)
        response = str(messages["messages"][-1].content)
        try:
            response = response.split("FINAL ANSWER:")[-1].strip()
        except:
            print('Could not split response on "FINAL ANSWER:"')
        print("\n\n"+"-"*50)
        print(f"Agent returning with answer: {response}")
        return response