guillaumefrd's picture
add more advanced tools (query image, ASR, code interpreter)
3568413
raw
history blame
5.15 kB
import json
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from langgraph.graph import StateGraph, START, END
from langchain.agents import load_tools
from langchain_community.tools.riza.command import ExecPython
from .prompt import system_prompt
from .custom_tools import (multiply, add, subtract, divide, modulus, power,
query_image, automatic_speech_recognition)
class LangGraphAgent:
def __init__(self,
model_name="gpt-4.1-nano",
show_tools_desc=True,
show_prompt=True):
# =========== LLM definition ===========
llm = ChatOpenAI(model=model_name, temperature=0) # needs OPENAI_API_KEY
print(f"LangGraphAgent initialized with model \"{model_name}\"")
# =========== Augment the LLM with tools ===========
community_tool_names = [
"ddg-search", # DuckDuckGo search
"wikipedia",
]
community_tools = load_tools(community_tool_names)
community_tools += [ExecPython(runtime_revision_id='01JT97GJ20BC83Y75WMAS364ZT')] # Riza code interpreter (needs RIZA_API_KEY) (not supported by load_tools, custom runtime with basic packages (pandas, numpy, etc.))
custom_tools = [
multiply, add, subtract, divide, modulus, power, # basic arithmetic
query_image, # Ask anything about an image using a VLM
automatic_speech_recognition, # Transcribe an audio file to text
]
tools = community_tools + custom_tools
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)
# =========== Agent definition ===========
# Nodes
def llm_call(state: MessagesState):
"""LLM decides whether to call a tool or not"""
return {
"messages": [
llm_with_tools.invoke(
[
SystemMessage(
content=system_prompt
)
]
+ state["messages"]
)
]
}
def tool_node(state: dict):
"""Performs the tool call"""
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["environment", END]:
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then perform an action
if last_message.tool_calls:
return "Action"
# Otherwise, we stop (reply to the user)
return END
# Build workflow
agent_builder = StateGraph(MessagesState)
# Add nodes
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("environment", tool_node)
# Add edges to connect nodes
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges(
"llm_call",
should_continue,
{
# Name returned by should_continue : Name of next node to visit
"Action": "environment",
END: END,
},
)
agent_builder.add_edge("environment", "llm_call")
# Compile the agent
self.agent = agent_builder.compile()
if show_tools_desc:
for i, tool in enumerate(llm_with_tools.kwargs['tools']):
print("\n" + "="*30 + f" Tool {i+1} " + "="*30)
print(json.dumps(tool[tool['type']], indent=4))
if show_prompt:
print("\n" + "="*30 + f" System prompt " + "="*30)
print(system_prompt)
def __call__(self, question: str) -> str:
print("\n\n"+"*"*50)
print(f"Agent received question: {question}")
print("*"*50)
# Invoke
messages = [HumanMessage(content=question)]
messages = self.agent.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()
# post-process the response (keep only what's after "FINAL ANSWER:" for the exact match)
response = str(messages["messages"][-1].content)
try:
response = response.split("FINAL ANSWER:")[-1].strip()
except:
print('Could not split response on "FINAL ANSWER:"')
print("\n\n"+"-"*50)
print(f"Agent returning with answer: {response}")
return response