File size: 906 Bytes
d5a33e6
69374eb
 
 
 
d5a33e6
69374eb
db7ceef
 
 
69374eb
d5a33e6
69374eb
db7ceef
 
69374eb
 
 
d5a33e6
 
 
db7ceef
 
69374eb
d5a33e6
db7ceef
69374eb
d5a33e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# rag.py
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

# load model only once
embedder = SentenceTransformer('all-MiniLM-L6-v2')

class VectorStore:
    def __init__(self):
        self.texts = []
        self.embeddings = []
        self.index = None

    def add_texts(self, texts):
        """Add list of texts to the store."""
        new_embeds = embedder.encode(texts)
        self.texts.extend(texts)
        self.embeddings.extend(new_embeds)
        self.index = faiss.IndexFlatL2(new_embeds.shape[1])
        self.index.add(np.array(self.embeddings))

    def retrieve(self, query, top_k=3):
        """Return top-k relevant texts for the query."""
        if not self.index:
            return []
        query_embed = embedder.encode([query])
        D, I = self.index.search(np.array(query_embed), k=top_k)
        return [self.texts[i] for i in I[0]]