Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -145,7 +145,7 @@ def wiki_search(query: str) -> str:
|
|
145 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
146 |
for doc in search_docs
|
147 |
])
|
148 |
-
return
|
149 |
|
150 |
|
151 |
|
@@ -175,7 +175,7 @@ def web_search(query: str) -> str:
|
|
175 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
176 |
for doc in search_docs
|
177 |
])
|
178 |
-
return
|
179 |
|
180 |
@tool
|
181 |
def arvix_search(query: str) -> str:
|
@@ -189,7 +189,7 @@ def arvix_search(query: str) -> str:
|
|
189 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
190 |
for doc in search_docs
|
191 |
])
|
192 |
-
return
|
193 |
|
194 |
|
195 |
|
@@ -414,26 +414,6 @@ question_retriever_tool = create_retriever_tool(
|
|
414 |
)
|
415 |
|
416 |
|
417 |
-
# -------------------------------
|
418 |
-
# Step 6: Create LangChain Tools
|
419 |
-
# -------------------------------
|
420 |
-
calc_tool = calculator
|
421 |
-
file_tool = analyze_attachment
|
422 |
-
web_tool = web_search
|
423 |
-
wiki_tool = wiki_search
|
424 |
-
arvix_tool = arvix_search
|
425 |
-
youtube_tool = get_youtube_transcript
|
426 |
-
video_tool = extract_video_id
|
427 |
-
analyze_tool = analyze_attachment
|
428 |
-
wikiq_tool = wikidata_query
|
429 |
-
|
430 |
-
|
431 |
-
# -------------------------------
|
432 |
-
# Step 7: Create the Planner-Agent Logic
|
433 |
-
# -------------------------------
|
434 |
-
|
435 |
-
# Define the tools (as you've already done)
|
436 |
-
tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
|
437 |
|
438 |
# Define the LLM before using it
|
439 |
#llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4"
|
@@ -693,8 +673,7 @@ model_config = {
|
|
693 |
}
|
694 |
|
695 |
def build_graph(provider, model_config):
|
696 |
-
#
|
697 |
-
|
698 |
def get_llm(provider: str, config: dict):
|
699 |
if provider == "huggingface":
|
700 |
from langchain_huggingface import HuggingFaceEndpoint
|
@@ -707,24 +686,53 @@ def build_graph(provider, model_config):
|
|
707 |
)
|
708 |
else:
|
709 |
raise ValueError(f"Unsupported provider: {provider}")
|
710 |
-
|
711 |
-
|
712 |
llm = get_llm(provider, model_config)
|
713 |
-
return llm
|
714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
llm_with_tools = llm.bind_tools(tools)
|
716 |
|
|
|
|
|
717 |
|
718 |
|
719 |
-
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
720 |
-
|
721 |
-
def assistant(state: MessagesState):
|
722 |
-
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
723 |
|
|
|
|
|
|
|
|
|
|
|
724 |
def retriever(state: MessagesState):
|
725 |
user_query = state["messages"][0].content
|
726 |
similar_docs = vector_store.similarity_search(user_query)
|
727 |
-
|
728 |
if not similar_docs:
|
729 |
wiki_result = wiki_tool.run(user_query)
|
730 |
return {
|
@@ -742,18 +750,30 @@ def build_graph(provider, model_config):
|
|
742 |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
743 |
]
|
744 |
}
|
745 |
-
|
|
|
|
|
|
|
|
|
|
|
746 |
def tools_condition(state: MessagesState) -> str:
|
747 |
if "use tool" in state["messages"][-1].content.lower():
|
748 |
return "tools"
|
749 |
else:
|
750 |
-
return END
|
751 |
-
|
|
|
752 |
builder = StateGraph(MessagesState)
|
|
|
|
|
753 |
builder.add_node("retriever", retriever)
|
754 |
builder.add_node("assistant", assistant)
|
755 |
builder.add_node("tools", ToolNode(tools))
|
|
|
|
|
756 |
builder.set_entry_point("retriever")
|
|
|
|
|
757 |
builder.add_edge("retriever", "assistant")
|
758 |
builder.add_conditional_edges("assistant", tools_condition)
|
759 |
builder.add_edge("tools", "assistant")
|
|
|
145 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
146 |
for doc in search_docs
|
147 |
])
|
148 |
+
return formatted_search_docs
|
149 |
|
150 |
|
151 |
|
|
|
175 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
176 |
for doc in search_docs
|
177 |
])
|
178 |
+
return formatted_search_docs
|
179 |
|
180 |
@tool
|
181 |
def arvix_search(query: str) -> str:
|
|
|
189 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
190 |
for doc in search_docs
|
191 |
])
|
192 |
+
return formatted_search_docs
|
193 |
|
194 |
|
195 |
|
|
|
414 |
)
|
415 |
|
416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
# Define the LLM before using it
|
419 |
#llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4"
|
|
|
673 |
}
|
674 |
|
675 |
def build_graph(provider, model_config):
|
676 |
+
# Step 1: Initialize the LLM
|
|
|
677 |
def get_llm(provider: str, config: dict):
|
678 |
if provider == "huggingface":
|
679 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
|
686 |
)
|
687 |
else:
|
688 |
raise ValueError(f"Unsupported provider: {provider}")
|
689 |
+
|
|
|
690 |
llm = get_llm(provider, model_config)
|
|
|
691 |
|
692 |
+
# -------------------------------
|
693 |
+
# Step 6: Define LangChain Tools
|
694 |
+
# -------------------------------
|
695 |
+
calc_tool = calculator # Math operations tool
|
696 |
+
web_tool = web_search # Web search tool
|
697 |
+
wiki_tool = wiki_search # Wikipedia search tool
|
698 |
+
arvix_tool = arvix_search # Arxiv search tool
|
699 |
+
youtube_tool = get_youtube_transcript # YouTube transcript extraction
|
700 |
+
video_tool = extract_video_id # Video ID extraction tool
|
701 |
+
analyze_tool = analyze_attachment # File analysis tool
|
702 |
+
wikiq_tool = wikidata_query # Wikidata query tool
|
703 |
+
|
704 |
+
# -------------------------------
|
705 |
+
# Step 7: Create the Planner-Agent Logic
|
706 |
+
# -------------------------------
|
707 |
+
# Define tools list
|
708 |
+
tools = [
|
709 |
+
wiki_tool,
|
710 |
+
calc_tool,
|
711 |
+
web_tool,
|
712 |
+
arvix_tool,
|
713 |
+
youtube_tool,
|
714 |
+
video_tool,
|
715 |
+
analyze_tool,
|
716 |
+
wikiq_tool
|
717 |
+
]
|
718 |
+
|
719 |
+
# Step 8: Bind tools to the LLM
|
720 |
llm_with_tools = llm.bind_tools(tools)
|
721 |
|
722 |
+
# Return the LLM with tools bound
|
723 |
+
return llm_with_tools
|
724 |
|
725 |
|
|
|
|
|
|
|
|
|
726 |
|
727 |
+
|
728 |
+
# Initialize system message
|
729 |
+
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
730 |
+
|
731 |
+
# Define the retriever function
|
732 |
def retriever(state: MessagesState):
|
733 |
user_query = state["messages"][0].content
|
734 |
similar_docs = vector_store.similarity_search(user_query)
|
735 |
+
|
736 |
if not similar_docs:
|
737 |
wiki_result = wiki_tool.run(user_query)
|
738 |
return {
|
|
|
750 |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
751 |
]
|
752 |
}
|
753 |
+
|
754 |
+
# Define the assistant function
|
755 |
+
def assistant(state: MessagesState):
|
756 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
757 |
+
|
758 |
+
# Define condition for tools usage
|
759 |
def tools_condition(state: MessagesState) -> str:
|
760 |
if "use tool" in state["messages"][-1].content.lower():
|
761 |
return "tools"
|
762 |
else:
|
763 |
+
return "END"
|
764 |
+
|
765 |
+
# Initialize the StateGraph
|
766 |
builder = StateGraph(MessagesState)
|
767 |
+
|
768 |
+
# Add nodes to the graph
|
769 |
builder.add_node("retriever", retriever)
|
770 |
builder.add_node("assistant", assistant)
|
771 |
builder.add_node("tools", ToolNode(tools))
|
772 |
+
|
773 |
+
# Set the entry point
|
774 |
builder.set_entry_point("retriever")
|
775 |
+
|
776 |
+
# Define edges
|
777 |
builder.add_edge("retriever", "assistant")
|
778 |
builder.add_conditional_edges("assistant", tools_condition)
|
779 |
builder.add_edge("tools", "assistant")
|