AUMREDKA commited on
Commit
999388b
·
verified ·
1 Parent(s): 3673d92

Update buffalo_rag/vector_store/db.py

Browse files
Files changed (1) hide show
  1. buffalo_rag/vector_store/db.py +4 -53
buffalo_rag/vector_store/db.py CHANGED
@@ -17,21 +17,15 @@ class VectorStore:
17
  self.chunk_ids = []
18
  self.chunks = {}
19
 
20
- # Load embedding model
21
  self.model = SentenceTransformer(model_name)
22
-
23
- # Load reranker model
24
  self.reranker = CrossEncoder(reranker_name)
25
 
26
- # Load or create index
27
  self.load_or_create_index()
28
 
29
  def load_or_create_index(self) -> None:
30
- """Load existing index or create a new one."""
31
  index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
32
 
33
  if os.path.exists(index_path):
34
- # Load existing index
35
  with open(index_path, 'rb') as f:
36
  data = pickle.load(f)
37
  self.index = data['index']
@@ -39,7 +33,6 @@ class VectorStore:
39
  self.chunks = data['chunks']
40
  print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
41
  else:
42
- # Create new index
43
  embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
44
  if os.path.exists(embeddings_path):
45
  self.create_index()
@@ -53,22 +46,18 @@ class VectorStore:
53
  with open(embeddings_path, 'rb') as f:
54
  embedding_map = pickle.load(f)
55
 
56
- # Extract embeddings and chunk IDs
57
  chunk_ids = list(embedding_map.keys())
58
  embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
59
  chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
60
 
61
- # Create FAISS index
62
  dimension = embeddings.shape[1]
63
  index = faiss.IndexFlatL2(dimension)
64
  index.add(embeddings.astype(np.float32))
65
 
66
- # Save index and metadata
67
  self.index = index
68
  self.chunk_ids = chunk_ids
69
  self.chunks = chunks
70
 
71
- # Save to disk
72
  with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
73
  pickle.dump({
74
  'index': index,
@@ -83,24 +72,20 @@ class VectorStore:
83
  k: int = 5,
84
  filter_categories: Optional[List[str]] = None,
85
  rerank: bool = True) -> List[Dict[str, Any]]:
86
- """Search for relevant chunks."""
87
  if self.index is None:
88
  print("No index available. Please create an index first.")
89
  return []
90
 
91
- # Create query embedding
92
  query_embedding = self.model.encode([query])[0]
93
 
94
- # Search index
95
  D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
96
 
97
- # Get results
98
  results = []
99
  for i, idx in enumerate(I[0]):
100
  chunk_id = self.chunk_ids[idx]
101
  chunk = self.chunks[chunk_id]
102
 
103
- # Apply category filter if specified
104
  if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
105
  continue
106
 
@@ -111,22 +96,16 @@ class VectorStore:
111
  }
112
  results.append(result)
113
 
114
- # Rerank results if requested
115
  if rerank and results:
116
- # Prepare pairs for reranking
117
  pairs = [(query, result['chunk']['content']) for result in results]
118
-
119
- # Get reranking scores
120
  rerank_scores = self.reranker.predict(pairs)
121
 
122
- # Update scores and sort
123
  for i, score in enumerate(rerank_scores):
124
  results[i]['rerank_score'] = float(score)
125
 
126
- # Sort by rerank score
127
  results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
128
 
129
- # Limit to k results
130
  results = results[:k]
131
 
132
  return results
@@ -135,29 +114,22 @@ class VectorStore:
135
  query: str,
136
  k: int = 5,
137
  filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
138
- """Combine dense vector search with BM25-style keyword matching."""
139
- # Get vector search results
140
  vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
141
 
142
- # Simple keyword matching (simulating BM25)
143
  keywords = query.lower().split()
144
-
145
- # Score all chunks by keyword presence
146
  keyword_scores = {}
 
147
  for chunk_id, chunk_data in self.chunks.items():
148
  chunk = chunk_data
149
  content = (chunk['title'] + " " + chunk['content']).lower()
150
 
151
- # Count keyword matches
152
  score = sum(content.count(keyword) for keyword in keywords)
153
 
154
- # Apply category filter if specified
155
  if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
156
  continue
157
 
158
  keyword_scores[chunk_id] = score
159
 
160
- # Get top keyword matches
161
  keyword_results = sorted(
162
  [{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
163
  for chunk_id, score in keyword_scores.items() if score > 0],
@@ -165,49 +137,28 @@ class VectorStore:
165
  reverse=True
166
  )[:k]
167
 
168
- # Combine results (remove duplicates)
169
  seen_ids = set()
170
  combined_results = []
171
 
172
- # Add vector results first
173
  for result in vector_results:
174
  combined_results.append(result)
175
  seen_ids.add(result['chunk_id'])
176
 
177
- # Add keyword results if not already added
178
  for result in keyword_results:
179
  if result['chunk_id'] not in seen_ids:
180
  combined_results.append(result)
181
  seen_ids.add(result['chunk_id'])
182
 
183
- # Limit to k results
184
  combined_results = combined_results[:k]
185
 
186
- # Rerank final results
187
  if combined_results:
188
- # Prepare pairs for reranking
189
  pairs = [(query, result['chunk']['content']) for result in combined_results]
190
 
191
- # Get reranking scores
192
  rerank_scores = self.reranker.predict(pairs)
193
 
194
- # Update scores and sort
195
  for i, score in enumerate(rerank_scores):
196
  combined_results[i]['rerank_score'] = float(score)
197
 
198
- # Sort by rerank score
199
  combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
200
 
201
- return combined_results
202
-
203
- # Example usage
204
- if __name__ == "__main__":
205
- vector_store = VectorStore()
206
- results = vector_store.hybrid_search("How do I apply for OPT?")
207
-
208
- print(f"Found {len(results)} results")
209
- for i, result in enumerate(results[:3]):
210
- print(f"Result {i+1}: {result['chunk']['title']}")
211
- print(f"Score: {result.get('rerank_score', result['score'])}")
212
- print(f"Content: {result['chunk']['content'][:100]}...")
213
- print()
 
17
  self.chunk_ids = []
18
  self.chunks = {}
19
 
 
20
  self.model = SentenceTransformer(model_name)
 
 
21
  self.reranker = CrossEncoder(reranker_name)
22
 
 
23
  self.load_or_create_index()
24
 
25
  def load_or_create_index(self) -> None:
 
26
  index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
27
 
28
  if os.path.exists(index_path):
 
29
  with open(index_path, 'rb') as f:
30
  data = pickle.load(f)
31
  self.index = data['index']
 
33
  self.chunks = data['chunks']
34
  print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
35
  else:
 
36
  embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
37
  if os.path.exists(embeddings_path):
38
  self.create_index()
 
46
  with open(embeddings_path, 'rb') as f:
47
  embedding_map = pickle.load(f)
48
 
 
49
  chunk_ids = list(embedding_map.keys())
50
  embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
51
  chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
52
 
 
53
  dimension = embeddings.shape[1]
54
  index = faiss.IndexFlatL2(dimension)
55
  index.add(embeddings.astype(np.float32))
56
 
 
57
  self.index = index
58
  self.chunk_ids = chunk_ids
59
  self.chunks = chunks
60
 
 
61
  with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
62
  pickle.dump({
63
  'index': index,
 
72
  k: int = 5,
73
  filter_categories: Optional[List[str]] = None,
74
  rerank: bool = True) -> List[Dict[str, Any]]:
75
+
76
  if self.index is None:
77
  print("No index available. Please create an index first.")
78
  return []
79
 
 
80
  query_embedding = self.model.encode([query])[0]
81
 
 
82
  D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
83
 
 
84
  results = []
85
  for i, idx in enumerate(I[0]):
86
  chunk_id = self.chunk_ids[idx]
87
  chunk = self.chunks[chunk_id]
88
 
 
89
  if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
90
  continue
91
 
 
96
  }
97
  results.append(result)
98
 
 
99
  if rerank and results:
 
100
  pairs = [(query, result['chunk']['content']) for result in results]
101
+
 
102
  rerank_scores = self.reranker.predict(pairs)
103
 
 
104
  for i, score in enumerate(rerank_scores):
105
  results[i]['rerank_score'] = float(score)
106
 
 
107
  results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
108
 
 
109
  results = results[:k]
110
 
111
  return results
 
114
  query: str,
115
  k: int = 5,
116
  filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
 
 
117
  vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
118
 
 
119
  keywords = query.lower().split()
 
 
120
  keyword_scores = {}
121
+
122
  for chunk_id, chunk_data in self.chunks.items():
123
  chunk = chunk_data
124
  content = (chunk['title'] + " " + chunk['content']).lower()
125
 
 
126
  score = sum(content.count(keyword) for keyword in keywords)
127
 
 
128
  if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
129
  continue
130
 
131
  keyword_scores[chunk_id] = score
132
 
 
133
  keyword_results = sorted(
134
  [{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
135
  for chunk_id, score in keyword_scores.items() if score > 0],
 
137
  reverse=True
138
  )[:k]
139
 
 
140
  seen_ids = set()
141
  combined_results = []
142
 
 
143
  for result in vector_results:
144
  combined_results.append(result)
145
  seen_ids.add(result['chunk_id'])
146
 
 
147
  for result in keyword_results:
148
  if result['chunk_id'] not in seen_ids:
149
  combined_results.append(result)
150
  seen_ids.add(result['chunk_id'])
151
 
 
152
  combined_results = combined_results[:k]
153
 
 
154
  if combined_results:
 
155
  pairs = [(query, result['chunk']['content']) for result in combined_results]
156
 
 
157
  rerank_scores = self.reranker.predict(pairs)
158
 
 
159
  for i, score in enumerate(rerank_scores):
160
  combined_results[i]['rerank_score'] = float(score)
161
 
 
162
  combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
163
 
164
+ return combined_results