Spaces:
Runtime error
Runtime error
Update rag.py
Browse files
rag.py
CHANGED
@@ -1,45 +1,29 @@
|
|
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
|
5 |
-
#
|
6 |
embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
7 |
-
DIMENSION = 384 # Fixed dimension for all-MiniLM-L6-v2
|
8 |
|
9 |
class VectorStore:
|
10 |
def __init__(self):
|
11 |
self.texts = []
|
|
|
12 |
self.index = None
|
13 |
-
self.embeddings = None
|
14 |
|
15 |
def add_texts(self, texts):
|
16 |
"""Add list of texts to the store."""
|
17 |
-
if not texts:
|
18 |
-
return
|
19 |
-
|
20 |
new_embeds = embedder.encode(texts)
|
21 |
-
|
22 |
-
# Initialize index if needed
|
23 |
-
if self.index is None:
|
24 |
-
self.index = faiss.IndexFlatL2(DIMENSION)
|
25 |
-
self.embeddings = new_embeds
|
26 |
-
else:
|
27 |
-
self.embeddings = np.vstack([self.embeddings, new_embeds])
|
28 |
-
|
29 |
-
# Rebuild index with all embeddings
|
30 |
-
self.index.reset()
|
31 |
-
self.index.add(self.embeddings.astype('float32'))
|
32 |
self.texts.extend(texts)
|
|
|
|
|
|
|
33 |
|
34 |
def retrieve(self, query, top_k=3):
|
35 |
"""Return top-k relevant texts for the query."""
|
36 |
-
if not self.
|
37 |
return []
|
38 |
-
|
39 |
query_embed = embedder.encode([query])
|
40 |
-
|
41 |
-
return [self.texts[i] for i in I[0]
|
42 |
-
|
43 |
-
def has_data(self):
|
44 |
-
"""Check if we have any data stored"""
|
45 |
-
return self.index is not None and self.index.ntotal > 0
|
|
|
1 |
+
# rag.py
|
2 |
from sentence_transformers import SentenceTransformer
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
|
6 |
+
# load model only once
|
7 |
embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
8 |
|
9 |
class VectorStore:
|
10 |
def __init__(self):
|
11 |
self.texts = []
|
12 |
+
self.embeddings = []
|
13 |
self.index = None
|
|
|
14 |
|
15 |
def add_texts(self, texts):
|
16 |
"""Add list of texts to the store."""
|
|
|
|
|
|
|
17 |
new_embeds = embedder.encode(texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
self.texts.extend(texts)
|
19 |
+
self.embeddings.extend(new_embeds)
|
20 |
+
self.index = faiss.IndexFlatL2(new_embeds.shape[1])
|
21 |
+
self.index.add(np.array(self.embeddings))
|
22 |
|
23 |
def retrieve(self, query, top_k=3):
|
24 |
"""Return top-k relevant texts for the query."""
|
25 |
+
if not self.index:
|
26 |
return []
|
|
|
27 |
query_embed = embedder.encode([query])
|
28 |
+
D, I = self.index.search(np.array(query_embed), k=top_k)
|
29 |
+
return [self.texts[i] for i in I[0]]
|
|
|
|
|
|
|
|