wt002 commited on
Commit
16c9822
Β·
verified Β·
1 Parent(s): 87c8549

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +20 -10
agent.py CHANGED
@@ -17,6 +17,10 @@ from langchain_core.tools import tool
17
  from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
  from sentence_transformers import SentenceTransformer
 
 
 
 
20
 
21
  load_dotenv()
22
 
@@ -122,14 +126,20 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
122
  # System message
123
  sys_msg = SystemMessage(content=system_prompt)
124
 
125
- # Initialize SentenceTransformer with max_seq_length
126
- sentence_transformer = SentenceTransformer(
127
- "sentence-transformers/all-mpnet-base-v2",
128
- max_seq_length=2048 # Set max sequence length here
129
- )
 
 
 
 
 
 
130
 
131
- # Initialize embeddings with the custom SentenceTransformer model
132
- embeddings = HuggingFaceEmbeddings(model=sentence_transformer)
133
 
134
  # Initialize Supabase client
135
  supabase: Client = create_client(
@@ -137,19 +147,19 @@ supabase: Client = create_client(
137
  os.environ.get("SUPABASE_SERVICE_KEY")
138
  )
139
 
140
- # Create vector store
141
  vector_store = SupabaseVectorStore(
142
  client=supabase,
143
  embedding=embeddings,
144
  table_name="documents",
145
- query_name="match_documents_langchain",
146
  )
147
 
148
  # Create retriever tool
149
  create_retriever_tool = create_retriever_tool(
150
  retriever=vector_store.as_retriever(),
151
  name="Question Search",
152
- description="A tool to retrieve similar questions from a vector store.",
153
  )
154
 
155
 
 
17
  from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
  from sentence_transformers import SentenceTransformer
20
+ from sentence_transformers import SentenceTransformer
21
+ from langchain.embeddings.base import Embeddings
22
+ from typing import List
23
+ import numpy as np
24
 
25
  load_dotenv()
26
 
 
126
  # System message
127
  sys_msg = SystemMessage(content=system_prompt)
128
 
129
+ # Custom embedding class
130
+ from sentence_transformers import SentenceTransformer
131
+ from langchain_huggingface import HuggingFaceEmbeddings
132
+ from supabase import create_client, Client
133
+ from langchain_community.vectorstores import SupabaseVectorStore
134
+ from langchain.tools import create_retriever_tool
135
+ import os
136
+
137
+ # Initialize SentenceTransformer and set max_seq_length
138
+ sentence_transformer = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
139
+ sentence_transformer.max_seq_length = 2048 # Set max sequence length
140
 
141
+ # Initialize embeddings with the model name (dim=768)
142
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
143
 
144
  # Initialize Supabase client
145
  supabase: Client = create_client(
 
147
  os.environ.get("SUPABASE_SERVICE_KEY")
148
  )
149
 
150
+ # Initialize Supabase vector store
151
  vector_store = SupabaseVectorStore(
152
  client=supabase,
153
  embedding=embeddings,
154
  table_name="documents",
155
+ query_name="match_documents_langchain"
156
  )
157
 
158
  # Create retriever tool
159
  create_retriever_tool = create_retriever_tool(
160
  retriever=vector_store.as_retriever(),
161
  name="Question Search",
162
+ description="A tool to retrieve similar questions from a vector store."
163
  )
164
 
165