0xrushi commited on
Commit
de49b6b
·
1 Parent(s): 6a3ad81
Files changed (1) hide show
  1. agent.py +19 -13
agent.py CHANGED
@@ -99,30 +99,36 @@ def build_graph():
99
  llm_with_tools = llm.bind_tools(tools)
100
 
101
  # Define nodes
102
- def assistant_node(state: MessagesState) -> dict:
103
- # Append system message for context
104
- messages = [sys_msg] + state["messages"]
105
- response = llm_with_tools.invoke(messages)
106
- return {"messages": [response]}
107
 
108
 
109
  # Retriever node returns AIMessage
110
  def retriever(state: MessagesState):
111
  query = state["messages"][-1].content
112
- similar_doc = vector_store.similarity_search(query, k=1)[0]
113
 
114
- content = similar_doc.page_content
115
- if "Final answer :" in content:
116
- answer = content.split("Final answer :")[-1].strip()
117
  else:
118
- answer = content.strip()
119
- return {"messages": [AIMessage(content=answer)]}
 
 
 
 
120
 
121
  builder = StateGraph(MessagesState)
122
  builder.add_node("retriever", retriever)
 
 
123
 
124
- builder.set_entry_point("retriever")
125
- builder.set_finish_point("retriever")
 
 
126
 
127
  # Compile graph
128
  return builder.compile()
 
99
  llm_with_tools = llm.bind_tools(tools)
100
 
101
  # Define nodes
102
+ def assistant(state: MessagesState):
103
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
104
+
 
 
105
 
106
 
107
  # Retriever node returns AIMessage
108
  def retriever(state: MessagesState):
109
  query = state["messages"][-1].content
110
+ similar_docs = vector_store.similarity_search(query, k=1)
111
 
112
+ if similar_docs:
113
+ reference = similar_docs[0].page_content
114
+ context_msg = HumanMessage(content=f"Here is a similar question and answer for reference:\n\n{reference}")
115
  else:
116
+ context_msg = HumanMessage(content="No relevant example found.")
117
+
118
+ return {
119
+ "messages": [sys_msg] + state["messages"] + [context_msg]
120
+ }
121
+
122
 
123
  builder = StateGraph(MessagesState)
124
  builder.add_node("retriever", retriever)
125
+ builder.add_node("assistant", assistant)
126
+ builder.add_node("tools", ToolNode(tools))
127
 
128
+ builder.add_edge(START, "retriever")
129
+ builder.add_edge("retriever", "assistant")
130
+ builder.add_conditional_edges("assistant", tools_condition)
131
+ builder.add_edge("tools", "assistant")
132
 
133
  # Compile graph
134
  return builder.compile()