# import libraries for langgraph, huggingface import os from dotenv import load_dotenv from typing import TypedDict, List, Dict, Any, Optional, Annotated from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings from langgraph.graph import StateGraph, MessagesState, START, END from langgraph.graph.message import add_messages from langchain_core.messages import SystemMessage, HumanMessage, AnyMessage, AIMessage from langchain_core.messages.ai import subtract_usage from langchain.tools import Tool from langchain_core.tools import tool from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper from langchain_community.utilities import SerpAPIWrapper from langchain_community.utilities import ArxivAPIWrapper from langchain_community.retrievers import BM25Retriever from langgraph.prebuilt import ToolNode, tools_condition # load environment variables load_dotenv() HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # maths tool @tool def add(a:int, b:int) -> int: """add two numbers. args: a: first int b: second int """ return a + b @tool def subtract(a:int, b:int) -> int: """subtract two numbers. args: a: first int b: second int """ return a - b @tool def multiply(a:int, b:int) -> int: """multiply two numbers. args: a: first int b: second int """ return a * b @tool def divide(a:int, b:int) -> float: """divide two numbers. args: a: first int b: second int """ try: # Attempt the division result = a / b return result except ZeroDivisionError: # Handle the case where b is zero raise ValueError("Cannot divide by zero.") @tool def modulus(a:int, b:int) -> int: """modulus remainder of two numbers. args: a: first int b: second int """ return a % b # wikipedia search tool @tool def search_wiki(query: str) -> Dict[str, str]: """search wikipedia with a query args: query: a search query """ docs = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) docs.run(query) formatted_result = f'\n{docs.page_content}\n' return {"wiki_results": formatted_result} # internet search tool @tool def search_web(query: str) -> Dict[str, str]: """search internet with a query args: query: a search query """ docs = SerpAPIWrapper() docs.run(query) formatted_result = f'\n{docs.page_content}\n' return {"wiki_results": formatted_result} # ArXiv search tool @tool def search_arxiv(query: str) -> Dict[str, str]: """search ArXiv for the paper with the given identifier args: query: a search identifier """ arxiv = ArxivAPIWrapper() docs = arxiv.run(query) formatted_result = f'\n{docs.page_content}\n' return {"wiki_results": formatted_result} # build retriever # bm25_retriever = BM25Retriever.from_documents(docs) # load system prompt from file with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() # init system message sys_msg = SystemMessage(content=system_prompt) tools = [ add, subtract, multiply, divide, modulus, search_wiki, search_web, search_arxiv ] # build graph function def build_graph(): # llm llm = HuggingFaceEndpoint( repo_id = "microsoft/Phi-4-reasoning-plus", huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, ) chat = ChatHuggingFace(llm=llm, verbose=False) # bind tools to llm chat_with_tools = chat.bind_tools(tools) # generate AgentState and Agent graph class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] def assistant(state: AgentState): return { "messages": [chat_with_tools.invoke(state["messages"])], } # build graph builder = StateGraph(AgentState) # define nodes builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) # define edges builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", # If the latest message requires a tool, route to tools # Otherwise, provide a direct response tools_condition, ) builder.add_edge("tools", "assistant") return builder.compile() if __name__ == "__main__": question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" graph = build_graph() messages = [HumanMessage(content=question)] messages = graph.invoke({"messages": messages}) for m in messages["messages"]: m.pretty_print()