Spaces:
Running
on
T4
Running
on
T4
from typing import List, Dict, Any, Optional | |
from qdrant_client.http import models as rest | |
from langchain.schema import Document | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from sentence_transformers import SentenceTransformer | |
model = SentenceTransformer('BAAI/bge-m3') | |
import logging | |
import os | |
from .utils import getconfig | |
from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore | |
import sys | |
# Configure logging to be more verbose | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
# Load configuration | |
config = getconfig("params.cfg") | |
# Retriever settings from config | |
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K")) | |
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD")) | |
# Reranker settings from config | |
RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False) | |
RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2") | |
RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5)) | |
RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2)) | |
# Initialize reranker if enabled | |
reranker = None | |
if RERANKER_ENABLED: | |
try: | |
print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True) | |
logging.info(f"Initializing reranker with model: {RERANKER_MODEL}") | |
print("Loading HuggingFace cross encoder model", flush=True) | |
# HuggingFaceCrossEncoder doesn't accept cache_dir parameter | |
# The underlying models will use default cache locations | |
cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL) | |
print("Cross encoder model loaded successfully", flush=True) | |
print("Creating CrossEncoderReranker...", flush=True) | |
reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K) | |
print("Reranker initialized successfully", flush=True) | |
logging.info("Reranker initialized successfully") | |
except Exception as e: | |
print(f"Failed to initialize reranker: {str(e)}", flush=True) | |
logging.error(f"Failed to initialize reranker: {str(e)}") | |
reranker = None | |
else: | |
print("Reranker is disabled", flush=True) | |
def get_vectorstore() -> VectorStoreInterface: | |
""" | |
Create and return a vector store connection. | |
Returns: | |
VectorStoreInterface instance | |
""" | |
logging.info("Initializing vector store connection...") | |
vectorstore = create_vectorstore(config) | |
logging.info("Vector store connection initialized successfully") | |
return vectorstore | |
def create_filter( | |
filter_metadata:dict = None, | |
) -> Optional[rest.Filter]: | |
""" | |
Create a Qdrant filter based on metadata criteria. | |
Args: | |
reports: List of specific report filenames to filter by | |
sources: Source type to filter by | |
subtype: Document subtype to filter by | |
year: List of years to filter by | |
Returns: | |
Qdrant Filter object or None if no filters specified | |
""" | |
if filter_metadata == None: | |
return None | |
conditions = [] | |
logging.info(f"Defining filters for {filter_metadata}") | |
for key, val in filter_metadata.items(): | |
if isinstance(val, str): | |
conditions.append(rest.FieldCondition( | |
key=f"metadata.{key}", | |
match=rest.MatchValue(value=val) | |
) | |
) | |
else: | |
conditions.append( | |
rest.FieldCondition( | |
key=f"metadata.{key}", | |
match=rest.MatchAny(any=val) | |
) | |
) | |
filter = rest.Filter( | |
must = conditions | |
) | |
return filter | |
def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Rerank documents using cross-encoder (specify in params.cfg) | |
Args: | |
query: The search query | |
documents: List of documents to rerank | |
Returns: | |
Reranked list of documents in original format | |
""" | |
if not reranker or not documents: | |
return documents | |
try: | |
logging.info(f"Starting reranking of {len(documents)} documents") | |
# Convert to LangChain Document format using correct keys (need to review this later for portability) | |
langchain_docs = [] | |
for doc in documents: | |
# Use correct keys from the data storage test module | |
content = doc.get("answer", "") | |
metadata = doc.get("answer_metadata", {}) | |
if not content: | |
logging.warning(f"Document missing content: {doc}") | |
continue | |
langchain_doc = Document( | |
page_content=content, | |
metadata=metadata | |
) | |
langchain_docs.append(langchain_doc) | |
if not langchain_docs: | |
logging.warning("No valid documents found for reranking") | |
return documents | |
# Rerank documents | |
logging.info(f"Reranking {len(langchain_docs)} documents") | |
reranked_docs = reranker.compress_documents(langchain_docs, query) | |
# Convert back to original format | |
result = [] | |
for doc in reranked_docs: | |
result.append({ | |
"answer": doc.page_content, | |
"answer_metadata": doc.metadata, | |
}) | |
logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}") | |
return result | |
except Exception as e: | |
logging.error(f"Error during reranking: {str(e)}") | |
# Return original documents if reranking fails | |
return documents | |
def get_context( | |
vectorstore: VectorStoreInterface, | |
query: str, | |
collection_name: str = None, | |
filter_metadata = None, | |
) -> List[Dict[str, Any]]: | |
""" | |
Retrieve semantically similar documents from the vector database with optional reranking. | |
Args: | |
vectorstore: The vector store interface to search | |
query: The search query | |
reports: List of specific report filenames to search within | |
sources: Source type to filter by | |
subtype: Document subtype to filter by | |
year: List of years to filter by | |
Returns: | |
List of dictionaries with 'answer', 'answer_metadata', and 'score' keys | |
""" | |
try: | |
# Use a higher k for initial retrieval if reranking is enabled (more candidates docs) | |
top_k = RETRIEVER_TOP_K | |
if RERANKER_ENABLED and reranker: | |
top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR | |
logging.info(f"Reranking enabled, retrieving {top_k} candidates") | |
search_kwargs = { | |
"model_name": config.get("embeddings", "MODEL_NAME") | |
} | |
#model = SentenceTransformer(config.get("embeddings", "MODEL_NAME")) | |
#query_vector = model.encode(query).tolist() | |
#retrieved_docs = vectorstore.search( | |
## collection_name="EUDR", | |
# query_vector=query_vector, | |
# limit=top_k, | |
# with_payload=True) | |
# filter support for QdrantVectorStore | |
if isinstance(vectorstore, QdrantVectorStore): | |
print(filter_metadata) | |
filter_obj = create_filter(filter_metadata) | |
if filter_obj: | |
search_kwargs["filter"] = filter_obj | |
# Perform initial retrieval | |
print(search_kwargs) | |
retrieved_docs = vectorstore.search(query, collection_name, top_k, **search_kwargs) | |
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...") | |
# Apply reranking if enabled | |
if RERANKER_ENABLED and reranker and retrieved_docs: | |
logging.info("Applying reranking...") | |
retrieved_docs = rerank_documents(query, retrieved_docs) | |
# Trim to final desired k | |
retrieved_docs = retrieved_docs[:RERANKER_TOP_K] | |
logging.info(f"Returning {len(retrieved_docs)} final documents") | |
logging.info(f"Retrieved results: {retrieved_docs}") | |
return retrieved_docs | |
except Exception as e: | |
logging.error(f"Error during retrieval: {str(e)}") | |
raise e |