gaur3009 commited on
Commit
9fb5174
Β·
verified Β·
1 Parent(s): 94de8a5

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +25 -8
rag.py CHANGED
@@ -2,27 +2,44 @@ from sentence_transformers import SentenceTransformer
2
  import faiss
3
  import numpy as np
4
 
5
- # load model only once
6
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
 
7
 
8
  class VectorStore:
9
  def __init__(self):
10
  self.texts = []
11
- self.embeddings = []
12
  self.index = None
 
13
 
14
  def add_texts(self, texts):
15
  """Add list of texts to the store."""
 
 
 
16
  new_embeds = embedder.encode(texts)
 
 
 
 
 
 
 
 
 
 
 
17
  self.texts.extend(texts)
18
- self.embeddings.extend(new_embeds)
19
- self.index = faiss.IndexFlatL2(new_embeds.shape[1])
20
- self.index.add(np.array(self.embeddings))
21
 
22
  def retrieve(self, query, top_k=3):
23
  """Return top-k relevant texts for the query."""
24
- if not self.index:
25
  return []
 
26
  query_embed = embedder.encode([query])
27
- D, I = self.index.search(np.array(query_embed), k=top_k)
28
- return [self.texts[i] for i in I[0]]
 
 
 
 
 
2
  import faiss
3
  import numpy as np
4
 
5
+ # Load model only once
6
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
7
+ DIMENSION = 384 # Fixed dimension for all-MiniLM-L6-v2
8
 
9
  class VectorStore:
10
  def __init__(self):
11
  self.texts = []
 
12
  self.index = None
13
+ self.embeddings = None
14
 
15
  def add_texts(self, texts):
16
  """Add list of texts to the store."""
17
+ if not texts:
18
+ return
19
+
20
  new_embeds = embedder.encode(texts)
21
+
22
+ # Initialize index if needed
23
+ if self.index is None:
24
+ self.index = faiss.IndexFlatL2(DIMENSION)
25
+ self.embeddings = new_embeds
26
+ else:
27
+ self.embeddings = np.vstack([self.embeddings, new_embeds])
28
+
29
+ # Rebuild index with all embeddings
30
+ self.index.reset()
31
+ self.index.add(self.embeddings.astype('float32'))
32
  self.texts.extend(texts)
 
 
 
33
 
34
  def retrieve(self, query, top_k=3):
35
  """Return top-k relevant texts for the query."""
36
+ if not self.has_data():
37
  return []
38
+
39
  query_embed = embedder.encode([query])
40
+ _, I = self.index.search(query_embed.astype('float32'), top_k)
41
+ return [self.texts[i] for i in I[0] if i < len(self.texts)]
42
+
43
+ def has_data(self):
44
+ """Check if we have any data stored"""
45
+ return self.index is not None and self.index.ntotal > 0