ppsingh's picture
Update utils/retriever.py
69ffcc3 verified
from typing import List, Dict, Any, Optional
from qdrant_client.http import models as rest
from langchain.schema import Document
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('BAAI/bge-m3')
import logging
import os
from .utils import getconfig
from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore
import sys
# Configure logging to be more verbose
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
# Load configuration
config = getconfig("params.cfg")
# Retriever settings from config
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
# Reranker settings from config
RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False)
RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2")
RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5))
RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2))
# Initialize reranker if enabled
reranker = None
if RERANKER_ENABLED:
try:
print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True)
logging.info(f"Initializing reranker with model: {RERANKER_MODEL}")
print("Loading HuggingFace cross encoder model", flush=True)
# HuggingFaceCrossEncoder doesn't accept cache_dir parameter
# The underlying models will use default cache locations
cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL)
print("Cross encoder model loaded successfully", flush=True)
print("Creating CrossEncoderReranker...", flush=True)
reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K)
print("Reranker initialized successfully", flush=True)
logging.info("Reranker initialized successfully")
except Exception as e:
print(f"Failed to initialize reranker: {str(e)}", flush=True)
logging.error(f"Failed to initialize reranker: {str(e)}")
reranker = None
else:
print("Reranker is disabled", flush=True)
def get_vectorstore() -> VectorStoreInterface:
"""
Create and return a vector store connection.
Returns:
VectorStoreInterface instance
"""
logging.info("Initializing vector store connection...")
vectorstore = create_vectorstore(config)
logging.info("Vector store connection initialized successfully")
return vectorstore
def create_filter(
filter_metadata:dict = None,
) -> Optional[rest.Filter]:
"""
Create a Qdrant filter based on metadata criteria.
Args:
reports: List of specific report filenames to filter by
sources: Source type to filter by
subtype: Document subtype to filter by
year: List of years to filter by
Returns:
Qdrant Filter object or None if no filters specified
"""
if filter_metadata == None:
return None
conditions = []
logging.info(f"Defining filters for {filter_metadata}")
for key, val in filter_metadata.items():
if isinstance(val, str):
conditions.append(rest.FieldCondition(
key=f"metadata.{key}",
match=rest.MatchValue(value=val)
)
)
else:
conditions.append(
rest.FieldCondition(
key=f"metadata.{key}",
match=rest.MatchAny(any=val)
)
)
filter = rest.Filter(
must = conditions
)
return filter
def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Rerank documents using cross-encoder (specify in params.cfg)
Args:
query: The search query
documents: List of documents to rerank
Returns:
Reranked list of documents in original format
"""
if not reranker or not documents:
return documents
try:
logging.info(f"Starting reranking of {len(documents)} documents")
# Convert to LangChain Document format using correct keys (need to review this later for portability)
langchain_docs = []
for doc in documents:
# Use correct keys from the data storage test module
content = doc.get("answer", "")
metadata = doc.get("answer_metadata", {})
if not content:
logging.warning(f"Document missing content: {doc}")
continue
langchain_doc = Document(
page_content=content,
metadata=metadata
)
langchain_docs.append(langchain_doc)
if not langchain_docs:
logging.warning("No valid documents found for reranking")
return documents
# Rerank documents
logging.info(f"Reranking {len(langchain_docs)} documents")
reranked_docs = reranker.compress_documents(langchain_docs, query)
# Convert back to original format
result = []
for doc in reranked_docs:
result.append({
"answer": doc.page_content,
"answer_metadata": doc.metadata,
})
logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}")
return result
except Exception as e:
logging.error(f"Error during reranking: {str(e)}")
# Return original documents if reranking fails
return documents
def get_context(
vectorstore: VectorStoreInterface,
query: str,
collection_name: str = None,
filter_metadata = None,
) -> List[Dict[str, Any]]:
"""
Retrieve semantically similar documents from the vector database with optional reranking.
Args:
vectorstore: The vector store interface to search
query: The search query
reports: List of specific report filenames to search within
sources: Source type to filter by
subtype: Document subtype to filter by
year: List of years to filter by
Returns:
List of dictionaries with 'answer', 'answer_metadata', and 'score' keys
"""
try:
# Use a higher k for initial retrieval if reranking is enabled (more candidates docs)
top_k = RETRIEVER_TOP_K
if RERANKER_ENABLED and reranker:
top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR
logging.info(f"Reranking enabled, retrieving {top_k} candidates")
search_kwargs = {
"model_name": config.get("embeddings", "MODEL_NAME")
}
#model = SentenceTransformer(config.get("embeddings", "MODEL_NAME"))
#query_vector = model.encode(query).tolist()
#retrieved_docs = vectorstore.search(
## collection_name="EUDR",
# query_vector=query_vector,
# limit=top_k,
# with_payload=True)
# filter support for QdrantVectorStore
if isinstance(vectorstore, QdrantVectorStore):
print(filter_metadata)
filter_obj = create_filter(filter_metadata)
if filter_obj:
search_kwargs["filter"] = filter_obj
# Perform initial retrieval
print(search_kwargs)
retrieved_docs = vectorstore.search(query, collection_name, top_k, **search_kwargs)
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
# Apply reranking if enabled
if RERANKER_ENABLED and reranker and retrieved_docs:
logging.info("Applying reranking...")
retrieved_docs = rerank_documents(query, retrieved_docs)
# Trim to final desired k
retrieved_docs = retrieved_docs[:RERANKER_TOP_K]
logging.info(f"Returning {len(retrieved_docs)} final documents")
logging.info(f"Retrieved results: {retrieved_docs}")
return retrieved_docs
except Exception as e:
logging.error(f"Error during retrieval: {str(e)}")
raise e