gaur3009 commited on
Commit
d5a33e6
Β·
verified Β·
1 Parent(s): daac110

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +9 -25
rag.py CHANGED
@@ -1,45 +1,29 @@
 
1
  from sentence_transformers import SentenceTransformer
2
  import faiss
3
  import numpy as np
4
 
5
- # Load model only once
6
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
7
- DIMENSION = 384 # Fixed dimension for all-MiniLM-L6-v2
8
 
9
  class VectorStore:
10
  def __init__(self):
11
  self.texts = []
 
12
  self.index = None
13
- self.embeddings = None
14
 
15
  def add_texts(self, texts):
16
  """Add list of texts to the store."""
17
- if not texts:
18
- return
19
-
20
  new_embeds = embedder.encode(texts)
21
-
22
- # Initialize index if needed
23
- if self.index is None:
24
- self.index = faiss.IndexFlatL2(DIMENSION)
25
- self.embeddings = new_embeds
26
- else:
27
- self.embeddings = np.vstack([self.embeddings, new_embeds])
28
-
29
- # Rebuild index with all embeddings
30
- self.index.reset()
31
- self.index.add(self.embeddings.astype('float32'))
32
  self.texts.extend(texts)
 
 
 
33
 
34
  def retrieve(self, query, top_k=3):
35
  """Return top-k relevant texts for the query."""
36
- if not self.has_data():
37
  return []
38
-
39
  query_embed = embedder.encode([query])
40
- _, I = self.index.search(query_embed.astype('float32'), top_k)
41
- return [self.texts[i] for i in I[0] if i < len(self.texts)]
42
-
43
- def has_data(self):
44
- """Check if we have any data stored"""
45
- return self.index is not None and self.index.ntotal > 0
 
1
+ # rag.py
2
  from sentence_transformers import SentenceTransformer
3
  import faiss
4
  import numpy as np
5
 
6
+ # load model only once
7
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
 
8
 
9
  class VectorStore:
10
  def __init__(self):
11
  self.texts = []
12
+ self.embeddings = []
13
  self.index = None
 
14
 
15
  def add_texts(self, texts):
16
  """Add list of texts to the store."""
 
 
 
17
  new_embeds = embedder.encode(texts)
 
 
 
 
 
 
 
 
 
 
 
18
  self.texts.extend(texts)
19
+ self.embeddings.extend(new_embeds)
20
+ self.index = faiss.IndexFlatL2(new_embeds.shape[1])
21
+ self.index.add(np.array(self.embeddings))
22
 
23
  def retrieve(self, query, top_k=3):
24
  """Return top-k relevant texts for the query."""
25
+ if not self.index:
26
  return []
 
27
  query_embed = embedder.encode([query])
28
+ D, I = self.index.search(np.array(query_embed), k=top_k)
29
+ return [self.texts[i] for i in I[0]]