EtienneB commited on
Commit
f00550f
·
1 Parent(s): 294f2c0

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +24 -7
agent.py CHANGED
@@ -1,9 +1,11 @@
1
  import os
2
 
3
  from dotenv import load_dotenv
 
4
  from langchain_core.messages import HumanMessage, SystemMessage
5
  from langchain_core.tools import tool
6
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
 
7
  from langgraph.graph import START, MessagesState, StateGraph
8
  from langgraph.prebuilt import ToolNode, tools_condition
9
 
@@ -41,6 +43,16 @@ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma sepa
41
  Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
42
  """
43
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def build_graph():
46
  """Build the graph"""
@@ -71,17 +83,22 @@ def build_graph():
71
  """Assistant node"""
72
  return {"messages": [llm_with_tools.invoke([system_prompt] + state["messages"])]}
73
 
74
- builder = StateGraph(MessagesState)
 
 
 
 
 
75
 
76
- # Nodes
 
77
  builder.add_node("assistant", assistant)
78
  builder.add_node("tools", ToolNode(tools))
79
-
80
- # Edges
81
- builder.add_edge(START, "assistant")
82
  builder.add_conditional_edges("assistant", tools_condition)
83
  builder.add_edge("tools", "assistant")
84
-
85
  # Compile graph
86
  return builder.compile()
87
 
 
1
  import os
2
 
3
  from dotenv import load_dotenv
4
+ from langchain_community.vectorstores import Chroma
5
  from langchain_core.messages import HumanMessage, SystemMessage
6
  from langchain_core.tools import tool
7
+ from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
8
+ HuggingFaceEndpoint)
9
  from langgraph.graph import START, MessagesState, StateGraph
10
  from langgraph.prebuilt import ToolNode, tools_condition
11
 
 
43
  Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
44
  """
45
 
46
+ # System message
47
+ sys_msg = SystemMessage(content=system_prompt)
48
+
49
+ # Embeddings + Chroma Vector Store
50
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
51
+ vector_store = Chroma(
52
+ collection_name="langgraph-documents",
53
+ embedding_function=embeddings,
54
+ persist_directory="chroma_db" # Use a persistent directory
55
+ )
56
 
57
  def build_graph():
58
  """Build the graph"""
 
83
  """Assistant node"""
84
  return {"messages": [llm_with_tools.invoke([system_prompt] + state["messages"])]}
85
 
86
+ def retriever(state: MessagesState):
87
+ similar = vector_store.similarity_search(state["messages"][0].content)
88
+ if similar:
89
+ example_msg = HumanMessage(content=f"Here is a similar question:\n\n{similar[0].page_content}")
90
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
91
+ return {"messages": [sys_msg] + state["messages"]}
92
 
93
+ builder = StateGraph(MessagesState)
94
+ builder.add_node("retriever", retriever)
95
  builder.add_node("assistant", assistant)
96
  builder.add_node("tools", ToolNode(tools))
97
+ builder.add_edge(START, "retriever")
98
+ builder.add_edge("retriever", "assistant")
 
99
  builder.add_conditional_edges("assistant", tools_condition)
100
  builder.add_edge("tools", "assistant")
101
+
102
  # Compile graph
103
  return builder.compile()
104