| import os | |
| import numpy as np | |
| import hnswlib | |
| from typing import List, Dict, Any | |
| from sentence_transformers import SentenceTransformer | |
| from rank_bm25 import BM25Okapi | |
| from src.config import RetrieverConfig | |
| from src.utils import logger | |
| class Retriever: | |
| """ | |
| Hybrid retriever combining BM25 sparse and dense retrieval (no Redis). | |
| """ | |
| def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig): | |
| """ | |
| Initialize the retriever with chunks and configuration. | |
| Args: | |
| chunks (List[Dict[str, Any]]): List of chunks, where each chunk is a dictionary. | |
| config (RetrieverConfig): Configuration for the retriever. | |
| """ | |
| self.chunks = chunks | |
| try: | |
| if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks): | |
| logger.error("Chunks must be a list of dicts.") | |
| raise ValueError("Chunks must be a list of dicts.") | |
| corpus = [c.get('narration', '').split() for c in chunks] | |
| self.bm25 = BM25Okapi(corpus) | |
| self.embedder = SentenceTransformer(config.DENSE_MODEL) | |
| dim = len(self.embedder.encode(["test"])[0]) | |
| self.ann = hnswlib.Index(space='cosine', dim=dim) | |
| self.ann.init_index(max_elements=len(chunks)) | |
| embeddings = self.embedder.encode([c.get('narration', '') for c in chunks]) | |
| self.ann.add_items(embeddings, ids=list(range(len(chunks)))) | |
| self.ann.set_ef(config.ANN_TOP) | |
| except Exception as e: | |
| logger.error(f"Retriever init failed: {e}") | |
| self.bm25 = None | |
| self.embedder = None | |
| self.ann = None | |
| def retrieve_sparse(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve chunks using BM25 sparse retrieval. | |
| Args: | |
| query (str): Query string. | |
| top_k (int): Number of top chunks to return. | |
| Returns: | |
| List[Dict[str, Any]]: List of top chunks. | |
| """ | |
| if not self.bm25: | |
| logger.error("BM25 not initialized.") | |
| return [] | |
| tokenized = query.split() | |
| try: | |
| scores = self.bm25.get_scores(tokenized) | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| return [self.chunks[i] for i in top_indices] | |
| except Exception as e: | |
| logger.error(f"Sparse retrieval failed: {e}") | |
| return [] | |
| def retrieve_dense(self, query: str, top_k: int) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve chunks using dense retrieval. | |
| Args: | |
| query (str): Query string. | |
| top_k (int): Number of top chunks to return. | |
| Returns: | |
| List[Dict[str, Any]]: List of top chunks. | |
| """ | |
| if not self.ann or not self.embedder: | |
| logger.error("Dense retriever not initialized.") | |
| return [] | |
| try: | |
| q_emb = self.embedder.encode([query])[0] | |
| labels, distances = self.ann.knn_query(q_emb, k=top_k) | |
| return [self.chunks[i] for i in labels[0]] | |
| except Exception as e: | |
| logger.error(f"Dense retrieval failed: {e}") | |
| return [] | |
| def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve chunks using hybrid retrieval. | |
| Args: | |
| query (str): Query string. | |
| top_k (int, optional): Number of top chunks to return. Defaults to None. | |
| Returns: | |
| List[Dict[str, Any]]: List of top chunks. | |
| """ | |
| if top_k is None: | |
| top_k = RetrieverConfig.TOP_K | |
| sparse = self.retrieve_sparse(query, top_k) | |
| dense = self.retrieve_dense(query, top_k) | |
| seen = set() | |
| combined = [] | |
| for c in sparse + dense: | |
| cid = id(c) | |
| if cid not in seen: | |
| seen.add(cid) | |
| combined.append(c) | |
| return combined |