Scaper_search / rag.py
gaur3009's picture
Update rag.py
d5a33e6 verified
raw
history blame
906 Bytes
# rag.py
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# load model only once
embedder = SentenceTransformer('all-MiniLM-L6-v2')
class VectorStore:
def __init__(self):
self.texts = []
self.embeddings = []
self.index = None
def add_texts(self, texts):
"""Add list of texts to the store."""
new_embeds = embedder.encode(texts)
self.texts.extend(texts)
self.embeddings.extend(new_embeds)
self.index = faiss.IndexFlatL2(new_embeds.shape[1])
self.index.add(np.array(self.embeddings))
def retrieve(self, query, top_k=3):
"""Return top-k relevant texts for the query."""
if not self.index:
return []
query_embed = embedder.encode([query])
D, I = self.index.search(np.array(query_embed), k=top_k)
return [self.texts[i] for i in I[0]]