wt002 commited on
Commit
a1fdb15
·
verified ·
1 Parent(s): 43cd7ba

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +30 -46
agent.py CHANGED
@@ -20,6 +20,7 @@ from langchain_community.document_loaders import TextLoader
20
  from langchain_community.vectorstores import FAISS
21
  from langchain_openai import OpenAIEmbeddings
22
  from langchain_text_splitters import CharacterTextSplitter
 
23
 
24
 
25
  load_dotenv()
@@ -128,15 +129,20 @@ sys_msg = SystemMessage(content=system_prompt)
128
 
129
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
130
 
131
-
132
- def create_retriever_tool(persist_directory="vector_store"):
133
- documents = TextLoader("state_of_the_union.txt").load()
134
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
135
- texts = text_splitter.split_documents(documents)
136
- retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()
137
-
138
- docs = retriever.invoke(sys_msg)
139
- #pretty_print_docs(docs)
 
 
 
 
 
140
 
141
 
142
  tools = [
@@ -178,47 +184,25 @@ def build_graph(provider: str = "google"):
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
- #def retriever(state: MessagesState):
182
- # """Retriever node"""
183
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- # example_msg = HumanMessage(
185
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- # )
187
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
-
190
- from langchain_core.messages import AIMessage
191
-
192
  def retriever(state: MessagesState):
193
- query = state["messages"][-1].content
194
- similar_doc = vector_store.similarity_search(query, k=1)[0]
195
-
196
- content = similar_doc.page_content
197
- if "Final answer :" in content:
198
- answer = content.split("Final answer :")[-1].strip()
199
- else:
200
- answer = content.strip()
201
-
202
- return {"messages": [AIMessage(content=answer)]}
203
-
204
- # builder = StateGraph(MessagesState)
205
- #builder.add_node("retriever", retriever)
206
- #builder.add_node("assistant", assistant)
207
- #builder.add_node("tools", ToolNode(tools))
208
- #builder.add_edge(START, "retriever")
209
- #builder.add_edge("retriever", "assistant")
210
- #builder.add_conditional_edges(
211
- # "assistant",
212
- # tools_condition,
213
- #)
214
- #builder.add_edge("tools", "assistant")
215
 
216
  builder = StateGraph(MessagesState)
217
  builder.add_node("retriever", retriever)
218
-
219
- # Retriever ist Start und Endpunkt
220
- builder.set_entry_point("retriever")
221
- builder.set_finish_point("retriever")
 
 
 
 
 
222
 
223
  # Compile graph
224
  return builder.compile()
 
20
  from langchain_community.vectorstores import FAISS
21
  from langchain_openai import OpenAIEmbeddings
22
  from langchain_text_splitters import CharacterTextSplitter
23
+ from supabase.client import Client, create_client
24
 
25
 
26
  load_dotenv()
 
129
 
130
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
131
 
132
+ supabase: Client = create_client(
133
+ os.environ.get("SUPABASE_URL"),
134
+ os.environ.get("SUPABASE_SERVICE_KEY"))
135
+ vector_store = SupabaseVectorStore(
136
+ client=supabase,
137
+ embedding= embeddings,
138
+ table_name="documents",
139
+ query_name="match_documents_langchain",
140
+ )
141
+ create_retriever_tool = create_retriever_tool(
142
+ retriever=vector_store.as_retriever(),
143
+ name="Question Search",
144
+ description="A tool to retrieve similar questions from a vector store.",
145
+ )
146
 
147
 
148
  tools = [
 
184
  """Assistant node"""
185
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
186
 
 
 
 
 
 
 
 
 
 
 
 
187
  def retriever(state: MessagesState):
188
+ """Retriever node"""
189
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
190
+ example_msg = HumanMessage(
191
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
192
+ )
193
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  builder = StateGraph(MessagesState)
196
  builder.add_node("retriever", retriever)
197
+ builder.add_node("assistant", assistant)
198
+ builder.add_node("tools", ToolNode(tools))
199
+ builder.add_edge(START, "retriever")
200
+ builder.add_edge("retriever", "assistant")
201
+ builder.add_conditional_edges(
202
+ "assistant",
203
+ tools_condition,
204
+ )
205
+ builder.add_edge("tools", "assistant")
206
 
207
  # Compile graph
208
  return builder.compile()