Spaces:
Build error
Build error
import os | |
import numpy as np | |
import hnswlib | |
from typing import List, Dict, Any | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from src.config import RetrieverConfig | |
from src.utils import logger | |
class Retriever: | |
""" | |
Hybrid retriever combining BM25 sparse and dense retrieval (no Redis). | |
""" | |
def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig): | |
""" | |
Initialize the retriever with chunks and configuration. | |
Args: | |
chunks (List[Dict[str, Any]]): List of chunks, where each chunk is a dictionary. | |
config (RetrieverConfig): Configuration for the retriever. | |
""" | |
self.chunks = chunks | |
try: | |
if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks): | |
logger.error("Chunks must be a list of dicts.") | |
raise ValueError("Chunks must be a list of dicts.") | |
corpus = [c.get('narration', '').split() for c in chunks] | |
self.bm25 = BM25Okapi(corpus) | |
self.embedder = SentenceTransformer(config.DENSE_MODEL) | |
dim = len(self.embedder.encode(["test"])[0]) | |
self.ann = hnswlib.Index(space='cosine', dim=dim) | |
self.ann.init_index(max_elements=len(chunks)) | |
embeddings = self.embedder.encode([c.get('narration', '') for c in chunks]) | |
self.ann.add_items(embeddings, ids=list(range(len(chunks)))) | |
self.ann.set_ef(config.ANN_TOP) | |
except Exception as e: | |
logger.error(f"Retriever init failed: {e}") | |
self.bm25 = None | |
self.embedder = None | |
self.ann = None | |
def retrieve_sparse(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
""" | |
Retrieve chunks using BM25 sparse retrieval. | |
Args: | |
query (str): Query string. | |
top_k (int): Number of top chunks to return. | |
Returns: | |
List[Dict[str, Any]]: List of top chunks. | |
""" | |
if not self.bm25: | |
logger.error("BM25 not initialized.") | |
return [] | |
tokenized = query.split() | |
try: | |
scores = self.bm25.get_scores(tokenized) | |
top_indices = np.argsort(scores)[::-1][:top_k] | |
return [self.chunks[i] for i in top_indices] | |
except Exception as e: | |
logger.error(f"Sparse retrieval failed: {e}") | |
return [] | |
def retrieve_dense(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
""" | |
Retrieve chunks using dense retrieval. | |
Args: | |
query (str): Query string. | |
top_k (int): Number of top chunks to return. | |
Returns: | |
List[Dict[str, Any]]: List of top chunks. | |
""" | |
if not self.ann or not self.embedder: | |
logger.error("Dense retriever not initialized.") | |
return [] | |
try: | |
q_emb = self.embedder.encode([query])[0] | |
labels, distances = self.ann.knn_query(q_emb, k=top_k) | |
return [self.chunks[i] for i in labels[0]] | |
except Exception as e: | |
logger.error(f"Dense retrieval failed: {e}") | |
return [] | |
def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]: | |
""" | |
Retrieve chunks using hybrid retrieval. | |
Args: | |
query (str): Query string. | |
top_k (int, optional): Number of top chunks to return. Defaults to None. | |
Returns: | |
List[Dict[str, Any]]: List of top chunks. | |
""" | |
if top_k is None: | |
top_k = RetrieverConfig.TOP_K | |
sparse = self.retrieve_sparse(query, top_k) | |
dense = self.retrieve_dense(query, top_k) | |
seen = set() | |
combined = [] | |
for c in sparse + dense: | |
cid = id(c) | |
if cid not in seen: | |
seen.add(cid) | |
combined.append(c) | |
return combined |