Document_intelligence / src /retriever.py
Abhinav Gavireddi
fix: fixed bugs in UI
04db7e0
raw
history blame
3.93 kB
import os
import numpy as np
import hnswlib
from typing import List, Dict, Any
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from src.config import RetrieverConfig
from src.utils import logger
class Retriever:
"""
Hybrid retriever combining BM25 sparse and dense retrieval (no Redis).
"""
def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig):
"""
Initialize the retriever with chunks and configuration.
Args:
chunks (List[Dict[str, Any]]): List of chunks, where each chunk is a dictionary.
config (RetrieverConfig): Configuration for the retriever.
"""
self.chunks = chunks
try:
if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks):
logger.error("Chunks must be a list of dicts.")
raise ValueError("Chunks must be a list of dicts.")
corpus = [c.get('narration', '').split() for c in chunks]
self.bm25 = BM25Okapi(corpus)
self.embedder = SentenceTransformer(config.DENSE_MODEL)
dim = len(self.embedder.encode(["test"])[0])
self.ann = hnswlib.Index(space='cosine', dim=dim)
self.ann.init_index(max_elements=len(chunks))
embeddings = self.embedder.encode([c.get('narration', '') for c in chunks])
self.ann.add_items(embeddings, ids=list(range(len(chunks))))
self.ann.set_ef(config.ANN_TOP)
except Exception as e:
logger.error(f"Retriever init failed: {e}")
self.bm25 = None
self.embedder = None
self.ann = None
def retrieve_sparse(self, query: str, top_k: int) -> List[Dict[str, Any]]:
"""
Retrieve chunks using BM25 sparse retrieval.
Args:
query (str): Query string.
top_k (int): Number of top chunks to return.
Returns:
List[Dict[str, Any]]: List of top chunks.
"""
if not self.bm25:
logger.error("BM25 not initialized.")
return []
tokenized = query.split()
try:
scores = self.bm25.get_scores(tokenized)
top_indices = np.argsort(scores)[::-1][:top_k]
return [self.chunks[i] for i in top_indices]
except Exception as e:
logger.error(f"Sparse retrieval failed: {e}")
return []
def retrieve_dense(self, query: str, top_k: int) -> List[Dict[str, Any]]:
"""
Retrieve chunks using dense retrieval.
Args:
query (str): Query string.
top_k (int): Number of top chunks to return.
Returns:
List[Dict[str, Any]]: List of top chunks.
"""
if not self.ann or not self.embedder:
logger.error("Dense retriever not initialized.")
return []
try:
q_emb = self.embedder.encode([query])[0]
labels, distances = self.ann.knn_query(q_emb, k=top_k)
return [self.chunks[i] for i in labels[0]]
except Exception as e:
logger.error(f"Dense retrieval failed: {e}")
return []
def retrieve(self, query: str, top_k: int = None) -> List[Dict[str, Any]]:
"""
Retrieve chunks using hybrid retrieval.
Args:
query (str): Query string.
top_k (int, optional): Number of top chunks to return. Defaults to None.
Returns:
List[Dict[str, Any]]: List of top chunks.
"""
if top_k is None:
top_k = RetrieverConfig.TOP_K
sparse = self.retrieve_sparse(query, top_k)
dense = self.retrieve_dense(query, top_k)
seen = set()
combined = []
for c in sparse + dense:
cid = id(c)
if cid not in seen:
seen.add(cid)
combined.append(c)
return combined