Spaces:
Sleeping
Sleeping
from langchain.schema import HumanMessage, AIMessage, SystemMessage | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import AnyMessage, SystemMessage | |
from langchain_core.tools import tool | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_community.document_loaders import ArxivLoader | |
# from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain.tools.retriever import create_retriever_tool | |
from langgraph.graph.message import add_messages | |
from langgraph.graph import START, StateGraph, MessagesState, END | |
from langgraph.prebuilt import tools_condition, ToolNode | |
import os | |
from dotenv import load_dotenv | |
from typing import TypedDict, Annotated, Optional | |
from langchain_community.tools import DuckDuckGoSearchResults | |
from langchain_huggingface import ( | |
ChatHuggingFace, | |
HuggingFaceEndpoint, | |
HuggingFaceEmbeddings, | |
) | |
load_dotenv() | |
embddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
) | |
# Initialize the DuckDuckGo search tool | |
search_tool = DuckDuckGoSearchResults() | |
def wiki_search(query: str) -> str: | |
"""Search Wikipedia for a query and return maximum 2 results. | |
Args: | |
query: The search query.""" | |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"wiki_results": formatted_search_docs} | |
def web_search(query: str) -> str: | |
"""Search Tavily for a query and return maximum 3 results. | |
Args: | |
query: The search query.""" | |
search_docs = TavilySearchResults(max_results=3).invoke(query=query) | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"web_results": formatted_search_docs} | |
def arvix_search(query: str) -> str: | |
"""Search Arxiv for a query and return maximum 3 result. | |
Args: | |
query: The search query.""" | |
search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"arvix_results": formatted_search_docs} | |
# Load LLM model | |
# llm = ChatOpenAI( | |
# model="gpt-4o", | |
# base_url="https://models.inference.ai.azure.com", | |
# api_key=os.environ["GITHUB_TOKEN"], | |
# temperature=0.2, | |
# max_tokens=4096, | |
# ) | |
llm = ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
repo_id="microsoft/Phi-3-mini-4k-instruct", | |
temperature=0, | |
# huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"], | |
), | |
verbose=True, | |
) | |
tools = [ | |
arvix_search, | |
wiki_search, | |
# web_search, | |
search_tool, | |
] | |
# Bind the tools to the LLM | |
model_with_tools = llm.bind_tools(tools) | |
tool_node = ToolNode(tools) | |
def build_agent_workflow(): | |
def should_continue(state: MessagesState): | |
messages = state["messages"] | |
last_message = messages[-1] | |
if last_message.tool_calls: | |
return "tools" | |
return END | |
def call_model(state: MessagesState): | |
system_message = SystemMessage( | |
content=f""" | |
You are a helpful assistant tasked with answering questions using a set of tools. | |
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
FINAL ANSWER: [YOUR FINAL ANSWER]. | |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """ | |
) | |
messages = [system_message] + state["messages"] | |
print("Messages to LLM:", messages) | |
response = model_with_tools.invoke(messages) | |
return {"messages": [response]} | |
# Define the state graph | |
workflow = StateGraph(MessagesState) | |
workflow.add_node("agent", call_model) | |
workflow.add_node("tools", tool_node) | |
workflow.add_edge(START, "agent") | |
workflow.add_conditional_edges("agent", should_continue, ["tools", END]) | |
workflow.add_edge("tools", "agent") | |
app = workflow.compile() | |
return app | |
if __name__ == "__main__": | |
question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?" | |
# Build the graph | |
graph = build_agent_workflow() | |
# Run the graph | |
messages = [HumanMessage(content=question)] | |
messages = graph.invoke({"messages": messages}) | |
for m in messages["messages"]: | |
m.pretty_print() | |