Update buffalo_rag/vector_store/db.py
Browse files
buffalo_rag/vector_store/db.py
CHANGED
@@ -17,21 +17,15 @@ class VectorStore:
|
|
17 |
self.chunk_ids = []
|
18 |
self.chunks = {}
|
19 |
|
20 |
-
# Load embedding model
|
21 |
self.model = SentenceTransformer(model_name)
|
22 |
-
|
23 |
-
# Load reranker model
|
24 |
self.reranker = CrossEncoder(reranker_name)
|
25 |
|
26 |
-
# Load or create index
|
27 |
self.load_or_create_index()
|
28 |
|
29 |
def load_or_create_index(self) -> None:
|
30 |
-
"""Load existing index or create a new one."""
|
31 |
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
|
32 |
|
33 |
if os.path.exists(index_path):
|
34 |
-
# Load existing index
|
35 |
with open(index_path, 'rb') as f:
|
36 |
data = pickle.load(f)
|
37 |
self.index = data['index']
|
@@ -39,7 +33,6 @@ class VectorStore:
|
|
39 |
self.chunks = data['chunks']
|
40 |
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
|
41 |
else:
|
42 |
-
# Create new index
|
43 |
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
|
44 |
if os.path.exists(embeddings_path):
|
45 |
self.create_index()
|
@@ -53,22 +46,18 @@ class VectorStore:
|
|
53 |
with open(embeddings_path, 'rb') as f:
|
54 |
embedding_map = pickle.load(f)
|
55 |
|
56 |
-
# Extract embeddings and chunk IDs
|
57 |
chunk_ids = list(embedding_map.keys())
|
58 |
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
|
59 |
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
|
60 |
|
61 |
-
# Create FAISS index
|
62 |
dimension = embeddings.shape[1]
|
63 |
index = faiss.IndexFlatL2(dimension)
|
64 |
index.add(embeddings.astype(np.float32))
|
65 |
|
66 |
-
# Save index and metadata
|
67 |
self.index = index
|
68 |
self.chunk_ids = chunk_ids
|
69 |
self.chunks = chunks
|
70 |
|
71 |
-
# Save to disk
|
72 |
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
|
73 |
pickle.dump({
|
74 |
'index': index,
|
@@ -83,24 +72,20 @@ class VectorStore:
|
|
83 |
k: int = 5,
|
84 |
filter_categories: Optional[List[str]] = None,
|
85 |
rerank: bool = True) -> List[Dict[str, Any]]:
|
86 |
-
|
87 |
if self.index is None:
|
88 |
print("No index available. Please create an index first.")
|
89 |
return []
|
90 |
|
91 |
-
# Create query embedding
|
92 |
query_embedding = self.model.encode([query])[0]
|
93 |
|
94 |
-
# Search index
|
95 |
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
|
96 |
|
97 |
-
# Get results
|
98 |
results = []
|
99 |
for i, idx in enumerate(I[0]):
|
100 |
chunk_id = self.chunk_ids[idx]
|
101 |
chunk = self.chunks[chunk_id]
|
102 |
|
103 |
-
# Apply category filter if specified
|
104 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
105 |
continue
|
106 |
|
@@ -111,22 +96,16 @@ class VectorStore:
|
|
111 |
}
|
112 |
results.append(result)
|
113 |
|
114 |
-
# Rerank results if requested
|
115 |
if rerank and results:
|
116 |
-
# Prepare pairs for reranking
|
117 |
pairs = [(query, result['chunk']['content']) for result in results]
|
118 |
-
|
119 |
-
# Get reranking scores
|
120 |
rerank_scores = self.reranker.predict(pairs)
|
121 |
|
122 |
-
# Update scores and sort
|
123 |
for i, score in enumerate(rerank_scores):
|
124 |
results[i]['rerank_score'] = float(score)
|
125 |
|
126 |
-
# Sort by rerank score
|
127 |
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
|
128 |
|
129 |
-
# Limit to k results
|
130 |
results = results[:k]
|
131 |
|
132 |
return results
|
@@ -135,29 +114,22 @@ class VectorStore:
|
|
135 |
query: str,
|
136 |
k: int = 5,
|
137 |
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
138 |
-
"""Combine dense vector search with BM25-style keyword matching."""
|
139 |
-
# Get vector search results
|
140 |
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
|
141 |
|
142 |
-
# Simple keyword matching (simulating BM25)
|
143 |
keywords = query.lower().split()
|
144 |
-
|
145 |
-
# Score all chunks by keyword presence
|
146 |
keyword_scores = {}
|
|
|
147 |
for chunk_id, chunk_data in self.chunks.items():
|
148 |
chunk = chunk_data
|
149 |
content = (chunk['title'] + " " + chunk['content']).lower()
|
150 |
|
151 |
-
# Count keyword matches
|
152 |
score = sum(content.count(keyword) for keyword in keywords)
|
153 |
|
154 |
-
# Apply category filter if specified
|
155 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
156 |
continue
|
157 |
|
158 |
keyword_scores[chunk_id] = score
|
159 |
|
160 |
-
# Get top keyword matches
|
161 |
keyword_results = sorted(
|
162 |
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
|
163 |
for chunk_id, score in keyword_scores.items() if score > 0],
|
@@ -165,49 +137,28 @@ class VectorStore:
|
|
165 |
reverse=True
|
166 |
)[:k]
|
167 |
|
168 |
-
# Combine results (remove duplicates)
|
169 |
seen_ids = set()
|
170 |
combined_results = []
|
171 |
|
172 |
-
# Add vector results first
|
173 |
for result in vector_results:
|
174 |
combined_results.append(result)
|
175 |
seen_ids.add(result['chunk_id'])
|
176 |
|
177 |
-
# Add keyword results if not already added
|
178 |
for result in keyword_results:
|
179 |
if result['chunk_id'] not in seen_ids:
|
180 |
combined_results.append(result)
|
181 |
seen_ids.add(result['chunk_id'])
|
182 |
|
183 |
-
# Limit to k results
|
184 |
combined_results = combined_results[:k]
|
185 |
|
186 |
-
# Rerank final results
|
187 |
if combined_results:
|
188 |
-
# Prepare pairs for reranking
|
189 |
pairs = [(query, result['chunk']['content']) for result in combined_results]
|
190 |
|
191 |
-
# Get reranking scores
|
192 |
rerank_scores = self.reranker.predict(pairs)
|
193 |
|
194 |
-
# Update scores and sort
|
195 |
for i, score in enumerate(rerank_scores):
|
196 |
combined_results[i]['rerank_score'] = float(score)
|
197 |
|
198 |
-
# Sort by rerank score
|
199 |
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
|
200 |
|
201 |
-
return combined_results
|
202 |
-
|
203 |
-
# Example usage
|
204 |
-
if __name__ == "__main__":
|
205 |
-
vector_store = VectorStore()
|
206 |
-
results = vector_store.hybrid_search("How do I apply for OPT?")
|
207 |
-
|
208 |
-
print(f"Found {len(results)} results")
|
209 |
-
for i, result in enumerate(results[:3]):
|
210 |
-
print(f"Result {i+1}: {result['chunk']['title']}")
|
211 |
-
print(f"Score: {result.get('rerank_score', result['score'])}")
|
212 |
-
print(f"Content: {result['chunk']['content'][:100]}...")
|
213 |
-
print()
|
|
|
17 |
self.chunk_ids = []
|
18 |
self.chunks = {}
|
19 |
|
|
|
20 |
self.model = SentenceTransformer(model_name)
|
|
|
|
|
21 |
self.reranker = CrossEncoder(reranker_name)
|
22 |
|
|
|
23 |
self.load_or_create_index()
|
24 |
|
25 |
def load_or_create_index(self) -> None:
|
|
|
26 |
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
|
27 |
|
28 |
if os.path.exists(index_path):
|
|
|
29 |
with open(index_path, 'rb') as f:
|
30 |
data = pickle.load(f)
|
31 |
self.index = data['index']
|
|
|
33 |
self.chunks = data['chunks']
|
34 |
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
|
35 |
else:
|
|
|
36 |
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
|
37 |
if os.path.exists(embeddings_path):
|
38 |
self.create_index()
|
|
|
46 |
with open(embeddings_path, 'rb') as f:
|
47 |
embedding_map = pickle.load(f)
|
48 |
|
|
|
49 |
chunk_ids = list(embedding_map.keys())
|
50 |
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
|
51 |
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
|
52 |
|
|
|
53 |
dimension = embeddings.shape[1]
|
54 |
index = faiss.IndexFlatL2(dimension)
|
55 |
index.add(embeddings.astype(np.float32))
|
56 |
|
|
|
57 |
self.index = index
|
58 |
self.chunk_ids = chunk_ids
|
59 |
self.chunks = chunks
|
60 |
|
|
|
61 |
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
|
62 |
pickle.dump({
|
63 |
'index': index,
|
|
|
72 |
k: int = 5,
|
73 |
filter_categories: Optional[List[str]] = None,
|
74 |
rerank: bool = True) -> List[Dict[str, Any]]:
|
75 |
+
|
76 |
if self.index is None:
|
77 |
print("No index available. Please create an index first.")
|
78 |
return []
|
79 |
|
|
|
80 |
query_embedding = self.model.encode([query])[0]
|
81 |
|
|
|
82 |
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
|
83 |
|
|
|
84 |
results = []
|
85 |
for i, idx in enumerate(I[0]):
|
86 |
chunk_id = self.chunk_ids[idx]
|
87 |
chunk = self.chunks[chunk_id]
|
88 |
|
|
|
89 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
90 |
continue
|
91 |
|
|
|
96 |
}
|
97 |
results.append(result)
|
98 |
|
|
|
99 |
if rerank and results:
|
|
|
100 |
pairs = [(query, result['chunk']['content']) for result in results]
|
101 |
+
|
|
|
102 |
rerank_scores = self.reranker.predict(pairs)
|
103 |
|
|
|
104 |
for i, score in enumerate(rerank_scores):
|
105 |
results[i]['rerank_score'] = float(score)
|
106 |
|
|
|
107 |
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
|
108 |
|
|
|
109 |
results = results[:k]
|
110 |
|
111 |
return results
|
|
|
114 |
query: str,
|
115 |
k: int = 5,
|
116 |
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
|
|
|
117 |
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
|
118 |
|
|
|
119 |
keywords = query.lower().split()
|
|
|
|
|
120 |
keyword_scores = {}
|
121 |
+
|
122 |
for chunk_id, chunk_data in self.chunks.items():
|
123 |
chunk = chunk_data
|
124 |
content = (chunk['title'] + " " + chunk['content']).lower()
|
125 |
|
|
|
126 |
score = sum(content.count(keyword) for keyword in keywords)
|
127 |
|
|
|
128 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
129 |
continue
|
130 |
|
131 |
keyword_scores[chunk_id] = score
|
132 |
|
|
|
133 |
keyword_results = sorted(
|
134 |
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
|
135 |
for chunk_id, score in keyword_scores.items() if score > 0],
|
|
|
137 |
reverse=True
|
138 |
)[:k]
|
139 |
|
|
|
140 |
seen_ids = set()
|
141 |
combined_results = []
|
142 |
|
|
|
143 |
for result in vector_results:
|
144 |
combined_results.append(result)
|
145 |
seen_ids.add(result['chunk_id'])
|
146 |
|
|
|
147 |
for result in keyword_results:
|
148 |
if result['chunk_id'] not in seen_ids:
|
149 |
combined_results.append(result)
|
150 |
seen_ids.add(result['chunk_id'])
|
151 |
|
|
|
152 |
combined_results = combined_results[:k]
|
153 |
|
|
|
154 |
if combined_results:
|
|
|
155 |
pairs = [(query, result['chunk']['content']) for result in combined_results]
|
156 |
|
|
|
157 |
rerank_scores = self.reranker.predict(pairs)
|
158 |
|
|
|
159 |
for i, score in enumerate(rerank_scores):
|
160 |
combined_results[i]['rerank_score'] = float(score)
|
161 |
|
|
|
162 |
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
|
163 |
|
164 |
+
return combined_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|