Spaces:
Runtime error
Runtime error
Update rag.py
Browse files
rag.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
# rag.py
|
2 |
from sentence_transformers import SentenceTransformer
|
3 |
import faiss
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
-
#
|
7 |
embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
8 |
|
9 |
class VectorStore:
|
@@ -11,19 +11,52 @@ class VectorStore:
|
|
11 |
self.texts = []
|
12 |
self.embeddings = []
|
13 |
self.index = None
|
|
|
14 |
|
15 |
def add_texts(self, texts):
|
16 |
-
"""Add list of texts to the store
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
self.embeddings.extend(new_embeds)
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def retrieve(self, query, top_k=3):
|
24 |
-
"""Return top-k relevant texts
|
25 |
-
if not self.index:
|
26 |
-
return []
|
|
|
|
|
27 |
query_embed = embedder.encode([query])
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
+
import hashlib
|
5 |
|
6 |
+
# Load model once
|
7 |
embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
8 |
|
9 |
class VectorStore:
|
|
|
11 |
self.texts = []
|
12 |
self.embeddings = []
|
13 |
self.index = None
|
14 |
+
self.text_hashes = set()
|
15 |
|
16 |
def add_texts(self, texts):
|
17 |
+
"""Add list of texts to the store, avoiding duplicates"""
|
18 |
+
new_texts = []
|
19 |
+
for text in texts:
|
20 |
+
text_hash = hashlib.md5(text.encode()).hexdigest()
|
21 |
+
if text_hash not in self.text_hashes:
|
22 |
+
new_texts.append(text)
|
23 |
+
self.text_hashes.add(text_hash)
|
24 |
+
|
25 |
+
if not new_texts:
|
26 |
+
return
|
27 |
+
|
28 |
+
# Encode new texts
|
29 |
+
new_embeds = embedder.encode(new_texts)
|
30 |
+
self.texts.extend(new_texts)
|
31 |
self.embeddings.extend(new_embeds)
|
32 |
+
|
33 |
+
# Update FAISS index
|
34 |
+
if self.index is None:
|
35 |
+
self.index = faiss.IndexFlatL2(new_embeds[0].shape[0])
|
36 |
+
|
37 |
+
# Convert to numpy array and add to index
|
38 |
+
embeds_array = np.array(self.embeddings).astype('float32')
|
39 |
+
self.index.reset()
|
40 |
+
self.index.add(embeds_array)
|
41 |
|
42 |
def retrieve(self, query, top_k=3):
|
43 |
+
"""Return top-k relevant texts and their indices"""
|
44 |
+
if not self.index or not self.texts:
|
45 |
+
return [], []
|
46 |
+
|
47 |
+
# Encode query
|
48 |
query_embed = embedder.encode([query])
|
49 |
+
query_array = np.array(query_embed).astype('float32')
|
50 |
+
|
51 |
+
# Search
|
52 |
+
distances, indices = self.index.search(query_array, k=min(top_k, len(self.texts)))
|
53 |
+
|
54 |
+
# Return texts and indices
|
55 |
+
return [self.texts[i] for i in indices[0]], indices[0].tolist()
|
56 |
+
|
57 |
+
def clear(self):
|
58 |
+
"""Clear the vector store"""
|
59 |
+
self.texts = []
|
60 |
+
self.embeddings = []
|
61 |
+
self.index = None
|
62 |
+
self.text_hashes = set()
|