# env variable needed: HF_TOKEN, OPENAI_API_KEY, BRAVE_SEARCH_API_KEY import os import json from typing import Literal from langchain_openai import ChatOpenAI from langgraph.graph import MessagesState from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage from langgraph.graph import StateGraph, START, END from langchain_community.tools import BraveSearch from .prompt import system_prompt from .custom_tools import (multiply, add, subtract, divide, modulus, power, query_image, automatic_speech_recognition, get_webpage_content, python_repl_tool, get_youtube_transcript) class LangGraphAgent: def __init__(self, model_name="gpt-4.1-nano", show_tools_desc=True, show_prompt=True): # =========== LLM definition =========== if model_name.startswith('o'): # reasoning model (no temperature setting) llm = ChatOpenAI(model=model_name) # needs OPENAI_API_KEY in env else: llm = ChatOpenAI(model=model_name, temperature=0) print(f"LangGraphAgent initialized with model \"{model_name}\"") # =========== Augment the LLM with tools =========== community_tools = [ BraveSearch.from_api_key( # Web search (more performant than DuckDuckGo) api_key=os.getenv("BRAVE_SEARCH_API_KEY"), # needs BRAVE_SEARCH_API_KEY in env search_kwargs={"count": 5}), ] custom_tools = [ multiply, add, subtract, divide, modulus, power, # Basic arithmetic query_image, # Ask anything about an image using a VLM automatic_speech_recognition, # Transcribe an audio file to text get_webpage_content, # Load a web page and return it to markdown python_repl_tool, # Python code interpreter get_youtube_transcript, # Get the transcript of a YouTube video ] tools = community_tools + custom_tools tools_by_name = {tool.name: tool for tool in tools} llm_with_tools = llm.bind_tools(tools) # =========== Agent definition =========== # Nodes def llm_call(state: MessagesState): """LLM decides whether to call a tool or not""" return { "messages": [ llm_with_tools.invoke( [ SystemMessage( content=system_prompt ) ] + state["messages"] ) ] } def tool_node(state: dict): """Performs the tool call""" result = [] for tool_call in state["messages"][-1].tool_calls: tool = tools_by_name[tool_call["name"]] observation = tool.invoke(tool_call["args"]) result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) return {"messages": result} # Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call def should_continue(state: MessagesState) -> Literal["environment", END]: """Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" messages = state["messages"] last_message = messages[-1] # If the LLM makes a tool call, then perform an action if last_message.tool_calls: return "Action" # Otherwise, we stop (reply to the user) return END # Build workflow agent_builder = StateGraph(MessagesState) # Add nodes agent_builder.add_node("llm_call", llm_call) agent_builder.add_node("environment", tool_node) # Add edges to connect nodes agent_builder.add_edge(START, "llm_call") agent_builder.add_conditional_edges( "llm_call", should_continue, { # Name returned by should_continue : Name of next node to visit "Action": "environment", END: END, }, ) agent_builder.add_edge("environment", "llm_call") # Compile the agent self.agent = agent_builder.compile() if show_tools_desc: for i, tool in enumerate(llm_with_tools.kwargs['tools']): print("\n" + "="*30 + f" Tool {i+1} " + "="*30) print(json.dumps(tool[tool['type']], indent=4)) if show_prompt: print("\n" + "="*30 + f" System prompt " + "="*30) print(system_prompt) def __call__(self, question: str) -> str: print("\n\n"+"*"*50) print(f"Agent received question: {question}") print("*"*50) # Invoke messages = [HumanMessage(content=question)] messages = self.agent.invoke({"messages": messages}, {"recursion_limit": 30}) # maximum number of steps before hitting a stop condition for m in messages["messages"]: m.pretty_print() # post-process the response (keep only what's after "FINAL ANSWER:" for the exact match) response = str(messages["messages"][-1].content) try: response = response.split("FINAL ANSWER:")[-1].strip() except: print('Could not split response on "FINAL ANSWER:"') print("\n\n"+"-"*50) print(f"Agent returning with answer: {response}") return response