ppsingh commited on
Commit
137c471
·
verified ·
1 Parent(s): e7709a4

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +121 -0
  2. params.cfg +26 -0
  3. requirements.txt +121 -0
  4. utils/retriever.py +257 -0
  5. utils/utils.py +107 -0
  6. utils/vectorstore_interface.py +171 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ from utils.retriever import get_context, get_vectorstore
4
+
5
+ # Initialize vector store at startup
6
+ print("Initializing vector store connection...", flush=True)
7
+ try:
8
+ vectorstore = get_vectorstore()
9
+ print("Vector store connection initialized successfully", flush=True)
10
+ except Exception as e:
11
+ print(f"Failed to initialize vector store: {e}", flush=True)
12
+ raise
13
+
14
+ # ---------------------------------------------------------------------
15
+ # MCP - returns raw dictionary format
16
+ # ---------------------------------------------------------------------
17
+
18
+ def retrieve(
19
+ query: str,
20
+ reports_filter: str = "",
21
+ sources_filter: str = "",
22
+ subtype_filter: str = "",
23
+ year_filter: str = ""
24
+ ) -> list:
25
+ """
26
+ Retrieve semantically similar documents from the vector database for MCP clients.
27
+
28
+ Args:
29
+ query (str): The search query text
30
+ reports_filter (str): Comma-separated list of specific report filenames (optional)
31
+ sources_filter (str): Filter by document source type (optional)
32
+ subtype_filter (str): Filter by document subtype (optional)
33
+ year_filter (str): Comma-separated list of years to filter by (optional)
34
+
35
+ Returns:
36
+ list: List of dictionaries containing document content, metadata, and scores
37
+ """
38
+ # Parse filter inputs (convert empty strings to None or lists)
39
+ reports = [r.strip() for r in reports_filter.split(",") if r.strip()] if reports_filter else []
40
+ sources = sources_filter.strip() if sources_filter else None
41
+ subtype = subtype_filter.strip() if subtype_filter else None
42
+ year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
43
+
44
+ # Call retriever function and return raw results
45
+ results = get_context(
46
+ vectorstore=vectorstore,
47
+ query=query,
48
+ reports=reports,
49
+ sources=sources,
50
+ subtype=subtype,
51
+ year=year
52
+ )
53
+
54
+ return results
55
+
56
+
57
+ # Create the Gradio interface with Blocks to support both UI and MCP
58
+ with gr.Blocks() as ui:
59
+ gr.Markdown("# ChatFed Retrieval/Reranker Module")
60
+ gr.Markdown("Retrieves semantically similar documents from vector database and reranks. Intended for use in RAG pipelines as an MCP server with other ChatFed modules.")
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ query_input = gr.Textbox(
65
+ label="Query",
66
+ lines=2,
67
+ placeholder="Enter your search query here",
68
+ info="The query to search for in the vector database"
69
+ )
70
+ reports_input = gr.Textbox(
71
+ label="Reports Filter (optional)",
72
+ lines=1,
73
+ placeholder="report1.pdf, report2.pdf",
74
+ info="Comma-separated list of specific report filenames to search within (leave empty for all)"
75
+ )
76
+ sources_input = gr.Textbox(
77
+ label="Sources Filter (optional)",
78
+ lines=1,
79
+ placeholder="annual_report",
80
+ info="Filter by document source type (leave empty for all)"
81
+ )
82
+ subtype_input = gr.Textbox(
83
+ label="Subtype Filter (optional)",
84
+ lines=1,
85
+ placeholder="financial",
86
+ info="Filter by document subtype (leave empty for all)"
87
+ )
88
+ year_input = gr.Textbox(
89
+ label="Year Filter (optional)",
90
+ lines=1,
91
+ placeholder="2023, 2024",
92
+ info="Comma-separated list of years to filter by (leave empty for all)"
93
+ )
94
+
95
+ submit_btn = gr.Button("Submit", variant="primary")
96
+
97
+ # Output needs to be in json format to be added as tool in HuggingChat
98
+ with gr.Column():
99
+ output = gr.Text(
100
+ label="Retrieved Context",
101
+ lines=10,
102
+ show_copy_button=True
103
+ )
104
+
105
+ # UI event handler
106
+ submit_btn.click(
107
+ fn=retrieve,
108
+ inputs=[query_input, reports_input, sources_input, subtype_input, year_input],
109
+ outputs=output,
110
+ api_name="retrieve"
111
+ )
112
+
113
+
114
+ # Launch with MCP server enabled
115
+ if __name__ == "__main__":
116
+ ui.launch(
117
+ server_name="0.0.0.0",
118
+ server_port=7860,
119
+ #mcp_server=True,
120
+ show_error=True
121
+ )
params.cfg ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [vectorstore]
2
+ # huggingface_spaces usage:
3
+ # PROVIDER = huggingface
4
+ # URL = GIZ/audit_data
5
+ # COLLECTION_NAME = docling
6
+
7
+ # direct Qdrant usage:
8
+ PROVIDER = qdrant
9
+ URL = giz-chatfed-qdrantserver.hf.space
10
+ COLLECTION_NAME = EUDR
11
+
12
+ [embeddings]
13
+ MODEL_NAME = BAAI/bge-m3
14
+ # DEVICE = cpu
15
+
16
+ [retriever]
17
+ TOP_K = 10
18
+ SCORE_THRESHOLD = 0.6
19
+
20
+ [reranker]
21
+ MODEL_NAME = cross-encoder/ms-marco-MiniLM-L-6-v2
22
+ TOP_K = 5
23
+ ENABLED = false
24
+ # use this to scale out the total docs retrieved prior to reranking (i.e. retriever top_k * TOP_K_SCALE_FACTOR)
25
+ TOP_K_SCALE_FACTOR = 2
26
+
requirements.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.14
4
+ aiosignal==1.4.0
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ certifi==2025.7.14
9
+ charset-normalizer==3.4.2
10
+ click==8.2.1
11
+ contourpy==1.3.2
12
+ cycler==0.12.1
13
+ dataclasses-json==0.6.7
14
+ fastapi==0.116.1
15
+ ffmpy==0.6.1
16
+ filelock==3.18.0
17
+ fonttools==4.59.0
18
+ frozenlist==1.7.0
19
+ fsspec==2025.7.0
20
+ gradio==4.44.1
21
+ gradio_client==1.3.0
22
+ greenlet==3.2.3
23
+ grpcio==1.74.0
24
+ h11==0.16.0
25
+ h2==4.2.0
26
+ hf-xet==1.1.5
27
+ hpack==4.1.0
28
+ httpcore==1.0.9
29
+ httpx==0.28.1
30
+ httpx-sse==0.4.1
31
+ huggingface-hub==0.34.0
32
+ hyperframe==6.1.0
33
+ idna==3.10
34
+ importlib_resources==6.5.2
35
+ Jinja2==3.1.6
36
+ joblib==1.5.1
37
+ jsonpatch==1.33
38
+ jsonpointer==3.0.0
39
+ kiwisolver==1.4.8
40
+ langchain==0.3.26
41
+ langchain-community==0.3.27
42
+ langchain-core==0.3.71
43
+ langchain-text-splitters==0.3.8
44
+ langsmith==0.4.8
45
+ markdown-it-py==3.0.0
46
+ MarkupSafe==2.1.5
47
+ marshmallow==3.26.1
48
+ matplotlib==3.10.3
49
+ mdurl==0.1.2
50
+ mpmath==1.3.0
51
+ multidict==6.6.3
52
+ mypy_extensions==1.1.0
53
+ networkx==3.5
54
+ numpy==2.3.2
55
+ nvidia-cublas-cu12==12.6.4.1
56
+ nvidia-cuda-cupti-cu12==12.6.80
57
+ nvidia-cuda-nvrtc-cu12==12.6.77
58
+ nvidia-cuda-runtime-cu12==12.6.77
59
+ nvidia-cudnn-cu12==9.5.1.17
60
+ nvidia-cufft-cu12==11.3.0.4
61
+ nvidia-cufile-cu12==1.11.1.6
62
+ nvidia-curand-cu12==10.3.7.77
63
+ nvidia-cusolver-cu12==11.7.1.2
64
+ nvidia-cusparse-cu12==12.5.4.2
65
+ nvidia-cusparselt-cu12==0.6.3
66
+ nvidia-nccl-cu12==2.26.2
67
+ nvidia-nvjitlink-cu12==12.6.85
68
+ nvidia-nvtx-cu12==12.6.77
69
+ orjson==3.11.0
70
+ packaging==25.0
71
+ pandas==2.3.1
72
+ pillow==10.4.0
73
+ portalocker==3.2.0
74
+ propcache==0.3.2
75
+ protobuf==6.31.1
76
+ pydantic==2.11.7
77
+ pydantic-settings==2.10.1
78
+ pydantic_core==2.33.2
79
+ pydub==0.25.1
80
+ Pygments==2.19.2
81
+ pyparsing==3.2.3
82
+ python-dateutil==2.9.0.post0
83
+ python-dotenv==1.1.1
84
+ python-multipart==0.0.20
85
+ pytz==2025.2
86
+ PyYAML==6.0.2
87
+ qdrant-client==1.15.0
88
+ regex==2024.11.6
89
+ requests==2.32.4
90
+ requests-toolbelt==1.0.0
91
+ rich==14.1.0
92
+ ruff==0.12.5
93
+ safetensors==0.5.3
94
+ scikit-learn==1.7.1
95
+ scipy==1.16.0
96
+ semantic-version==2.10.0
97
+ sentence-transformers==5.0.0
98
+ shellingham==1.5.4
99
+ six==1.17.0
100
+ sniffio==1.3.1
101
+ SQLAlchemy==2.0.41
102
+ starlette==0.47.2
103
+ sympy==1.14.0
104
+ tenacity==9.1.2
105
+ threadpoolctl==3.6.0
106
+ tokenizers==0.21.2
107
+ tomlkit==0.12.0
108
+ torch==2.7.1
109
+ tqdm==4.67.1
110
+ transformers==4.53.3
111
+ triton==3.3.1
112
+ typer==0.16.0
113
+ typing-inspect==0.9.0
114
+ typing-inspection==0.4.1
115
+ typing_extensions==4.14.1
116
+ tzdata==2025.2
117
+ urllib3==2.5.0
118
+ uvicorn==0.35.0
119
+ websockets==12.0
120
+ yarl==1.20.1
121
+ zstandard==0.23.0
utils/retriever.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ from qdrant_client.http import models as rest
3
+ from langchain.schema import Document
4
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
5
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
6
+ from sentence_transformers import SentenceTransformer
7
+ model = SentenceTransformer('BAAI/bge-m3')
8
+ import logging
9
+ import os
10
+ from .utils import getconfig
11
+ from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore
12
+ import sys
13
+
14
+ # Configure logging to be more verbose
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(levelname)s - %(message)s',
18
+ handlers=[
19
+ logging.StreamHandler(sys.stdout)
20
+ ]
21
+ )
22
+
23
+ # Load configuration
24
+ config = getconfig("params.cfg")
25
+
26
+ # Retriever settings from config
27
+ RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
28
+ SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
29
+
30
+ # Reranker settings from config
31
+ RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False)
32
+ RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2")
33
+ RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5))
34
+ RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2))
35
+
36
+ # Initialize reranker if enabled
37
+ reranker = None
38
+ if RERANKER_ENABLED:
39
+ try:
40
+ print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True)
41
+ logging.info(f"Initializing reranker with model: {RERANKER_MODEL}")
42
+
43
+ print("Loading HuggingFace cross encoder model", flush=True)
44
+ # HuggingFaceCrossEncoder doesn't accept cache_dir parameter
45
+ # The underlying models will use default cache locations
46
+ cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL)
47
+ print("Cross encoder model loaded successfully", flush=True)
48
+
49
+ print("Creating CrossEncoderReranker...", flush=True)
50
+ reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K)
51
+ print("Reranker initialized successfully", flush=True)
52
+ logging.info("Reranker initialized successfully")
53
+ except Exception as e:
54
+ print(f"Failed to initialize reranker: {str(e)}", flush=True)
55
+ logging.error(f"Failed to initialize reranker: {str(e)}")
56
+ reranker = None
57
+ else:
58
+ print("Reranker is disabled", flush=True)
59
+
60
+ def get_vectorstore() -> VectorStoreInterface:
61
+ """
62
+ Create and return a vector store connection.
63
+
64
+ Returns:
65
+ VectorStoreInterface instance
66
+ """
67
+ logging.info("Initializing vector store connection...")
68
+ vectorstore = create_vectorstore(config)
69
+ logging.info("Vector store connection initialized successfully")
70
+ return vectorstore
71
+
72
+ def create_filter(
73
+ reports: List[str] = None,
74
+ sources: str = None,
75
+ subtype: str = None,
76
+ year: List[str] = None
77
+ ) -> Optional[rest.Filter]:
78
+ """
79
+ Create a Qdrant filter based on metadata criteria.
80
+
81
+ Args:
82
+ reports: List of specific report filenames to filter by
83
+ sources: Source type to filter by
84
+ subtype: Document subtype to filter by
85
+ year: List of years to filter by
86
+
87
+ Returns:
88
+ Qdrant Filter object or None if no filters specified
89
+ """
90
+ if not any([reports, sources, subtype, year]):
91
+ return None
92
+
93
+ conditions = []
94
+
95
+ if reports and len(reports) > 0:
96
+ logging.info(f"Defining filter for reports: {reports}")
97
+ conditions.append(
98
+ rest.FieldCondition(
99
+ key="metadata.filename",
100
+ match=rest.MatchAny(any=reports)
101
+ )
102
+ )
103
+ else:
104
+ if sources:
105
+ logging.info(f"Defining filter for sources: {sources}")
106
+ conditions.append(
107
+ rest.FieldCondition(
108
+ key="metadata.source",
109
+ match=rest.MatchValue(value=sources)
110
+ )
111
+ )
112
+
113
+ if subtype:
114
+ logging.info(f"Defining filter for subtype: {subtype}")
115
+ conditions.append(
116
+ rest.FieldCondition(
117
+ key="metadata.subtype",
118
+ match=rest.MatchValue(value=subtype)
119
+ )
120
+ )
121
+
122
+ if year and len(year) > 0:
123
+ logging.info(f"Defining filter for years: {year}")
124
+ conditions.append(
125
+ rest.FieldCondition(
126
+ key="metadata.year",
127
+ match=rest.MatchAny(any=year)
128
+ )
129
+ )
130
+
131
+ if conditions:
132
+ return rest.Filter(must=conditions)
133
+ return None
134
+
135
+ def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
136
+ """
137
+ Rerank documents using cross-encoder (specify in params.cfg)
138
+
139
+ Args:
140
+ query: The search query
141
+ documents: List of documents to rerank
142
+
143
+ Returns:
144
+ Reranked list of documents in original format
145
+ """
146
+ if not reranker or not documents:
147
+ return documents
148
+
149
+ try:
150
+ logging.info(f"Starting reranking of {len(documents)} documents")
151
+
152
+ # Convert to LangChain Document format using correct keys (need to review this later for portability)
153
+ langchain_docs = []
154
+ for doc in documents:
155
+ # Use correct keys from the data storage test module
156
+ content = doc.get('answer', '')
157
+ metadata = doc.get('answer_metadata', {})
158
+
159
+ if not content:
160
+ logging.warning(f"Document missing content: {doc}")
161
+ continue
162
+
163
+ langchain_doc = Document(
164
+ page_content=content,
165
+ metadata=metadata
166
+ )
167
+ langchain_docs.append(langchain_doc)
168
+
169
+ if not langchain_docs:
170
+ logging.warning("No valid documents found for reranking")
171
+ return documents
172
+
173
+ # Rerank documents
174
+ logging.info(f"Reranking {len(langchain_docs)} documents")
175
+ reranked_docs = reranker.compress_documents(langchain_docs, query)
176
+
177
+ # Convert back to original format
178
+ result = []
179
+ for doc in reranked_docs:
180
+ result.append({
181
+ 'answer': doc.page_content,
182
+ 'answer_metadata': doc.metadata,
183
+ })
184
+
185
+ logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}")
186
+ return result
187
+
188
+ except Exception as e:
189
+ logging.error(f"Error during reranking: {str(e)}")
190
+ # Return original documents if reranking fails
191
+ return documents
192
+
193
+ def get_context(
194
+ vectorstore: VectorStoreInterface,
195
+ query: str,
196
+ reports: List[str] = None,
197
+ sources: str = None,
198
+ subtype: str = None,
199
+ year: List[str] = None
200
+ ) -> List[Dict[str, Any]]:
201
+ """
202
+ Retrieve semantically similar documents from the vector database with optional reranking.
203
+
204
+ Args:
205
+ vectorstore: The vector store interface to search
206
+ query: The search query
207
+ reports: List of specific report filenames to search within
208
+ sources: Source type to filter by
209
+ subtype: Document subtype to filter by
210
+ year: List of years to filter by
211
+
212
+ Returns:
213
+ List of dictionaries with 'answer', 'answer_metadata', and 'score' keys
214
+ """
215
+ try:
216
+ # Use a higher k for initial retrieval if reranking is enabled (more candidates docs)
217
+ top_k = RETRIEVER_TOP_K
218
+ if RERANKER_ENABLED and reranker:
219
+ top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR
220
+ logging.info(f"Reranking enabled, retrieving {top_k} candidates")
221
+
222
+ search_kwargs = {
223
+ "model_name": config.get("embeddings", "MODEL_NAME")
224
+ }
225
+ #model = SentenceTransformer(config.get("embeddings", "MODEL_NAME"))
226
+ #query_vector = model.encode(query).tolist()
227
+ #retrieved_docs = vectorstore.search(
228
+ ## collection_name="EUDR",
229
+ # query_vector=query_vector,
230
+ # limit=top_k,
231
+ # with_payload=True)
232
+ # filter support for QdrantVectorStore
233
+ #if isinstance(vectorstore, QdrantVectorStore):
234
+ # filter_obj = create_filter(reports, sources, subtype, year)
235
+ # if filter_obj:
236
+ # search_kwargs["filter"] = filter_obj
237
+
238
+ # Perform initial retrieval
239
+ retrieved_docs = vectorstore.search(query, top_k)
240
+
241
+ logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
242
+
243
+ # Apply reranking if enabled
244
+ if RERANKER_ENABLED and reranker and retrieved_docs:
245
+ logging.info("Applying reranking...")
246
+ retrieved_docs = rerank_documents(query, retrieved_docs)
247
+
248
+ # Trim to final desired k
249
+ retrieved_docs = retrieved_docs[:RERANKER_TOP_K]
250
+
251
+ logging.info(f"Returning {len(retrieved_docs)} final documents")
252
+ logging.info(f"Retrieved results: {retrieved_docs}")
253
+ return retrieved_docs
254
+
255
+ except Exception as e:
256
+ logging.error(f"Error during retrieval: {str(e)}")
257
+ raise e
utils/utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configparser
2
+ import logging
3
+ import os
4
+ import ast
5
+ import re
6
+ from dotenv import load_dotenv
7
+
8
+ # Local .env file
9
+ load_dotenv()
10
+
11
+ def getconfig(configfile_path: str):
12
+ """
13
+ Read the config file
14
+ Params
15
+ ----------------
16
+ configfile_path: file path of .cfg file
17
+ """
18
+ config = configparser.ConfigParser()
19
+ try:
20
+ config.read_file(open(configfile_path))
21
+ return config
22
+ except:
23
+ logging.warning("config file not found")
24
+
25
+
26
+ def get_auth(provider: str) -> dict:
27
+ """Get authentication configuration for different providers"""
28
+ auth_configs = {
29
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
30
+ "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
31
+ }
32
+
33
+ provider = provider.lower() # Normalize to lowercase
34
+
35
+ if provider not in auth_configs:
36
+ raise ValueError(f"Unsupported provider: {provider}")
37
+
38
+ auth_config = auth_configs[provider]
39
+ api_key = auth_config.get("api_key")
40
+
41
+ if not api_key:
42
+ logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
43
+ auth_config["api_key"] = None
44
+
45
+ return auth_config
46
+
47
+
48
+ def process_content(content: str) -> str:
49
+ """
50
+ Process and clean malformed content that may contain stringified nested lists.
51
+ The test DB on qdrant somehow got a bit malformed in the processing - but probably good to have this anyway
52
+
53
+ Args:
54
+ content: Raw content from vector store
55
+
56
+ Returns:
57
+ Cleaned, readable text content
58
+ """
59
+ if not content:
60
+ return content
61
+
62
+ # Check if content looks like a stringified list/nested structure
63
+ content_stripped = content.strip()
64
+ if content_stripped.startswith('[') and content_stripped.endswith(']'):
65
+ try:
66
+ # Parse as literal list structure
67
+ parsed_content = ast.literal_eval(content_stripped)
68
+
69
+ if isinstance(parsed_content, list):
70
+ # Flatten nested lists and extract meaningful text
71
+ def extract_text_from_nested(obj):
72
+ if isinstance(obj, list):
73
+ text_items = []
74
+ for item in obj:
75
+ extracted = extract_text_from_nested(item)
76
+ if extracted and extracted.strip():
77
+ text_items.append(extracted)
78
+ return ' '.join(text_items)
79
+ elif isinstance(obj, str) and obj.strip():
80
+ return obj.strip()
81
+ elif isinstance(obj, dict):
82
+ # Handle dict structures if present
83
+ text_items = []
84
+ for key, value in obj.items():
85
+ if isinstance(value, str) and value.strip():
86
+ text_items.append(f"{key}: {value}")
87
+ return ' '.join(text_items)
88
+ else:
89
+ return ''
90
+
91
+ extracted_text = extract_text_from_nested(parsed_content)
92
+
93
+ if extracted_text and len(extracted_text.strip()) > 0:
94
+ # Clean up extra whitespace and format nicely
95
+ cleaned_text = re.sub(r'\s+', ' ', extracted_text).strip()
96
+ logging.debug(f"Successfully processed nested list content: {len(cleaned_text)} chars")
97
+ return cleaned_text
98
+ else:
99
+ logging.warning("Parsed list content but no meaningful text found")
100
+ return content # Return original if no meaningful text extracted
101
+
102
+ except (ValueError, SyntaxError) as e:
103
+ logging.debug(f"Content looks like list but failed to parse: {e}")
104
+ # Fall through to return original content
105
+
106
+ # For regular text content, just clean up whitespace
107
+ return re.sub(r'\s+', ' ', content).strip()
utils/vectorstore_interface.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any, Optional
3
+ from gradio_client import Client
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from dotenv import load_dotenv
8
+ from .utils import get_auth, process_content
9
+ load_dotenv()
10
+
11
+
12
+ class VectorStoreInterface(ABC):
13
+ """Abstract interface for different vector store implementations."""
14
+
15
+ @abstractmethod
16
+ def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
17
+ """Search for similar documents."""
18
+ pass
19
+
20
+
21
+ class HuggingFaceSpacesVectorStore(VectorStoreInterface):
22
+ """Vector store implementation for Hugging Face Spaces with MCP endpoints."""
23
+
24
+ def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
25
+ repo_id = url
26
+
27
+ logging.info(f"Connecting to Hugging Face Space: {repo_id}")
28
+
29
+ if api_key:
30
+ self.client = Client(repo_id, hf_token=api_key)
31
+ else:
32
+ self.client = Client(repo_id)
33
+
34
+ self.collection_name = collection_name
35
+
36
+ def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
37
+ """Search using Hugging Face Spaces MCP API."""
38
+ try:
39
+ # Use the /search_text endpoint as documented in the API
40
+ result = self.client.predict(
41
+ query=query,
42
+ collection_name=self.collection_name,
43
+ model_name=kwargs.get('model_name'),
44
+ top_k=top_k,
45
+ api_name="/search_text"
46
+ )
47
+
48
+ logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
49
+ return result
50
+
51
+ except Exception as e:
52
+ logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
53
+ raise e
54
+
55
+ class QdrantVectorStore(VectorStoreInterface):
56
+ """Vector store implementation for direct Qdrant connection."""
57
+
58
+ def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
59
+ from qdrant_client import QdrantClient
60
+ from sentence_transformers import SentenceTransformer
61
+
62
+ self.client = QdrantClient(host = url,
63
+ # very important that port to be used for python client
64
+ port=443,
65
+ https=True,
66
+ # api_key = QDRANT_API_KEY_READ,
67
+ ## this is for write access
68
+ api_key = api_key,
69
+ timeout=120,)
70
+
71
+ #self.client = QdrantClient(
72
+ # url=url, # Use url parameter which handles full URLs with protocol
73
+ # api_key=api_key
74
+ #)
75
+
76
+ self.collection_name = collection_name
77
+ # Initialize embedding model as None - will be loaded on first use
78
+ self._embedding_model = None
79
+ self._current_model_name = None
80
+
81
+ def _get_embedding_model(self, model_name: str = None):
82
+ """Lazy load embedding model to avoid loading if not needed."""
83
+ if model_name is None:
84
+ model_name = "BAAI/bge-m3" # Default from config
85
+
86
+ # Only reload if model name changed
87
+ if self._embedding_model is None or self._current_model_name != model_name:
88
+ logging.info(f"Loading embedding model: {model_name}")
89
+ from sentence_transformers import SentenceTransformer
90
+
91
+ cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache"))
92
+ cache_folder.mkdir(parents=True, exist_ok=True)
93
+
94
+ self._embedding_model = SentenceTransformer(
95
+ model_name,
96
+ cache_folder=str(cache_folder)
97
+ )
98
+ # self._embedding_model = SentenceTransformer(model_name)
99
+ self._current_model_name = model_name
100
+ logging.info(f"Successfully loaded embedding model: {model_name}")
101
+
102
+ return self._embedding_model
103
+
104
+ def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
105
+ """Search using direct Qdrant connection."""
106
+ try:
107
+ # Get embedding model
108
+ model_name = kwargs.get('model_name')
109
+ embedding_model = self._get_embedding_model(model_name)
110
+
111
+ # Convert query to embedding
112
+ logging.info(f"Converting query to embedding using model: {self._current_model_name}")
113
+ query_embedding = embedding_model.encode(query).tolist()
114
+
115
+ # Get filter from kwargs if provided
116
+ filter_obj = kwargs.get('filter', None)
117
+
118
+ # Perform vector search
119
+ logging.info(f"Searching Qdrant collection '{self.collection_name}' for top {top_k} results")
120
+ search_result = self.client.search(
121
+ collection_name=self.collection_name,
122
+ query_vector=query_embedding,
123
+ query_filter=filter_obj, # Add filter support
124
+ limit=top_k,
125
+ with_payload=True,
126
+ with_vectors=False
127
+ )
128
+ logging.info(search_result)
129
+ # Format results to match expected output format
130
+ results = []
131
+ for hit in search_result:
132
+ raw_content = hit.payload.get('text', '')
133
+ # Process content to handle malformed nested list structures
134
+ processed_content = process_content(raw_content)
135
+
136
+ result_dict = {
137
+ 'answer': processed_content,
138
+ 'answer_metadata': hit.payload.get('metadata', {}),
139
+ 'score': hit.score
140
+ }
141
+ results.append(result_dict)
142
+
143
+ logging.info(f"Successfully retrieved {len(results)} documents from Qdrant")
144
+ return results
145
+
146
+ except Exception as e:
147
+ logging.error(f"Error searching Qdrant: {str(e)}")
148
+ raise e
149
+
150
+ def create_vectorstore(config: Any) -> VectorStoreInterface:
151
+ """Factory function to create appropriate vector store based on configuration."""
152
+ vectorstore_type = config.get("vectorstore", "PROVIDER")
153
+
154
+ # Get authentication config based on provider
155
+ auth_config = get_auth(vectorstore_type.lower())
156
+
157
+ if vectorstore_type.lower() == "huggingface":
158
+ url = config.get("vectorstore", "URL")
159
+ collection_name = config.get("vectorstore", "COLLECTION_NAME")
160
+ api_key = auth_config["api_key"]
161
+ return HuggingFaceSpacesVectorStore(url, collection_name, api_key)
162
+
163
+ elif vectorstore_type.lower() == "qdrant":
164
+ url = config.get("vectorstore", "URL") # Use the full URL
165
+ collection_name = config.get("vectorstore", "COLLECTION_NAME")
166
+ api_key = auth_config["api_key"]
167
+ # Remove port parameter since it's included in the URL
168
+ return QdrantVectorStore(url, collection_name, api_key)
169
+
170
+ else:
171
+ raise ValueError(f"Unsupported vector store type: {vectorstore_type}")