|
|
|
|
|
import os |
|
import json |
|
|
|
from typing import Literal |
|
from langchain_openai import ChatOpenAI |
|
from langgraph.graph import MessagesState |
|
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage |
|
from langgraph.graph import StateGraph, START, END |
|
from langchain_community.tools import BraveSearch |
|
|
|
from .prompt import system_prompt |
|
from .custom_tools import (multiply, add, subtract, divide, modulus, power, |
|
query_image, automatic_speech_recognition, get_webpage_content, python_repl_tool, |
|
get_youtube_transcript) |
|
|
|
|
|
class LangGraphAgent: |
|
def __init__(self, |
|
model_name="gpt-4.1-nano", |
|
show_tools_desc=True, |
|
show_prompt=True): |
|
|
|
|
|
if model_name.startswith('o'): |
|
|
|
llm = ChatOpenAI(model=model_name) |
|
else: |
|
llm = ChatOpenAI(model=model_name, temperature=0) |
|
print(f"LangGraphAgent initialized with model \"{model_name}\"") |
|
|
|
|
|
community_tools = [ |
|
BraveSearch.from_api_key( |
|
api_key=os.getenv("BRAVE_SEARCH_API_KEY"), |
|
search_kwargs={"count": 5}), |
|
] |
|
custom_tools = [ |
|
multiply, add, subtract, divide, modulus, power, |
|
query_image, |
|
automatic_speech_recognition, |
|
get_webpage_content, |
|
python_repl_tool, |
|
get_youtube_transcript, |
|
] |
|
|
|
tools = community_tools + custom_tools |
|
tools_by_name = {tool.name: tool for tool in tools} |
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
|
|
|
|
def llm_call(state: MessagesState): |
|
"""LLM decides whether to call a tool or not""" |
|
|
|
return { |
|
"messages": [ |
|
llm_with_tools.invoke( |
|
[ |
|
SystemMessage( |
|
content=system_prompt |
|
) |
|
] |
|
+ state["messages"] |
|
) |
|
] |
|
} |
|
|
|
def tool_node(state: dict): |
|
"""Performs the tool call""" |
|
|
|
result = [] |
|
for tool_call in state["messages"][-1].tool_calls: |
|
tool = tools_by_name[tool_call["name"]] |
|
observation = tool.invoke(tool_call["args"]) |
|
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) |
|
return {"messages": result} |
|
|
|
|
|
|
|
def should_continue(state: MessagesState) -> Literal["environment", END]: |
|
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" |
|
|
|
messages = state["messages"] |
|
last_message = messages[-1] |
|
|
|
if last_message.tool_calls: |
|
return "Action" |
|
|
|
return END |
|
|
|
|
|
agent_builder = StateGraph(MessagesState) |
|
|
|
|
|
agent_builder.add_node("llm_call", llm_call) |
|
agent_builder.add_node("environment", tool_node) |
|
|
|
|
|
agent_builder.add_edge(START, "llm_call") |
|
agent_builder.add_conditional_edges( |
|
"llm_call", |
|
should_continue, |
|
{ |
|
|
|
"Action": "environment", |
|
END: END, |
|
}, |
|
) |
|
agent_builder.add_edge("environment", "llm_call") |
|
|
|
|
|
self.agent = agent_builder.compile() |
|
|
|
if show_tools_desc: |
|
for i, tool in enumerate(llm_with_tools.kwargs['tools']): |
|
print("\n" + "="*30 + f" Tool {i+1} " + "="*30) |
|
print(json.dumps(tool[tool['type']], indent=4)) |
|
|
|
if show_prompt: |
|
print("\n" + "="*30 + f" System prompt " + "="*30) |
|
print(system_prompt) |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
print("\n\n"+"*"*50) |
|
print(f"Agent received question: {question}") |
|
print("*"*50) |
|
|
|
|
|
messages = [HumanMessage(content=question)] |
|
messages = self.agent.invoke({"messages": messages}, |
|
{"recursion_limit": 30}) |
|
for m in messages["messages"]: |
|
m.pretty_print() |
|
|
|
|
|
response = str(messages["messages"][-1].content) |
|
try: |
|
response = response.split("FINAL ANSWER:")[-1].strip() |
|
except: |
|
print('Could not split response on "FINAL ANSWER:"') |
|
print("\n\n"+"-"*50) |
|
print(f"Agent returning with answer: {response}") |
|
return response |