gaur3009 commited on
Commit
5d969f7
Β·
verified Β·
1 Parent(s): 1fab24f

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +27 -20
rag.py CHANGED
@@ -1,28 +1,35 @@
1
- from sentence_transformers import SentenceTransformer
2
- import faiss
3
- import numpy as np
4
-
5
- # load model only once
6
- embedder = SentenceTransformer('all-MiniLM-L6-v2')
7
 
8
  class VectorStore:
9
  def __init__(self):
10
- self.texts = []
11
- self.embeddings = []
12
- self.index = None
 
 
 
13
 
14
  def add_texts(self, texts):
15
- """Add list of texts to the store."""
16
- new_embeds = embedder.encode(texts)
17
- self.texts.extend(texts)
18
- self.embeddings.extend(new_embeds)
19
- self.index = faiss.IndexFlatL2(new_embeds.shape[1])
20
- self.index.add(np.array(self.embeddings))
 
 
 
 
 
 
 
21
 
22
  def retrieve(self, query, top_k=3):
23
- """Return top-k relevant texts for the query."""
24
- if not self.index:
25
  return []
26
- query_embed = embedder.encode([query])
27
- D, I = self.index.search(np.array(query_embed), k=top_k)
28
- return [self.texts[i] for i in I[0]]
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
 
 
 
4
 
5
  class VectorStore:
6
  def __init__(self):
7
+ self.vectorstore = None
8
+ self.embedder = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
9
+ self.text_splitter = RecursiveCharacterTextSplitter(
10
+ chunk_size=500,
11
+ chunk_overlap=50
12
+ )
13
 
14
  def add_texts(self, texts):
15
+ if not texts:
16
+ return
17
+
18
+ # Split and add texts
19
+ if self.vectorstore is None:
20
+ self.vectorstore = FAISS.from_texts(
21
+ self.text_splitter.split_text("\n\n".join(texts)),
22
+ self.embedder
23
+ )
24
+ else:
25
+ self.vectorstore.add_texts(
26
+ self.text_splitter.split_text("\n\n".join(texts))
27
+ )
28
 
29
  def retrieve(self, query, top_k=3):
30
+ if self.vectorstore is None:
 
31
  return []
32
+ return [
33
+ doc.page_content
34
+ for doc in self.vectorstore.similarity_search(query, k=top_k)
35
+ ]