Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -145,7 +145,7 @@ def wiki_search(query: str) -> str:
|
|
| 145 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 146 |
for doc in search_docs
|
| 147 |
])
|
| 148 |
-
return
|
| 149 |
|
| 150 |
|
| 151 |
|
|
@@ -175,7 +175,7 @@ def web_search(query: str) -> str:
|
|
| 175 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 176 |
for doc in search_docs
|
| 177 |
])
|
| 178 |
-
return
|
| 179 |
|
| 180 |
@tool
|
| 181 |
def arvix_search(query: str) -> str:
|
|
@@ -189,7 +189,7 @@ def arvix_search(query: str) -> str:
|
|
| 189 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
| 190 |
for doc in search_docs
|
| 191 |
])
|
| 192 |
-
return
|
| 193 |
|
| 194 |
|
| 195 |
|
|
@@ -414,26 +414,6 @@ question_retriever_tool = create_retriever_tool(
|
|
| 414 |
)
|
| 415 |
|
| 416 |
|
| 417 |
-
# -------------------------------
|
| 418 |
-
# Step 6: Create LangChain Tools
|
| 419 |
-
# -------------------------------
|
| 420 |
-
calc_tool = calculator
|
| 421 |
-
file_tool = analyze_attachment
|
| 422 |
-
web_tool = web_search
|
| 423 |
-
wiki_tool = wiki_search
|
| 424 |
-
arvix_tool = arvix_search
|
| 425 |
-
youtube_tool = get_youtube_transcript
|
| 426 |
-
video_tool = extract_video_id
|
| 427 |
-
analyze_tool = analyze_attachment
|
| 428 |
-
wikiq_tool = wikidata_query
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
# -------------------------------
|
| 432 |
-
# Step 7: Create the Planner-Agent Logic
|
| 433 |
-
# -------------------------------
|
| 434 |
-
|
| 435 |
-
# Define the tools (as you've already done)
|
| 436 |
-
tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
|
| 437 |
|
| 438 |
# Define the LLM before using it
|
| 439 |
#llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4"
|
|
@@ -693,8 +673,7 @@ model_config = {
|
|
| 693 |
}
|
| 694 |
|
| 695 |
def build_graph(provider, model_config):
|
| 696 |
-
#
|
| 697 |
-
|
| 698 |
def get_llm(provider: str, config: dict):
|
| 699 |
if provider == "huggingface":
|
| 700 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
@@ -707,24 +686,53 @@ def build_graph(provider, model_config):
|
|
| 707 |
)
|
| 708 |
else:
|
| 709 |
raise ValueError(f"Unsupported provider: {provider}")
|
| 710 |
-
|
| 711 |
-
|
| 712 |
llm = get_llm(provider, model_config)
|
| 713 |
-
return llm
|
| 714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
llm_with_tools = llm.bind_tools(tools)
|
| 716 |
|
|
|
|
|
|
|
| 717 |
|
| 718 |
|
| 719 |
-
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
| 720 |
-
|
| 721 |
-
def assistant(state: MessagesState):
|
| 722 |
-
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
| 723 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
def retriever(state: MessagesState):
|
| 725 |
user_query = state["messages"][0].content
|
| 726 |
similar_docs = vector_store.similarity_search(user_query)
|
| 727 |
-
|
| 728 |
if not similar_docs:
|
| 729 |
wiki_result = wiki_tool.run(user_query)
|
| 730 |
return {
|
|
@@ -742,18 +750,30 @@ def build_graph(provider, model_config):
|
|
| 742 |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
| 743 |
]
|
| 744 |
}
|
| 745 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
def tools_condition(state: MessagesState) -> str:
|
| 747 |
if "use tool" in state["messages"][-1].content.lower():
|
| 748 |
return "tools"
|
| 749 |
else:
|
| 750 |
-
return END
|
| 751 |
-
|
|
|
|
| 752 |
builder = StateGraph(MessagesState)
|
|
|
|
|
|
|
| 753 |
builder.add_node("retriever", retriever)
|
| 754 |
builder.add_node("assistant", assistant)
|
| 755 |
builder.add_node("tools", ToolNode(tools))
|
|
|
|
|
|
|
| 756 |
builder.set_entry_point("retriever")
|
|
|
|
|
|
|
| 757 |
builder.add_edge("retriever", "assistant")
|
| 758 |
builder.add_conditional_edges("assistant", tools_condition)
|
| 759 |
builder.add_edge("tools", "assistant")
|
|
|
|
| 145 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 146 |
for doc in search_docs
|
| 147 |
])
|
| 148 |
+
return formatted_search_docs
|
| 149 |
|
| 150 |
|
| 151 |
|
|
|
|
| 175 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 176 |
for doc in search_docs
|
| 177 |
])
|
| 178 |
+
return formatted_search_docs
|
| 179 |
|
| 180 |
@tool
|
| 181 |
def arvix_search(query: str) -> str:
|
|
|
|
| 189 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
| 190 |
for doc in search_docs
|
| 191 |
])
|
| 192 |
+
return formatted_search_docs
|
| 193 |
|
| 194 |
|
| 195 |
|
|
|
|
| 414 |
)
|
| 415 |
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
# Define the LLM before using it
|
| 419 |
#llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4"
|
|
|
|
| 673 |
}
|
| 674 |
|
| 675 |
def build_graph(provider, model_config):
|
| 676 |
+
# Step 1: Initialize the LLM
|
|
|
|
| 677 |
def get_llm(provider: str, config: dict):
|
| 678 |
if provider == "huggingface":
|
| 679 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
|
|
| 686 |
)
|
| 687 |
else:
|
| 688 |
raise ValueError(f"Unsupported provider: {provider}")
|
| 689 |
+
|
|
|
|
| 690 |
llm = get_llm(provider, model_config)
|
|
|
|
| 691 |
|
| 692 |
+
# -------------------------------
|
| 693 |
+
# Step 6: Define LangChain Tools
|
| 694 |
+
# -------------------------------
|
| 695 |
+
calc_tool = calculator # Math operations tool
|
| 696 |
+
web_tool = web_search # Web search tool
|
| 697 |
+
wiki_tool = wiki_search # Wikipedia search tool
|
| 698 |
+
arvix_tool = arvix_search # Arxiv search tool
|
| 699 |
+
youtube_tool = get_youtube_transcript # YouTube transcript extraction
|
| 700 |
+
video_tool = extract_video_id # Video ID extraction tool
|
| 701 |
+
analyze_tool = analyze_attachment # File analysis tool
|
| 702 |
+
wikiq_tool = wikidata_query # Wikidata query tool
|
| 703 |
+
|
| 704 |
+
# -------------------------------
|
| 705 |
+
# Step 7: Create the Planner-Agent Logic
|
| 706 |
+
# -------------------------------
|
| 707 |
+
# Define tools list
|
| 708 |
+
tools = [
|
| 709 |
+
wiki_tool,
|
| 710 |
+
calc_tool,
|
| 711 |
+
web_tool,
|
| 712 |
+
arvix_tool,
|
| 713 |
+
youtube_tool,
|
| 714 |
+
video_tool,
|
| 715 |
+
analyze_tool,
|
| 716 |
+
wikiq_tool
|
| 717 |
+
]
|
| 718 |
+
|
| 719 |
+
# Step 8: Bind tools to the LLM
|
| 720 |
llm_with_tools = llm.bind_tools(tools)
|
| 721 |
|
| 722 |
+
# Return the LLM with tools bound
|
| 723 |
+
return llm_with_tools
|
| 724 |
|
| 725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
|
| 727 |
+
|
| 728 |
+
# Initialize system message
|
| 729 |
+
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
| 730 |
+
|
| 731 |
+
# Define the retriever function
|
| 732 |
def retriever(state: MessagesState):
|
| 733 |
user_query = state["messages"][0].content
|
| 734 |
similar_docs = vector_store.similarity_search(user_query)
|
| 735 |
+
|
| 736 |
if not similar_docs:
|
| 737 |
wiki_result = wiki_tool.run(user_query)
|
| 738 |
return {
|
|
|
|
| 750 |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
| 751 |
]
|
| 752 |
}
|
| 753 |
+
|
| 754 |
+
# Define the assistant function
|
| 755 |
+
def assistant(state: MessagesState):
|
| 756 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
| 757 |
+
|
| 758 |
+
# Define condition for tools usage
|
| 759 |
def tools_condition(state: MessagesState) -> str:
|
| 760 |
if "use tool" in state["messages"][-1].content.lower():
|
| 761 |
return "tools"
|
| 762 |
else:
|
| 763 |
+
return "END"
|
| 764 |
+
|
| 765 |
+
# Initialize the StateGraph
|
| 766 |
builder = StateGraph(MessagesState)
|
| 767 |
+
|
| 768 |
+
# Add nodes to the graph
|
| 769 |
builder.add_node("retriever", retriever)
|
| 770 |
builder.add_node("assistant", assistant)
|
| 771 |
builder.add_node("tools", ToolNode(tools))
|
| 772 |
+
|
| 773 |
+
# Set the entry point
|
| 774 |
builder.set_entry_point("retriever")
|
| 775 |
+
|
| 776 |
+
# Define edges
|
| 777 |
builder.add_edge("retriever", "assistant")
|
| 778 |
builder.add_conditional_edges("assistant", tools_condition)
|
| 779 |
builder.add_edge("tools", "assistant")
|