File size: 8,508 Bytes
137c471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712bf59
137c471
 
 
 
 
 
 
 
 
 
 
 
 
0ddb79b
137c471
 
 
0ddb79b
712bf59
 
 
 
 
 
137c471
712bf59
 
 
 
 
 
 
137c471
156313b
 
 
 
0ddb79b
137c471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69ffcc3
 
137c471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69ffcc3
 
137c471
 
 
 
 
 
 
 
 
 
 
 
 
0ddb79b
 
137c471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab2a807
7c4e37b
0ddb79b
 
 
137c471
 
7c4e37b
a6471b0
137c471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from typing import List, Dict, Any, Optional
from qdrant_client.http import models as rest
from langchain.schema import Document
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('BAAI/bge-m3') 
import logging
import os
from .utils import getconfig
from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore
import sys

# Configure logging to be more verbose
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

# Load configuration
config = getconfig("params.cfg")

# Retriever settings from config
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))

# Reranker settings from config
RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False)
RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2")
RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5))
RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2))

# Initialize reranker if enabled
reranker = None
if RERANKER_ENABLED:
    try:
        print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True)
        logging.info(f"Initializing reranker with model: {RERANKER_MODEL}")
        
        print("Loading HuggingFace cross encoder model", flush=True)
        # HuggingFaceCrossEncoder doesn't accept cache_dir parameter
        # The underlying models will use default cache locations
        cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL)
        print("Cross encoder model loaded successfully", flush=True)
        
        print("Creating CrossEncoderReranker...", flush=True)
        reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K)
        print("Reranker initialized successfully", flush=True)
        logging.info("Reranker initialized successfully")
    except Exception as e:
        print(f"Failed to initialize reranker: {str(e)}", flush=True)
        logging.error(f"Failed to initialize reranker: {str(e)}")
        reranker = None
else:
    print("Reranker is disabled", flush=True)

def get_vectorstore() -> VectorStoreInterface:
    """
    Create and return a vector store connection.
    
    Returns:
        VectorStoreInterface instance
    """
    logging.info("Initializing vector store connection...")
    vectorstore = create_vectorstore(config)
    logging.info("Vector store connection initialized successfully")
    return vectorstore

def create_filter(
    filter_metadata:dict = None,
) -> Optional[rest.Filter]:
    """
    Create a Qdrant filter based on metadata criteria.
    
    Args:
        reports: List of specific report filenames to filter by
        sources: Source type to filter by
        subtype: Document subtype to filter by
        year: List of years to filter by
    
    Returns:
        Qdrant Filter object or None if no filters specified
    """
    if filter_metadata == None:
        return None
    
    conditions = []
    logging.info(f"Defining filters for {filter_metadata}")
    
    for key, val in filter_metadata.items():
        if isinstance(val, str):
            conditions.append(rest.FieldCondition(
            key=f"metadata.{key}",
            match=rest.MatchValue(value=val)
            )
        )
        else:
            conditions.append(
        rest.FieldCondition(
            key=f"metadata.{key}",
            match=rest.MatchAny(any=val)
        )
            )
    filter = rest.Filter(
        must = conditions
    )
    return filter


def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Rerank documents using cross-encoder (specify in params.cfg)
    
    Args:
        query: The search query
        documents: List of documents to rerank 
    
    Returns:
        Reranked list of documents in original format
    """
    if not reranker or not documents:
        return documents
    
    try:
        logging.info(f"Starting reranking of {len(documents)} documents")
        
        # Convert to LangChain Document format using correct keys (need to review this later for portability)
        langchain_docs = []
        for doc in documents:
            # Use correct keys from the data storage test module
            content = doc.get("answer", "")
            metadata = doc.get("answer_metadata", {})
            
            if not content:
                logging.warning(f"Document missing content: {doc}")
                continue
                
            langchain_doc = Document(
                page_content=content,
                metadata=metadata
            )
            langchain_docs.append(langchain_doc)
        
        if not langchain_docs:
            logging.warning("No valid documents found for reranking")
            return documents
        
        # Rerank documents
        logging.info(f"Reranking {len(langchain_docs)} documents")
        reranked_docs = reranker.compress_documents(langchain_docs, query)
        
        # Convert back to original format
        result = []
        for doc in reranked_docs:
            result.append({
                "answer": doc.page_content,
                "answer_metadata": doc.metadata,
            })
        
        logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}")
        return result
        
    except Exception as e:
        logging.error(f"Error during reranking: {str(e)}")
        # Return original documents if reranking fails
        return documents

def get_context(
    vectorstore: VectorStoreInterface,
    query: str,
    collection_name: str = None,
    filter_metadata = None,
) -> List[Dict[str, Any]]:
    """
    Retrieve semantically similar documents from the vector database with optional reranking.
    
    Args:
        vectorstore: The vector store interface to search
        query: The search query
        reports: List of specific report filenames to search within
        sources: Source type to filter by
        subtype: Document subtype to filter by
        year: List of years to filter by
    
    Returns:
        List of dictionaries with 'answer', 'answer_metadata', and 'score' keys
    """
    try:
        # Use a higher k for initial retrieval if reranking is enabled (more candidates docs)
        top_k = RETRIEVER_TOP_K
        if RERANKER_ENABLED and reranker:
            top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR
            logging.info(f"Reranking enabled, retrieving {top_k} candidates")
        
        search_kwargs = {
            "model_name": config.get("embeddings", "MODEL_NAME")
        }
        #model = SentenceTransformer(config.get("embeddings", "MODEL_NAME"))
        #query_vector = model.encode(query).tolist()
        #retrieved_docs = vectorstore.search(
        ##              collection_name="EUDR",
        #              query_vector=query_vector,
        #              limit=top_k,
        #              with_payload=True)
        # filter support for QdrantVectorStore
        if isinstance(vectorstore, QdrantVectorStore):
            print(filter_metadata)
            filter_obj = create_filter(filter_metadata)
            if filter_obj:
                search_kwargs["filter"] = filter_obj
        
        # Perform initial retrieval
        print(search_kwargs)
        retrieved_docs = vectorstore.search(query, collection_name,  top_k,  **search_kwargs)
        
        logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
        
        # Apply reranking if enabled
        if RERANKER_ENABLED and reranker and retrieved_docs:
            logging.info("Applying reranking...")
            retrieved_docs = rerank_documents(query, retrieved_docs)
            
            # Trim to final desired k
            retrieved_docs = retrieved_docs[:RERANKER_TOP_K]
        
        logging.info(f"Returning {len(retrieved_docs)} final documents")
        logging.info(f"Retrieved results: {retrieved_docs}")
        return retrieved_docs
        
    except Exception as e:
        logging.error(f"Error during retrieval: {str(e)}")
        raise e