gaur3009 commited on
Commit
69374eb
Β·
verified Β·
1 Parent(s): bea9184

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +20 -27
rag.py CHANGED
@@ -1,35 +1,28 @@
1
- from langchain_community.vectorstores import FAISS
2
- from langchain_community.embeddings import HuggingFaceEmbeddings
3
- from langchain_text_splitters import RecursiveCharacterTextSplitter
 
 
 
4
 
5
  class VectorStore:
6
  def __init__(self):
7
- self.vectorstore = None
8
- self.embedder = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
9
- self.text_splitter = RecursiveCharacterTextSplitter(
10
- chunk_size=500,
11
- chunk_overlap=50
12
- )
13
 
14
  def add_texts(self, texts):
15
- if not texts:
16
- return
17
-
18
- # Split and add texts
19
- if self.vectorstore is None:
20
- self.vectorstore = FAISS.from_texts(
21
- self.text_splitter.split_text("\n\n".join(texts)),
22
- self.embedder
23
- )
24
- else:
25
- self.vectorstore.add_texts(
26
- self.text_splitter.split_text("\n\n".join(texts))
27
- )
28
 
29
  def retrieve(self, query, top_k=3):
30
- if self.vectorstore is None:
 
31
  return []
32
- return [
33
- doc.page_content
34
- for doc in self.vectorstore.similarity_search(query, k=top_k)
35
- ]
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+
5
+ # load model only once
6
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
7
 
8
  class VectorStore:
9
  def __init__(self):
10
+ self.texts = []
11
+ self.embeddings = []
12
+ self.index = None
 
 
 
13
 
14
  def add_texts(self, texts):
15
+ """Add list of texts to the store."""
16
+ new_embeds = embedder.encode(texts)
17
+ self.texts.extend(texts)
18
+ self.embeddings.extend(new_embeds)
19
+ self.index = faiss.IndexFlatL2(new_embeds.shape[1])
20
+ self.index.add(np.array(self.embeddings))
 
 
 
 
 
 
 
21
 
22
  def retrieve(self, query, top_k=3):
23
+ """Return top-k relevant texts for the query."""
24
+ if not self.index:
25
  return []
26
+ query_embed = embedder.encode([query])
27
+ D, I = self.index.search(np.array(query_embed), k=top_k)
28
+ return [self.texts[i] for i in I[0]]