Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -1087,41 +1087,21 @@ def tools_condition(state: dict) -> str:
|
|
1087 |
|
1088 |
|
1089 |
from langgraph.graph import StateGraph
|
|
|
|
|
1090 |
|
1091 |
-
|
1092 |
-
def build_graph() -> StateGraph:
|
1093 |
-
builder = StateGraph(AgentState)
|
1094 |
-
|
1095 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
1096 |
-
#from langgraph.graph import StateGraph, ToolNode
|
1097 |
-
from langchain_core.runnables import RunnableLambda
|
1098 |
-
from some_module import vector_store # Make sure this is defined/imported
|
1099 |
-
|
1100 |
-
|
1101 |
llm = get_llm(provider, model_config)
|
1102 |
|
1103 |
-
|
1104 |
-
|
1105 |
-
wiki_search,
|
1106 |
-
calculator,
|
1107 |
-
web_search,
|
1108 |
-
arxiv_search,
|
1109 |
-
get_youtube_transcript,
|
1110 |
-
extract_video_id,
|
1111 |
-
analyze_attachment,
|
1112 |
-
wikidata_query
|
1113 |
-
]
|
1114 |
|
1115 |
global tool_map
|
1116 |
tool_map = {t.name: t for t in tools}
|
1117 |
|
1118 |
-
# Step 3: Bind tools to LLM
|
1119 |
llm_with_tools = llm.bind_tools(tools)
|
1120 |
-
|
1121 |
-
# Step 4: Build stateful graph logic
|
1122 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1123 |
|
1124 |
-
# --- Define nodes ---
|
1125 |
retriever = RunnableLambda(lambda state: {
|
1126 |
**state,
|
1127 |
"retrieved_docs": vector_store.similarity_search(state["input"])
|
@@ -1129,31 +1109,27 @@ def build_graph() -> StateGraph:
|
|
1129 |
|
1130 |
assistant = RunnableLambda(lambda state: {
|
1131 |
**state,
|
1132 |
-
"messages":
|
1133 |
})
|
1134 |
|
1135 |
-
call_llm =
|
1136 |
|
1137 |
-
|
1138 |
builder.add_node("retriever", retriever)
|
1139 |
builder.add_node("assistant", assistant)
|
1140 |
builder.add_node("call_llm", call_llm)
|
1141 |
builder.add_node("call_tool", tool_dispatcher)
|
1142 |
|
1143 |
builder.set_entry_point("retriever")
|
1144 |
-
|
1145 |
builder.add_edge("retriever", "assistant")
|
1146 |
builder.add_edge("assistant", "call_llm")
|
1147 |
-
|
1148 |
builder.add_conditional_edges("call_llm", should_call_tool, {
|
1149 |
"call_tool": "call_tool",
|
1150 |
"end": None
|
1151 |
})
|
1152 |
-
|
1153 |
builder.add_edge("call_tool", "call_llm")
|
1154 |
|
1155 |
-
|
1156 |
-
return graph
|
1157 |
|
1158 |
|
1159 |
|
|
|
1087 |
|
1088 |
|
1089 |
from langgraph.graph import StateGraph
|
1090 |
+
from langchain_core.messages import SystemMessage
|
1091 |
+
from langchain_core.runnables import RunnableLambda
|
1092 |
|
1093 |
+
def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1094 |
llm = get_llm(provider, model_config)
|
1095 |
|
1096 |
+
tools = [wiki_search, calculator, web_search, arxiv_search,
|
1097 |
+
get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1098 |
|
1099 |
global tool_map
|
1100 |
tool_map = {t.name: t for t in tools}
|
1101 |
|
|
|
1102 |
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
|
1103 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1104 |
|
|
|
1105 |
retriever = RunnableLambda(lambda state: {
|
1106 |
**state,
|
1107 |
"retrieved_docs": vector_store.similarity_search(state["input"])
|
|
|
1109 |
|
1110 |
assistant = RunnableLambda(lambda state: {
|
1111 |
**state,
|
1112 |
+
"messages": [sys_msg] + state["messages"]
|
1113 |
})
|
1114 |
|
1115 |
+
call_llm = llm_with_tools
|
1116 |
|
1117 |
+
builder = StateGraph(AgentState)
|
1118 |
builder.add_node("retriever", retriever)
|
1119 |
builder.add_node("assistant", assistant)
|
1120 |
builder.add_node("call_llm", call_llm)
|
1121 |
builder.add_node("call_tool", tool_dispatcher)
|
1122 |
|
1123 |
builder.set_entry_point("retriever")
|
|
|
1124 |
builder.add_edge("retriever", "assistant")
|
1125 |
builder.add_edge("assistant", "call_llm")
|
|
|
1126 |
builder.add_conditional_edges("call_llm", should_call_tool, {
|
1127 |
"call_tool": "call_tool",
|
1128 |
"end": None
|
1129 |
})
|
|
|
1130 |
builder.add_edge("call_tool", "call_llm")
|
1131 |
|
1132 |
+
return builder.compile()
|
|
|
1133 |
|
1134 |
|
1135 |
|