gaur3009 commited on
Commit
34a9313
Β·
verified Β·
1 Parent(s): 7474f2b

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +45 -12
rag.py CHANGED
@@ -1,9 +1,9 @@
1
- # rag.py
2
  from sentence_transformers import SentenceTransformer
3
  import faiss
4
  import numpy as np
 
5
 
6
- # load model only once
7
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
8
 
9
  class VectorStore:
@@ -11,19 +11,52 @@ class VectorStore:
11
  self.texts = []
12
  self.embeddings = []
13
  self.index = None
 
14
 
15
  def add_texts(self, texts):
16
- """Add list of texts to the store."""
17
- new_embeds = embedder.encode(texts)
18
- self.texts.extend(texts)
 
 
 
 
 
 
 
 
 
 
 
19
  self.embeddings.extend(new_embeds)
20
- self.index = faiss.IndexFlatL2(new_embeds.shape[1])
21
- self.index.add(np.array(self.embeddings))
 
 
 
 
 
 
 
22
 
23
  def retrieve(self, query, top_k=3):
24
- """Return top-k relevant texts for the query."""
25
- if not self.index:
26
- return []
 
 
27
  query_embed = embedder.encode([query])
28
- D, I = self.index.search(np.array(query_embed), k=top_k)
29
- return [self.texts[i] for i in I[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from sentence_transformers import SentenceTransformer
2
  import faiss
3
  import numpy as np
4
+ import hashlib
5
 
6
+ # Load model once
7
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
8
 
9
  class VectorStore:
 
11
  self.texts = []
12
  self.embeddings = []
13
  self.index = None
14
+ self.text_hashes = set()
15
 
16
  def add_texts(self, texts):
17
+ """Add list of texts to the store, avoiding duplicates"""
18
+ new_texts = []
19
+ for text in texts:
20
+ text_hash = hashlib.md5(text.encode()).hexdigest()
21
+ if text_hash not in self.text_hashes:
22
+ new_texts.append(text)
23
+ self.text_hashes.add(text_hash)
24
+
25
+ if not new_texts:
26
+ return
27
+
28
+ # Encode new texts
29
+ new_embeds = embedder.encode(new_texts)
30
+ self.texts.extend(new_texts)
31
  self.embeddings.extend(new_embeds)
32
+
33
+ # Update FAISS index
34
+ if self.index is None:
35
+ self.index = faiss.IndexFlatL2(new_embeds[0].shape[0])
36
+
37
+ # Convert to numpy array and add to index
38
+ embeds_array = np.array(self.embeddings).astype('float32')
39
+ self.index.reset()
40
+ self.index.add(embeds_array)
41
 
42
  def retrieve(self, query, top_k=3):
43
+ """Return top-k relevant texts and their indices"""
44
+ if not self.index or not self.texts:
45
+ return [], []
46
+
47
+ # Encode query
48
  query_embed = embedder.encode([query])
49
+ query_array = np.array(query_embed).astype('float32')
50
+
51
+ # Search
52
+ distances, indices = self.index.search(query_array, k=min(top_k, len(self.texts)))
53
+
54
+ # Return texts and indices
55
+ return [self.texts[i] for i in indices[0]], indices[0].tolist()
56
+
57
+ def clear(self):
58
+ """Clear the vector store"""
59
+ self.texts = []
60
+ self.embeddings = []
61
+ self.index = None
62
+ self.text_hashes = set()