from typing import List, Dict, Any, Optional from qdrant_client.http import models as rest from langchain.schema import Document from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain.retrievers.document_compressors import CrossEncoderReranker from sentence_transformers import SentenceTransformer model = SentenceTransformer('BAAI/bge-m3') import logging import os from .utils import getconfig from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore import sys # Configure logging to be more verbose logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) # Load configuration config = getconfig("params.cfg") # Retriever settings from config RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K")) SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD")) # Reranker settings from config RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False) RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2") RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5)) RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2)) # Initialize reranker if enabled reranker = None if RERANKER_ENABLED: try: print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True) logging.info(f"Initializing reranker with model: {RERANKER_MODEL}") print("Loading HuggingFace cross encoder model", flush=True) # HuggingFaceCrossEncoder doesn't accept cache_dir parameter # The underlying models will use default cache locations cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL) print("Cross encoder model loaded successfully", flush=True) print("Creating CrossEncoderReranker...", flush=True) reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K) print("Reranker initialized successfully", flush=True) logging.info("Reranker initialized successfully") except Exception as e: print(f"Failed to initialize reranker: {str(e)}", flush=True) logging.error(f"Failed to initialize reranker: {str(e)}") reranker = None else: print("Reranker is disabled", flush=True) def get_vectorstore() -> VectorStoreInterface: """ Create and return a vector store connection. Returns: VectorStoreInterface instance """ logging.info("Initializing vector store connection...") vectorstore = create_vectorstore(config) logging.info("Vector store connection initialized successfully") return vectorstore def create_filter( filter_metadata:dict = None, ) -> Optional[rest.Filter]: """ Create a Qdrant filter based on metadata criteria. Args: reports: List of specific report filenames to filter by sources: Source type to filter by subtype: Document subtype to filter by year: List of years to filter by Returns: Qdrant Filter object or None if no filters specified """ if filter_metadata == None: return None conditions = [] logging.info(f"Defining filters for {filter_metadata}") for key, val in filter_metadata.items(): if isinstance(val, str): conditions.append(rest.FieldCondition( key=f"metadata.{key}", match=rest.MatchValue(value=val) ) ) else: conditions.append( rest.FieldCondition( key=f"metadata.{key}", match=rest.MatchAny(any=val) ) ) filter = rest.Filter( must = conditions ) return filter def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Rerank documents using cross-encoder (specify in params.cfg) Args: query: The search query documents: List of documents to rerank Returns: Reranked list of documents in original format """ if not reranker or not documents: return documents try: logging.info(f"Starting reranking of {len(documents)} documents") # Convert to LangChain Document format using correct keys (need to review this later for portability) langchain_docs = [] for doc in documents: # Use correct keys from the data storage test module content = doc.get("answer", "") metadata = doc.get("answer_metadata", {}) if not content: logging.warning(f"Document missing content: {doc}") continue langchain_doc = Document( page_content=content, metadata=metadata ) langchain_docs.append(langchain_doc) if not langchain_docs: logging.warning("No valid documents found for reranking") return documents # Rerank documents logging.info(f"Reranking {len(langchain_docs)} documents") reranked_docs = reranker.compress_documents(langchain_docs, query) # Convert back to original format result = [] for doc in reranked_docs: result.append({ "answer": doc.page_content, "answer_metadata": doc.metadata, }) logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}") return result except Exception as e: logging.error(f"Error during reranking: {str(e)}") # Return original documents if reranking fails return documents def get_context( vectorstore: VectorStoreInterface, query: str, collection_name: str = None, filter_metadata = None, ) -> List[Dict[str, Any]]: """ Retrieve semantically similar documents from the vector database with optional reranking. Args: vectorstore: The vector store interface to search query: The search query reports: List of specific report filenames to search within sources: Source type to filter by subtype: Document subtype to filter by year: List of years to filter by Returns: List of dictionaries with 'answer', 'answer_metadata', and 'score' keys """ try: # Use a higher k for initial retrieval if reranking is enabled (more candidates docs) top_k = RETRIEVER_TOP_K if RERANKER_ENABLED and reranker: top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR logging.info(f"Reranking enabled, retrieving {top_k} candidates") search_kwargs = { "model_name": config.get("embeddings", "MODEL_NAME") } #model = SentenceTransformer(config.get("embeddings", "MODEL_NAME")) #query_vector = model.encode(query).tolist() #retrieved_docs = vectorstore.search( ## collection_name="EUDR", # query_vector=query_vector, # limit=top_k, # with_payload=True) # filter support for QdrantVectorStore if isinstance(vectorstore, QdrantVectorStore): print(filter_metadata) filter_obj = create_filter(filter_metadata) if filter_obj: search_kwargs["filter"] = filter_obj # Perform initial retrieval print(search_kwargs) retrieved_docs = vectorstore.search(query, collection_name, top_k, **search_kwargs) logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...") # Apply reranking if enabled if RERANKER_ENABLED and reranker and retrieved_docs: logging.info("Applying reranking...") retrieved_docs = rerank_documents(query, retrieved_docs) # Trim to final desired k retrieved_docs = retrieved_docs[:RERANKER_TOP_K] logging.info(f"Returning {len(retrieved_docs)} final documents") logging.info(f"Retrieved results: {retrieved_docs}") return retrieved_docs except Exception as e: logging.error(f"Error during retrieval: {str(e)}") raise e