File size: 7,958 Bytes
9108a9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import os
import pickle
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, CrossEncoder
class VectorStore:
def __init__(self,
embedding_dir: str = "data/embeddings",
model_name: str = "BAAI/bge-small-en-v1.5",
reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.embedding_dir = embedding_dir
self.index = None
self.chunk_ids = []
self.chunks = {}
# Load embedding model
self.model = SentenceTransformer(model_name)
# Load reranker model
self.reranker = CrossEncoder(reranker_name)
# Load or create index
self.load_or_create_index()
def load_or_create_index(self) -> None:
"""Load existing index or create a new one."""
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
if os.path.exists(index_path):
# Load existing index
with open(index_path, 'rb') as f:
data = pickle.load(f)
self.index = data['index']
self.chunk_ids = data['chunk_ids']
self.chunks = data['chunks']
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
else:
# Create new index
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
if os.path.exists(embeddings_path):
self.create_index()
else:
print("No embeddings found. Please run the chunker first.")
def create_index(self) -> None:
"""Create FAISS index from embeddings."""
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
with open(embeddings_path, 'rb') as f:
embedding_map = pickle.load(f)
# Extract embeddings and chunk IDs
chunk_ids = list(embedding_map.keys())
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
# Create FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
# Save index and metadata
self.index = index
self.chunk_ids = chunk_ids
self.chunks = chunks
# Save to disk
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
pickle.dump({
'index': index,
'chunk_ids': chunk_ids,
'chunks': chunks
}, f)
print(f"Created index with {len(chunk_ids)} chunks")
def search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None,
rerank: bool = True) -> List[Dict[str, Any]]:
"""Search for relevant chunks."""
if self.index is None:
print("No index available. Please create an index first.")
return []
# Create query embedding
query_embedding = self.model.encode([query])[0]
# Search index
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
# Get results
results = []
for i, idx in enumerate(I[0]):
chunk_id = self.chunk_ids[idx]
chunk = self.chunks[chunk_id]
# Apply category filter if specified
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
result = {
'chunk_id': chunk_id,
'score': float(D[0][i]),
'chunk': chunk
}
results.append(result)
# Rerank results if requested
if rerank and results:
# Prepare pairs for reranking
pairs = [(query, result['chunk']['content']) for result in results]
# Get reranking scores
rerank_scores = self.reranker.predict(pairs)
# Update scores and sort
for i, score in enumerate(rerank_scores):
results[i]['rerank_score'] = float(score)
# Sort by rerank score
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
# Limit to k results
results = results[:k]
return results
def hybrid_search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Combine dense vector search with BM25-style keyword matching."""
# Get vector search results
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
# Simple keyword matching (simulating BM25)
keywords = query.lower().split()
# Score all chunks by keyword presence
keyword_scores = {}
for chunk_id, chunk_data in self.chunks.items():
chunk = chunk_data
content = (chunk['title'] + " " + chunk['content']).lower()
# Count keyword matches
score = sum(content.count(keyword) for keyword in keywords)
# Apply category filter if specified
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
keyword_scores[chunk_id] = score
# Get top keyword matches
keyword_results = sorted(
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
for chunk_id, score in keyword_scores.items() if score > 0],
key=lambda x: x['score'],
reverse=True
)[:k]
# Combine results (remove duplicates)
seen_ids = set()
combined_results = []
# Add vector results first
for result in vector_results:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
# Add keyword results if not already added
for result in keyword_results:
if result['chunk_id'] not in seen_ids:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
# Limit to k results
combined_results = combined_results[:k]
# Rerank final results
if combined_results:
# Prepare pairs for reranking
pairs = [(query, result['chunk']['content']) for result in combined_results]
# Get reranking scores
rerank_scores = self.reranker.predict(pairs)
# Update scores and sort
for i, score in enumerate(rerank_scores):
combined_results[i]['rerank_score'] = float(score)
# Sort by rerank score
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
return combined_results
# Example usage
if __name__ == "__main__":
vector_store = VectorStore()
results = vector_store.hybrid_search("How do I apply for OPT?")
print(f"Found {len(results)} results")
for i, result in enumerate(results[:3]):
print(f"Result {i+1}: {result['chunk']['title']}")
print(f"Score: {result.get('rerank_score', result['score'])}")
print(f"Content: {result['chunk']['content'][:100]}...")
print() |