|
import os |
|
import pickle |
|
import json |
|
import numpy as np |
|
from typing import List, Dict, Any, Optional, Tuple |
|
import faiss |
|
from tqdm import tqdm |
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
|
class VectorStore: |
|
def __init__(self, |
|
embedding_dir: str = "data/embeddings", |
|
model_name: str = "BAAI/bge-small-en-v1.5", |
|
reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): |
|
self.embedding_dir = embedding_dir |
|
self.index = None |
|
self.chunk_ids = [] |
|
self.chunks = {} |
|
|
|
self.model = SentenceTransformer(model_name) |
|
self.reranker = CrossEncoder(reranker_name) |
|
|
|
self.load_or_create_index() |
|
|
|
def load_or_create_index(self) -> None: |
|
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl') |
|
|
|
if os.path.exists(index_path): |
|
with open(index_path, 'rb') as f: |
|
data = pickle.load(f) |
|
self.index = data['index'] |
|
self.chunk_ids = data['chunk_ids'] |
|
self.chunks = data['chunks'] |
|
print(f"Loaded existing index with {len(self.chunk_ids)} chunks") |
|
else: |
|
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') |
|
if os.path.exists(embeddings_path): |
|
self.create_index() |
|
else: |
|
print("No embeddings found. Please run the chunker first.") |
|
|
|
def create_index(self) -> None: |
|
"""Create FAISS index from embeddings.""" |
|
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') |
|
|
|
with open(embeddings_path, 'rb') as f: |
|
embedding_map = pickle.load(f) |
|
|
|
chunk_ids = list(embedding_map.keys()) |
|
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids]) |
|
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids} |
|
|
|
dimension = embeddings.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(embeddings.astype(np.float32)) |
|
|
|
self.index = index |
|
self.chunk_ids = chunk_ids |
|
self.chunks = chunks |
|
|
|
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f: |
|
pickle.dump({ |
|
'index': index, |
|
'chunk_ids': chunk_ids, |
|
'chunks': chunks |
|
}, f) |
|
|
|
print(f"Created index with {len(chunk_ids)} chunks") |
|
|
|
def search(self, |
|
query: str, |
|
k: int = 5, |
|
filter_categories: Optional[List[str]] = None, |
|
rerank: bool = True) -> List[Dict[str, Any]]: |
|
|
|
if self.index is None: |
|
print("No index available. Please create an index first.") |
|
return [] |
|
|
|
query_embedding = self.model.encode([query])[0] |
|
|
|
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids))) |
|
|
|
results = [] |
|
for i, idx in enumerate(I[0]): |
|
chunk_id = self.chunk_ids[idx] |
|
chunk = self.chunks[chunk_id] |
|
|
|
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): |
|
continue |
|
|
|
result = { |
|
'chunk_id': chunk_id, |
|
'score': float(D[0][i]), |
|
'chunk': chunk |
|
} |
|
results.append(result) |
|
|
|
if rerank and results: |
|
pairs = [(query, result['chunk']['content']) for result in results] |
|
|
|
rerank_scores = self.reranker.predict(pairs) |
|
|
|
for i, score in enumerate(rerank_scores): |
|
results[i]['rerank_score'] = float(score) |
|
|
|
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True) |
|
|
|
results = results[:k] |
|
|
|
return results |
|
|
|
def hybrid_search(self, |
|
query: str, |
|
k: int = 5, |
|
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]: |
|
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False) |
|
|
|
keywords = query.lower().split() |
|
keyword_scores = {} |
|
|
|
for chunk_id, chunk_data in self.chunks.items(): |
|
chunk = chunk_data |
|
content = (chunk['title'] + " " + chunk['content']).lower() |
|
|
|
score = sum(content.count(keyword) for keyword in keywords) |
|
|
|
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): |
|
continue |
|
|
|
keyword_scores[chunk_id] = score |
|
|
|
keyword_results = sorted( |
|
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]} |
|
for chunk_id, score in keyword_scores.items() if score > 0], |
|
key=lambda x: x['score'], |
|
reverse=True |
|
)[:k] |
|
|
|
seen_ids = set() |
|
combined_results = [] |
|
|
|
for result in vector_results: |
|
combined_results.append(result) |
|
seen_ids.add(result['chunk_id']) |
|
|
|
for result in keyword_results: |
|
if result['chunk_id'] not in seen_ids: |
|
combined_results.append(result) |
|
seen_ids.add(result['chunk_id']) |
|
|
|
combined_results = combined_results[:k] |
|
|
|
if combined_results: |
|
pairs = [(query, result['chunk']['content']) for result in combined_results] |
|
|
|
rerank_scores = self.reranker.predict(pairs) |
|
|
|
for i, score in enumerate(rerank_scores): |
|
combined_results[i]['rerank_score'] = float(score) |
|
|
|
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True) |
|
|
|
return combined_results |