File size: 5,614 Bytes
ca8728d 4754c75 283e426 809f87e 283e426 809f87e 4754c75 809f87e 3568413 82e5cca 809f87e 82e5cca 809f87e 4754c75 26aec96 283e426 3568413 4754c75 3568413 4754c75 82e5cca 3568413 283e426 809f87e 283e426 809f87e 283e426 809f87e 26aec96 809f87e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# env variable needed: HF_TOKEN, OPENAI_API_KEY, BRAVE_SEARCH_API_KEY
import os
import json
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, START, END
from langchain_community.tools import BraveSearch
from .prompt import system_prompt
from .custom_tools import (multiply, add, subtract, divide, modulus, power,
query_image, automatic_speech_recognition, get_webpage_content, python_repl_tool,
get_youtube_transcript)
class LangGraphAgent:
def __init__(self,
model_name="gpt-4.1-nano",
show_tools_desc=True,
show_prompt=True):
# =========== LLM definition ===========
if model_name.startswith('o'):
# reasoning model (no temperature setting)
llm = ChatOpenAI(model=model_name) # needs OPENAI_API_KEY in env
else:
llm = ChatOpenAI(model=model_name, temperature=0)
print(f"LangGraphAgent initialized with model \"{model_name}\"")
# =========== Augment the LLM with tools ===========
community_tools = [
BraveSearch.from_api_key( # Web search (more performant than DuckDuckGo)
api_key=os.getenv("BRAVE_SEARCH_API_KEY"), # needs BRAVE_SEARCH_API_KEY in env
search_kwargs={"count": 5}),
]
custom_tools = [
multiply, add, subtract, divide, modulus, power, # Basic arithmetic
query_image, # Ask anything about an image using a VLM
automatic_speech_recognition, # Transcribe an audio file to text
get_webpage_content, # Load a web page and return it to markdown
python_repl_tool, # Python code interpreter
get_youtube_transcript, # Get the transcript of a YouTube video
]
tools = community_tools + custom_tools
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)
# =========== Agent definition ===========
# Nodes
def llm_call(state: MessagesState):
"""LLM decides whether to call a tool or not"""
return {
"messages": [
llm_with_tools.invoke(
[
SystemMessage(
content=system_prompt
)
]
+ state["messages"]
)
]
}
def tool_node(state: dict):
"""Performs the tool call"""
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["environment", END]:
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then perform an action
if last_message.tool_calls:
return "Action"
# Otherwise, we stop (reply to the user)
return END
# Build workflow
agent_builder = StateGraph(MessagesState)
# Add nodes
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("environment", tool_node)
# Add edges to connect nodes
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges(
"llm_call",
should_continue,
{
# Name returned by should_continue : Name of next node to visit
"Action": "environment",
END: END,
},
)
agent_builder.add_edge("environment", "llm_call")
# Compile the agent
self.agent = agent_builder.compile()
if show_tools_desc:
for i, tool in enumerate(llm_with_tools.kwargs['tools']):
print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
print(json.dumps(tool[tool['type']], indent=4))
if show_prompt:
print("\n" + "="*30 + f" System prompt " + "="*30)
print(system_prompt)
def __call__(self, question: str) -> str:
print("\n\n"+"*"*50)
print(f"Agent received question: {question}")
print("*"*50)
# Invoke
messages = [HumanMessage(content=question)]
messages = self.agent.invoke({"messages": messages},
{"recursion_limit": 30}) # maximum number of steps before hitting a stop condition
for m in messages["messages"]:
m.pretty_print()
# post-process the response (keep only what's after "FINAL ANSWER:" for the exact match)
response = str(messages["messages"][-1].content)
try:
response = response.split("FINAL ANSWER:")[-1].strip()
except:
print('Could not split response on "FINAL ANSWER:"')
print("\n\n"+"-"*50)
print(f"Agent returning with answer: {response}")
return response |