guillaumefrd's picture
clean
ca8728d
# env variable needed: HF_TOKEN, OPENAI_API_KEY, BRAVE_SEARCH_API_KEY
import os
import json
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, START, END
from langchain_community.tools import BraveSearch
from .prompt import system_prompt
from .custom_tools import (multiply, add, subtract, divide, modulus, power,
query_image, automatic_speech_recognition, get_webpage_content, python_repl_tool,
get_youtube_transcript)
class LangGraphAgent:
def __init__(self,
model_name="gpt-4.1-nano",
show_tools_desc=True,
show_prompt=True):
# =========== LLM definition ===========
if model_name.startswith('o'):
# reasoning model (no temperature setting)
llm = ChatOpenAI(model=model_name) # needs OPENAI_API_KEY in env
else:
llm = ChatOpenAI(model=model_name, temperature=0)
print(f"LangGraphAgent initialized with model \"{model_name}\"")
# =========== Augment the LLM with tools ===========
community_tools = [
BraveSearch.from_api_key( # Web search (more performant than DuckDuckGo)
api_key=os.getenv("BRAVE_SEARCH_API_KEY"), # needs BRAVE_SEARCH_API_KEY in env
search_kwargs={"count": 5}),
]
custom_tools = [
multiply, add, subtract, divide, modulus, power, # Basic arithmetic
query_image, # Ask anything about an image using a VLM
automatic_speech_recognition, # Transcribe an audio file to text
get_webpage_content, # Load a web page and return it to markdown
python_repl_tool, # Python code interpreter
get_youtube_transcript, # Get the transcript of a YouTube video
]
tools = community_tools + custom_tools
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)
# =========== Agent definition ===========
# Nodes
def llm_call(state: MessagesState):
"""LLM decides whether to call a tool or not"""
return {
"messages": [
llm_with_tools.invoke(
[
SystemMessage(
content=system_prompt
)
]
+ state["messages"]
)
]
}
def tool_node(state: dict):
"""Performs the tool call"""
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["environment", END]:
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then perform an action
if last_message.tool_calls:
return "Action"
# Otherwise, we stop (reply to the user)
return END
# Build workflow
agent_builder = StateGraph(MessagesState)
# Add nodes
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("environment", tool_node)
# Add edges to connect nodes
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges(
"llm_call",
should_continue,
{
# Name returned by should_continue : Name of next node to visit
"Action": "environment",
END: END,
},
)
agent_builder.add_edge("environment", "llm_call")
# Compile the agent
self.agent = agent_builder.compile()
if show_tools_desc:
for i, tool in enumerate(llm_with_tools.kwargs['tools']):
print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
print(json.dumps(tool[tool['type']], indent=4))
if show_prompt:
print("\n" + "="*30 + f" System prompt " + "="*30)
print(system_prompt)
def __call__(self, question: str) -> str:
print("\n\n"+"*"*50)
print(f"Agent received question: {question}")
print("*"*50)
# Invoke
messages = [HumanMessage(content=question)]
messages = self.agent.invoke({"messages": messages},
{"recursion_limit": 30}) # maximum number of steps before hitting a stop condition
for m in messages["messages"]:
m.pretty_print()
# post-process the response (keep only what's after "FINAL ANSWER:" for the exact match)
response = str(messages["messages"][-1].content)
try:
response = response.split("FINAL ANSWER:")[-1].strip()
except:
print('Could not split response on "FINAL ANSWER:"')
print("\n\n"+"-"*50)
print(f"Agent returning with answer: {response}")
return response