Spaces:
Running
on
T4
Running
on
T4
Upload 6 files
Browse files- app.py +121 -0
- params.cfg +26 -0
- requirements.txt +121 -0
- utils/retriever.py +257 -0
- utils/utils.py +107 -0
- utils/vectorstore_interface.py +171 -0
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import sys
|
3 |
+
from utils.retriever import get_context, get_vectorstore
|
4 |
+
|
5 |
+
# Initialize vector store at startup
|
6 |
+
print("Initializing vector store connection...", flush=True)
|
7 |
+
try:
|
8 |
+
vectorstore = get_vectorstore()
|
9 |
+
print("Vector store connection initialized successfully", flush=True)
|
10 |
+
except Exception as e:
|
11 |
+
print(f"Failed to initialize vector store: {e}", flush=True)
|
12 |
+
raise
|
13 |
+
|
14 |
+
# ---------------------------------------------------------------------
|
15 |
+
# MCP - returns raw dictionary format
|
16 |
+
# ---------------------------------------------------------------------
|
17 |
+
|
18 |
+
def retrieve(
|
19 |
+
query: str,
|
20 |
+
reports_filter: str = "",
|
21 |
+
sources_filter: str = "",
|
22 |
+
subtype_filter: str = "",
|
23 |
+
year_filter: str = ""
|
24 |
+
) -> list:
|
25 |
+
"""
|
26 |
+
Retrieve semantically similar documents from the vector database for MCP clients.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
query (str): The search query text
|
30 |
+
reports_filter (str): Comma-separated list of specific report filenames (optional)
|
31 |
+
sources_filter (str): Filter by document source type (optional)
|
32 |
+
subtype_filter (str): Filter by document subtype (optional)
|
33 |
+
year_filter (str): Comma-separated list of years to filter by (optional)
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
list: List of dictionaries containing document content, metadata, and scores
|
37 |
+
"""
|
38 |
+
# Parse filter inputs (convert empty strings to None or lists)
|
39 |
+
reports = [r.strip() for r in reports_filter.split(",") if r.strip()] if reports_filter else []
|
40 |
+
sources = sources_filter.strip() if sources_filter else None
|
41 |
+
subtype = subtype_filter.strip() if subtype_filter else None
|
42 |
+
year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
|
43 |
+
|
44 |
+
# Call retriever function and return raw results
|
45 |
+
results = get_context(
|
46 |
+
vectorstore=vectorstore,
|
47 |
+
query=query,
|
48 |
+
reports=reports,
|
49 |
+
sources=sources,
|
50 |
+
subtype=subtype,
|
51 |
+
year=year
|
52 |
+
)
|
53 |
+
|
54 |
+
return results
|
55 |
+
|
56 |
+
|
57 |
+
# Create the Gradio interface with Blocks to support both UI and MCP
|
58 |
+
with gr.Blocks() as ui:
|
59 |
+
gr.Markdown("# ChatFed Retrieval/Reranker Module")
|
60 |
+
gr.Markdown("Retrieves semantically similar documents from vector database and reranks. Intended for use in RAG pipelines as an MCP server with other ChatFed modules.")
|
61 |
+
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Column():
|
64 |
+
query_input = gr.Textbox(
|
65 |
+
label="Query",
|
66 |
+
lines=2,
|
67 |
+
placeholder="Enter your search query here",
|
68 |
+
info="The query to search for in the vector database"
|
69 |
+
)
|
70 |
+
reports_input = gr.Textbox(
|
71 |
+
label="Reports Filter (optional)",
|
72 |
+
lines=1,
|
73 |
+
placeholder="report1.pdf, report2.pdf",
|
74 |
+
info="Comma-separated list of specific report filenames to search within (leave empty for all)"
|
75 |
+
)
|
76 |
+
sources_input = gr.Textbox(
|
77 |
+
label="Sources Filter (optional)",
|
78 |
+
lines=1,
|
79 |
+
placeholder="annual_report",
|
80 |
+
info="Filter by document source type (leave empty for all)"
|
81 |
+
)
|
82 |
+
subtype_input = gr.Textbox(
|
83 |
+
label="Subtype Filter (optional)",
|
84 |
+
lines=1,
|
85 |
+
placeholder="financial",
|
86 |
+
info="Filter by document subtype (leave empty for all)"
|
87 |
+
)
|
88 |
+
year_input = gr.Textbox(
|
89 |
+
label="Year Filter (optional)",
|
90 |
+
lines=1,
|
91 |
+
placeholder="2023, 2024",
|
92 |
+
info="Comma-separated list of years to filter by (leave empty for all)"
|
93 |
+
)
|
94 |
+
|
95 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
96 |
+
|
97 |
+
# Output needs to be in json format to be added as tool in HuggingChat
|
98 |
+
with gr.Column():
|
99 |
+
output = gr.Text(
|
100 |
+
label="Retrieved Context",
|
101 |
+
lines=10,
|
102 |
+
show_copy_button=True
|
103 |
+
)
|
104 |
+
|
105 |
+
# UI event handler
|
106 |
+
submit_btn.click(
|
107 |
+
fn=retrieve,
|
108 |
+
inputs=[query_input, reports_input, sources_input, subtype_input, year_input],
|
109 |
+
outputs=output,
|
110 |
+
api_name="retrieve"
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
# Launch with MCP server enabled
|
115 |
+
if __name__ == "__main__":
|
116 |
+
ui.launch(
|
117 |
+
server_name="0.0.0.0",
|
118 |
+
server_port=7860,
|
119 |
+
#mcp_server=True,
|
120 |
+
show_error=True
|
121 |
+
)
|
params.cfg
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[vectorstore]
|
2 |
+
# huggingface_spaces usage:
|
3 |
+
# PROVIDER = huggingface
|
4 |
+
# URL = GIZ/audit_data
|
5 |
+
# COLLECTION_NAME = docling
|
6 |
+
|
7 |
+
# direct Qdrant usage:
|
8 |
+
PROVIDER = qdrant
|
9 |
+
URL = giz-chatfed-qdrantserver.hf.space
|
10 |
+
COLLECTION_NAME = EUDR
|
11 |
+
|
12 |
+
[embeddings]
|
13 |
+
MODEL_NAME = BAAI/bge-m3
|
14 |
+
# DEVICE = cpu
|
15 |
+
|
16 |
+
[retriever]
|
17 |
+
TOP_K = 10
|
18 |
+
SCORE_THRESHOLD = 0.6
|
19 |
+
|
20 |
+
[reranker]
|
21 |
+
MODEL_NAME = cross-encoder/ms-marco-MiniLM-L-6-v2
|
22 |
+
TOP_K = 5
|
23 |
+
ENABLED = false
|
24 |
+
# use this to scale out the total docs retrieved prior to reranking (i.e. retriever top_k * TOP_K_SCALE_FACTOR)
|
25 |
+
TOP_K_SCALE_FACTOR = 2
|
26 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohappyeyeballs==2.6.1
|
3 |
+
aiohttp==3.12.14
|
4 |
+
aiosignal==1.4.0
|
5 |
+
annotated-types==0.7.0
|
6 |
+
anyio==4.9.0
|
7 |
+
attrs==25.3.0
|
8 |
+
certifi==2025.7.14
|
9 |
+
charset-normalizer==3.4.2
|
10 |
+
click==8.2.1
|
11 |
+
contourpy==1.3.2
|
12 |
+
cycler==0.12.1
|
13 |
+
dataclasses-json==0.6.7
|
14 |
+
fastapi==0.116.1
|
15 |
+
ffmpy==0.6.1
|
16 |
+
filelock==3.18.0
|
17 |
+
fonttools==4.59.0
|
18 |
+
frozenlist==1.7.0
|
19 |
+
fsspec==2025.7.0
|
20 |
+
gradio==4.44.1
|
21 |
+
gradio_client==1.3.0
|
22 |
+
greenlet==3.2.3
|
23 |
+
grpcio==1.74.0
|
24 |
+
h11==0.16.0
|
25 |
+
h2==4.2.0
|
26 |
+
hf-xet==1.1.5
|
27 |
+
hpack==4.1.0
|
28 |
+
httpcore==1.0.9
|
29 |
+
httpx==0.28.1
|
30 |
+
httpx-sse==0.4.1
|
31 |
+
huggingface-hub==0.34.0
|
32 |
+
hyperframe==6.1.0
|
33 |
+
idna==3.10
|
34 |
+
importlib_resources==6.5.2
|
35 |
+
Jinja2==3.1.6
|
36 |
+
joblib==1.5.1
|
37 |
+
jsonpatch==1.33
|
38 |
+
jsonpointer==3.0.0
|
39 |
+
kiwisolver==1.4.8
|
40 |
+
langchain==0.3.26
|
41 |
+
langchain-community==0.3.27
|
42 |
+
langchain-core==0.3.71
|
43 |
+
langchain-text-splitters==0.3.8
|
44 |
+
langsmith==0.4.8
|
45 |
+
markdown-it-py==3.0.0
|
46 |
+
MarkupSafe==2.1.5
|
47 |
+
marshmallow==3.26.1
|
48 |
+
matplotlib==3.10.3
|
49 |
+
mdurl==0.1.2
|
50 |
+
mpmath==1.3.0
|
51 |
+
multidict==6.6.3
|
52 |
+
mypy_extensions==1.1.0
|
53 |
+
networkx==3.5
|
54 |
+
numpy==2.3.2
|
55 |
+
nvidia-cublas-cu12==12.6.4.1
|
56 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
57 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
58 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
59 |
+
nvidia-cudnn-cu12==9.5.1.17
|
60 |
+
nvidia-cufft-cu12==11.3.0.4
|
61 |
+
nvidia-cufile-cu12==1.11.1.6
|
62 |
+
nvidia-curand-cu12==10.3.7.77
|
63 |
+
nvidia-cusolver-cu12==11.7.1.2
|
64 |
+
nvidia-cusparse-cu12==12.5.4.2
|
65 |
+
nvidia-cusparselt-cu12==0.6.3
|
66 |
+
nvidia-nccl-cu12==2.26.2
|
67 |
+
nvidia-nvjitlink-cu12==12.6.85
|
68 |
+
nvidia-nvtx-cu12==12.6.77
|
69 |
+
orjson==3.11.0
|
70 |
+
packaging==25.0
|
71 |
+
pandas==2.3.1
|
72 |
+
pillow==10.4.0
|
73 |
+
portalocker==3.2.0
|
74 |
+
propcache==0.3.2
|
75 |
+
protobuf==6.31.1
|
76 |
+
pydantic==2.11.7
|
77 |
+
pydantic-settings==2.10.1
|
78 |
+
pydantic_core==2.33.2
|
79 |
+
pydub==0.25.1
|
80 |
+
Pygments==2.19.2
|
81 |
+
pyparsing==3.2.3
|
82 |
+
python-dateutil==2.9.0.post0
|
83 |
+
python-dotenv==1.1.1
|
84 |
+
python-multipart==0.0.20
|
85 |
+
pytz==2025.2
|
86 |
+
PyYAML==6.0.2
|
87 |
+
qdrant-client==1.15.0
|
88 |
+
regex==2024.11.6
|
89 |
+
requests==2.32.4
|
90 |
+
requests-toolbelt==1.0.0
|
91 |
+
rich==14.1.0
|
92 |
+
ruff==0.12.5
|
93 |
+
safetensors==0.5.3
|
94 |
+
scikit-learn==1.7.1
|
95 |
+
scipy==1.16.0
|
96 |
+
semantic-version==2.10.0
|
97 |
+
sentence-transformers==5.0.0
|
98 |
+
shellingham==1.5.4
|
99 |
+
six==1.17.0
|
100 |
+
sniffio==1.3.1
|
101 |
+
SQLAlchemy==2.0.41
|
102 |
+
starlette==0.47.2
|
103 |
+
sympy==1.14.0
|
104 |
+
tenacity==9.1.2
|
105 |
+
threadpoolctl==3.6.0
|
106 |
+
tokenizers==0.21.2
|
107 |
+
tomlkit==0.12.0
|
108 |
+
torch==2.7.1
|
109 |
+
tqdm==4.67.1
|
110 |
+
transformers==4.53.3
|
111 |
+
triton==3.3.1
|
112 |
+
typer==0.16.0
|
113 |
+
typing-inspect==0.9.0
|
114 |
+
typing-inspection==0.4.1
|
115 |
+
typing_extensions==4.14.1
|
116 |
+
tzdata==2025.2
|
117 |
+
urllib3==2.5.0
|
118 |
+
uvicorn==0.35.0
|
119 |
+
websockets==12.0
|
120 |
+
yarl==1.20.1
|
121 |
+
zstandard==0.23.0
|
utils/retriever.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional
|
2 |
+
from qdrant_client.http import models as rest
|
3 |
+
from langchain.schema import Document
|
4 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
5 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
model = SentenceTransformer('BAAI/bge-m3')
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from .utils import getconfig
|
11 |
+
from .vectorstore_interface import create_vectorstore, VectorStoreInterface, QdrantVectorStore
|
12 |
+
import sys
|
13 |
+
|
14 |
+
# Configure logging to be more verbose
|
15 |
+
logging.basicConfig(
|
16 |
+
level=logging.INFO,
|
17 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
18 |
+
handlers=[
|
19 |
+
logging.StreamHandler(sys.stdout)
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
# Load configuration
|
24 |
+
config = getconfig("params.cfg")
|
25 |
+
|
26 |
+
# Retriever settings from config
|
27 |
+
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
|
28 |
+
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
|
29 |
+
|
30 |
+
# Reranker settings from config
|
31 |
+
RERANKER_ENABLED = config.getboolean("reranker", "ENABLED", fallback=False)
|
32 |
+
RERANKER_MODEL = config.get("reranker", "MODEL_NAME", fallback="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
33 |
+
RERANKER_TOP_K = int(config.get("reranker", "TOP_K", fallback=5))
|
34 |
+
RERANKER_TOP_K_SCALE_FACTOR = int(config.get("reranker", "TOP_K_SCALE_FACTOR", fallback=2))
|
35 |
+
|
36 |
+
# Initialize reranker if enabled
|
37 |
+
reranker = None
|
38 |
+
if RERANKER_ENABLED:
|
39 |
+
try:
|
40 |
+
print(f"Starting reranker initialization with model: {RERANKER_MODEL}", flush=True)
|
41 |
+
logging.info(f"Initializing reranker with model: {RERANKER_MODEL}")
|
42 |
+
|
43 |
+
print("Loading HuggingFace cross encoder model", flush=True)
|
44 |
+
# HuggingFaceCrossEncoder doesn't accept cache_dir parameter
|
45 |
+
# The underlying models will use default cache locations
|
46 |
+
cross_encoder_model = HuggingFaceCrossEncoder(model_name=RERANKER_MODEL)
|
47 |
+
print("Cross encoder model loaded successfully", flush=True)
|
48 |
+
|
49 |
+
print("Creating CrossEncoderReranker...", flush=True)
|
50 |
+
reranker = CrossEncoderReranker(model=cross_encoder_model, top_n=RERANKER_TOP_K)
|
51 |
+
print("Reranker initialized successfully", flush=True)
|
52 |
+
logging.info("Reranker initialized successfully")
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Failed to initialize reranker: {str(e)}", flush=True)
|
55 |
+
logging.error(f"Failed to initialize reranker: {str(e)}")
|
56 |
+
reranker = None
|
57 |
+
else:
|
58 |
+
print("Reranker is disabled", flush=True)
|
59 |
+
|
60 |
+
def get_vectorstore() -> VectorStoreInterface:
|
61 |
+
"""
|
62 |
+
Create and return a vector store connection.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
VectorStoreInterface instance
|
66 |
+
"""
|
67 |
+
logging.info("Initializing vector store connection...")
|
68 |
+
vectorstore = create_vectorstore(config)
|
69 |
+
logging.info("Vector store connection initialized successfully")
|
70 |
+
return vectorstore
|
71 |
+
|
72 |
+
def create_filter(
|
73 |
+
reports: List[str] = None,
|
74 |
+
sources: str = None,
|
75 |
+
subtype: str = None,
|
76 |
+
year: List[str] = None
|
77 |
+
) -> Optional[rest.Filter]:
|
78 |
+
"""
|
79 |
+
Create a Qdrant filter based on metadata criteria.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
reports: List of specific report filenames to filter by
|
83 |
+
sources: Source type to filter by
|
84 |
+
subtype: Document subtype to filter by
|
85 |
+
year: List of years to filter by
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
Qdrant Filter object or None if no filters specified
|
89 |
+
"""
|
90 |
+
if not any([reports, sources, subtype, year]):
|
91 |
+
return None
|
92 |
+
|
93 |
+
conditions = []
|
94 |
+
|
95 |
+
if reports and len(reports) > 0:
|
96 |
+
logging.info(f"Defining filter for reports: {reports}")
|
97 |
+
conditions.append(
|
98 |
+
rest.FieldCondition(
|
99 |
+
key="metadata.filename",
|
100 |
+
match=rest.MatchAny(any=reports)
|
101 |
+
)
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
if sources:
|
105 |
+
logging.info(f"Defining filter for sources: {sources}")
|
106 |
+
conditions.append(
|
107 |
+
rest.FieldCondition(
|
108 |
+
key="metadata.source",
|
109 |
+
match=rest.MatchValue(value=sources)
|
110 |
+
)
|
111 |
+
)
|
112 |
+
|
113 |
+
if subtype:
|
114 |
+
logging.info(f"Defining filter for subtype: {subtype}")
|
115 |
+
conditions.append(
|
116 |
+
rest.FieldCondition(
|
117 |
+
key="metadata.subtype",
|
118 |
+
match=rest.MatchValue(value=subtype)
|
119 |
+
)
|
120 |
+
)
|
121 |
+
|
122 |
+
if year and len(year) > 0:
|
123 |
+
logging.info(f"Defining filter for years: {year}")
|
124 |
+
conditions.append(
|
125 |
+
rest.FieldCondition(
|
126 |
+
key="metadata.year",
|
127 |
+
match=rest.MatchAny(any=year)
|
128 |
+
)
|
129 |
+
)
|
130 |
+
|
131 |
+
if conditions:
|
132 |
+
return rest.Filter(must=conditions)
|
133 |
+
return None
|
134 |
+
|
135 |
+
def rerank_documents(query: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
136 |
+
"""
|
137 |
+
Rerank documents using cross-encoder (specify in params.cfg)
|
138 |
+
|
139 |
+
Args:
|
140 |
+
query: The search query
|
141 |
+
documents: List of documents to rerank
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Reranked list of documents in original format
|
145 |
+
"""
|
146 |
+
if not reranker or not documents:
|
147 |
+
return documents
|
148 |
+
|
149 |
+
try:
|
150 |
+
logging.info(f"Starting reranking of {len(documents)} documents")
|
151 |
+
|
152 |
+
# Convert to LangChain Document format using correct keys (need to review this later for portability)
|
153 |
+
langchain_docs = []
|
154 |
+
for doc in documents:
|
155 |
+
# Use correct keys from the data storage test module
|
156 |
+
content = doc.get('answer', '')
|
157 |
+
metadata = doc.get('answer_metadata', {})
|
158 |
+
|
159 |
+
if not content:
|
160 |
+
logging.warning(f"Document missing content: {doc}")
|
161 |
+
continue
|
162 |
+
|
163 |
+
langchain_doc = Document(
|
164 |
+
page_content=content,
|
165 |
+
metadata=metadata
|
166 |
+
)
|
167 |
+
langchain_docs.append(langchain_doc)
|
168 |
+
|
169 |
+
if not langchain_docs:
|
170 |
+
logging.warning("No valid documents found for reranking")
|
171 |
+
return documents
|
172 |
+
|
173 |
+
# Rerank documents
|
174 |
+
logging.info(f"Reranking {len(langchain_docs)} documents")
|
175 |
+
reranked_docs = reranker.compress_documents(langchain_docs, query)
|
176 |
+
|
177 |
+
# Convert back to original format
|
178 |
+
result = []
|
179 |
+
for doc in reranked_docs:
|
180 |
+
result.append({
|
181 |
+
'answer': doc.page_content,
|
182 |
+
'answer_metadata': doc.metadata,
|
183 |
+
})
|
184 |
+
|
185 |
+
logging.info(f"Successfully reranked {len(documents)} documents to top {len(result)}")
|
186 |
+
return result
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
logging.error(f"Error during reranking: {str(e)}")
|
190 |
+
# Return original documents if reranking fails
|
191 |
+
return documents
|
192 |
+
|
193 |
+
def get_context(
|
194 |
+
vectorstore: VectorStoreInterface,
|
195 |
+
query: str,
|
196 |
+
reports: List[str] = None,
|
197 |
+
sources: str = None,
|
198 |
+
subtype: str = None,
|
199 |
+
year: List[str] = None
|
200 |
+
) -> List[Dict[str, Any]]:
|
201 |
+
"""
|
202 |
+
Retrieve semantically similar documents from the vector database with optional reranking.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
vectorstore: The vector store interface to search
|
206 |
+
query: The search query
|
207 |
+
reports: List of specific report filenames to search within
|
208 |
+
sources: Source type to filter by
|
209 |
+
subtype: Document subtype to filter by
|
210 |
+
year: List of years to filter by
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
List of dictionaries with 'answer', 'answer_metadata', and 'score' keys
|
214 |
+
"""
|
215 |
+
try:
|
216 |
+
# Use a higher k for initial retrieval if reranking is enabled (more candidates docs)
|
217 |
+
top_k = RETRIEVER_TOP_K
|
218 |
+
if RERANKER_ENABLED and reranker:
|
219 |
+
top_k = top_k * RERANKER_TOP_K_SCALE_FACTOR
|
220 |
+
logging.info(f"Reranking enabled, retrieving {top_k} candidates")
|
221 |
+
|
222 |
+
search_kwargs = {
|
223 |
+
"model_name": config.get("embeddings", "MODEL_NAME")
|
224 |
+
}
|
225 |
+
#model = SentenceTransformer(config.get("embeddings", "MODEL_NAME"))
|
226 |
+
#query_vector = model.encode(query).tolist()
|
227 |
+
#retrieved_docs = vectorstore.search(
|
228 |
+
## collection_name="EUDR",
|
229 |
+
# query_vector=query_vector,
|
230 |
+
# limit=top_k,
|
231 |
+
# with_payload=True)
|
232 |
+
# filter support for QdrantVectorStore
|
233 |
+
#if isinstance(vectorstore, QdrantVectorStore):
|
234 |
+
# filter_obj = create_filter(reports, sources, subtype, year)
|
235 |
+
# if filter_obj:
|
236 |
+
# search_kwargs["filter"] = filter_obj
|
237 |
+
|
238 |
+
# Perform initial retrieval
|
239 |
+
retrieved_docs = vectorstore.search(query, top_k)
|
240 |
+
|
241 |
+
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
|
242 |
+
|
243 |
+
# Apply reranking if enabled
|
244 |
+
if RERANKER_ENABLED and reranker and retrieved_docs:
|
245 |
+
logging.info("Applying reranking...")
|
246 |
+
retrieved_docs = rerank_documents(query, retrieved_docs)
|
247 |
+
|
248 |
+
# Trim to final desired k
|
249 |
+
retrieved_docs = retrieved_docs[:RERANKER_TOP_K]
|
250 |
+
|
251 |
+
logging.info(f"Returning {len(retrieved_docs)} final documents")
|
252 |
+
logging.info(f"Retrieved results: {retrieved_docs}")
|
253 |
+
return retrieved_docs
|
254 |
+
|
255 |
+
except Exception as e:
|
256 |
+
logging.error(f"Error during retrieval: {str(e)}")
|
257 |
+
raise e
|
utils/utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configparser
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import ast
|
5 |
+
import re
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
# Local .env file
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
def getconfig(configfile_path: str):
|
12 |
+
"""
|
13 |
+
Read the config file
|
14 |
+
Params
|
15 |
+
----------------
|
16 |
+
configfile_path: file path of .cfg file
|
17 |
+
"""
|
18 |
+
config = configparser.ConfigParser()
|
19 |
+
try:
|
20 |
+
config.read_file(open(configfile_path))
|
21 |
+
return config
|
22 |
+
except:
|
23 |
+
logging.warning("config file not found")
|
24 |
+
|
25 |
+
|
26 |
+
def get_auth(provider: str) -> dict:
|
27 |
+
"""Get authentication configuration for different providers"""
|
28 |
+
auth_configs = {
|
29 |
+
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
|
30 |
+
"qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
|
31 |
+
}
|
32 |
+
|
33 |
+
provider = provider.lower() # Normalize to lowercase
|
34 |
+
|
35 |
+
if provider not in auth_configs:
|
36 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
37 |
+
|
38 |
+
auth_config = auth_configs[provider]
|
39 |
+
api_key = auth_config.get("api_key")
|
40 |
+
|
41 |
+
if not api_key:
|
42 |
+
logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
|
43 |
+
auth_config["api_key"] = None
|
44 |
+
|
45 |
+
return auth_config
|
46 |
+
|
47 |
+
|
48 |
+
def process_content(content: str) -> str:
|
49 |
+
"""
|
50 |
+
Process and clean malformed content that may contain stringified nested lists.
|
51 |
+
The test DB on qdrant somehow got a bit malformed in the processing - but probably good to have this anyway
|
52 |
+
|
53 |
+
Args:
|
54 |
+
content: Raw content from vector store
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Cleaned, readable text content
|
58 |
+
"""
|
59 |
+
if not content:
|
60 |
+
return content
|
61 |
+
|
62 |
+
# Check if content looks like a stringified list/nested structure
|
63 |
+
content_stripped = content.strip()
|
64 |
+
if content_stripped.startswith('[') and content_stripped.endswith(']'):
|
65 |
+
try:
|
66 |
+
# Parse as literal list structure
|
67 |
+
parsed_content = ast.literal_eval(content_stripped)
|
68 |
+
|
69 |
+
if isinstance(parsed_content, list):
|
70 |
+
# Flatten nested lists and extract meaningful text
|
71 |
+
def extract_text_from_nested(obj):
|
72 |
+
if isinstance(obj, list):
|
73 |
+
text_items = []
|
74 |
+
for item in obj:
|
75 |
+
extracted = extract_text_from_nested(item)
|
76 |
+
if extracted and extracted.strip():
|
77 |
+
text_items.append(extracted)
|
78 |
+
return ' '.join(text_items)
|
79 |
+
elif isinstance(obj, str) and obj.strip():
|
80 |
+
return obj.strip()
|
81 |
+
elif isinstance(obj, dict):
|
82 |
+
# Handle dict structures if present
|
83 |
+
text_items = []
|
84 |
+
for key, value in obj.items():
|
85 |
+
if isinstance(value, str) and value.strip():
|
86 |
+
text_items.append(f"{key}: {value}")
|
87 |
+
return ' '.join(text_items)
|
88 |
+
else:
|
89 |
+
return ''
|
90 |
+
|
91 |
+
extracted_text = extract_text_from_nested(parsed_content)
|
92 |
+
|
93 |
+
if extracted_text and len(extracted_text.strip()) > 0:
|
94 |
+
# Clean up extra whitespace and format nicely
|
95 |
+
cleaned_text = re.sub(r'\s+', ' ', extracted_text).strip()
|
96 |
+
logging.debug(f"Successfully processed nested list content: {len(cleaned_text)} chars")
|
97 |
+
return cleaned_text
|
98 |
+
else:
|
99 |
+
logging.warning("Parsed list content but no meaningful text found")
|
100 |
+
return content # Return original if no meaningful text extracted
|
101 |
+
|
102 |
+
except (ValueError, SyntaxError) as e:
|
103 |
+
logging.debug(f"Content looks like list but failed to parse: {e}")
|
104 |
+
# Fall through to return original content
|
105 |
+
|
106 |
+
# For regular text content, just clean up whitespace
|
107 |
+
return re.sub(r'\s+', ' ', content).strip()
|
utils/vectorstore_interface.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
from gradio_client import Client
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from .utils import get_auth, process_content
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
|
12 |
+
class VectorStoreInterface(ABC):
|
13 |
+
"""Abstract interface for different vector store implementations."""
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
17 |
+
"""Search for similar documents."""
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
class HuggingFaceSpacesVectorStore(VectorStoreInterface):
|
22 |
+
"""Vector store implementation for Hugging Face Spaces with MCP endpoints."""
|
23 |
+
|
24 |
+
def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
|
25 |
+
repo_id = url
|
26 |
+
|
27 |
+
logging.info(f"Connecting to Hugging Face Space: {repo_id}")
|
28 |
+
|
29 |
+
if api_key:
|
30 |
+
self.client = Client(repo_id, hf_token=api_key)
|
31 |
+
else:
|
32 |
+
self.client = Client(repo_id)
|
33 |
+
|
34 |
+
self.collection_name = collection_name
|
35 |
+
|
36 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
37 |
+
"""Search using Hugging Face Spaces MCP API."""
|
38 |
+
try:
|
39 |
+
# Use the /search_text endpoint as documented in the API
|
40 |
+
result = self.client.predict(
|
41 |
+
query=query,
|
42 |
+
collection_name=self.collection_name,
|
43 |
+
model_name=kwargs.get('model_name'),
|
44 |
+
top_k=top_k,
|
45 |
+
api_name="/search_text"
|
46 |
+
)
|
47 |
+
|
48 |
+
logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
|
49 |
+
return result
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
|
53 |
+
raise e
|
54 |
+
|
55 |
+
class QdrantVectorStore(VectorStoreInterface):
|
56 |
+
"""Vector store implementation for direct Qdrant connection."""
|
57 |
+
|
58 |
+
def __init__(self, url: str, collection_name: str, api_key: Optional[str] = None):
|
59 |
+
from qdrant_client import QdrantClient
|
60 |
+
from sentence_transformers import SentenceTransformer
|
61 |
+
|
62 |
+
self.client = QdrantClient(host = url,
|
63 |
+
# very important that port to be used for python client
|
64 |
+
port=443,
|
65 |
+
https=True,
|
66 |
+
# api_key = QDRANT_API_KEY_READ,
|
67 |
+
## this is for write access
|
68 |
+
api_key = api_key,
|
69 |
+
timeout=120,)
|
70 |
+
|
71 |
+
#self.client = QdrantClient(
|
72 |
+
# url=url, # Use url parameter which handles full URLs with protocol
|
73 |
+
# api_key=api_key
|
74 |
+
#)
|
75 |
+
|
76 |
+
self.collection_name = collection_name
|
77 |
+
# Initialize embedding model as None - will be loaded on first use
|
78 |
+
self._embedding_model = None
|
79 |
+
self._current_model_name = None
|
80 |
+
|
81 |
+
def _get_embedding_model(self, model_name: str = None):
|
82 |
+
"""Lazy load embedding model to avoid loading if not needed."""
|
83 |
+
if model_name is None:
|
84 |
+
model_name = "BAAI/bge-m3" # Default from config
|
85 |
+
|
86 |
+
# Only reload if model name changed
|
87 |
+
if self._embedding_model is None or self._current_model_name != model_name:
|
88 |
+
logging.info(f"Loading embedding model: {model_name}")
|
89 |
+
from sentence_transformers import SentenceTransformer
|
90 |
+
|
91 |
+
cache_folder = Path(os.getenv("HF_HUB_CACHE", "/tmp/hf_cache"))
|
92 |
+
cache_folder.mkdir(parents=True, exist_ok=True)
|
93 |
+
|
94 |
+
self._embedding_model = SentenceTransformer(
|
95 |
+
model_name,
|
96 |
+
cache_folder=str(cache_folder)
|
97 |
+
)
|
98 |
+
# self._embedding_model = SentenceTransformer(model_name)
|
99 |
+
self._current_model_name = model_name
|
100 |
+
logging.info(f"Successfully loaded embedding model: {model_name}")
|
101 |
+
|
102 |
+
return self._embedding_model
|
103 |
+
|
104 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
105 |
+
"""Search using direct Qdrant connection."""
|
106 |
+
try:
|
107 |
+
# Get embedding model
|
108 |
+
model_name = kwargs.get('model_name')
|
109 |
+
embedding_model = self._get_embedding_model(model_name)
|
110 |
+
|
111 |
+
# Convert query to embedding
|
112 |
+
logging.info(f"Converting query to embedding using model: {self._current_model_name}")
|
113 |
+
query_embedding = embedding_model.encode(query).tolist()
|
114 |
+
|
115 |
+
# Get filter from kwargs if provided
|
116 |
+
filter_obj = kwargs.get('filter', None)
|
117 |
+
|
118 |
+
# Perform vector search
|
119 |
+
logging.info(f"Searching Qdrant collection '{self.collection_name}' for top {top_k} results")
|
120 |
+
search_result = self.client.search(
|
121 |
+
collection_name=self.collection_name,
|
122 |
+
query_vector=query_embedding,
|
123 |
+
query_filter=filter_obj, # Add filter support
|
124 |
+
limit=top_k,
|
125 |
+
with_payload=True,
|
126 |
+
with_vectors=False
|
127 |
+
)
|
128 |
+
logging.info(search_result)
|
129 |
+
# Format results to match expected output format
|
130 |
+
results = []
|
131 |
+
for hit in search_result:
|
132 |
+
raw_content = hit.payload.get('text', '')
|
133 |
+
# Process content to handle malformed nested list structures
|
134 |
+
processed_content = process_content(raw_content)
|
135 |
+
|
136 |
+
result_dict = {
|
137 |
+
'answer': processed_content,
|
138 |
+
'answer_metadata': hit.payload.get('metadata', {}),
|
139 |
+
'score': hit.score
|
140 |
+
}
|
141 |
+
results.append(result_dict)
|
142 |
+
|
143 |
+
logging.info(f"Successfully retrieved {len(results)} documents from Qdrant")
|
144 |
+
return results
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
logging.error(f"Error searching Qdrant: {str(e)}")
|
148 |
+
raise e
|
149 |
+
|
150 |
+
def create_vectorstore(config: Any) -> VectorStoreInterface:
|
151 |
+
"""Factory function to create appropriate vector store based on configuration."""
|
152 |
+
vectorstore_type = config.get("vectorstore", "PROVIDER")
|
153 |
+
|
154 |
+
# Get authentication config based on provider
|
155 |
+
auth_config = get_auth(vectorstore_type.lower())
|
156 |
+
|
157 |
+
if vectorstore_type.lower() == "huggingface":
|
158 |
+
url = config.get("vectorstore", "URL")
|
159 |
+
collection_name = config.get("vectorstore", "COLLECTION_NAME")
|
160 |
+
api_key = auth_config["api_key"]
|
161 |
+
return HuggingFaceSpacesVectorStore(url, collection_name, api_key)
|
162 |
+
|
163 |
+
elif vectorstore_type.lower() == "qdrant":
|
164 |
+
url = config.get("vectorstore", "URL") # Use the full URL
|
165 |
+
collection_name = config.get("vectorstore", "COLLECTION_NAME")
|
166 |
+
api_key = auth_config["api_key"]
|
167 |
+
# Remove port parameter since it's included in the URL
|
168 |
+
return QdrantVectorStore(url, collection_name, api_key)
|
169 |
+
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Unsupported vector store type: {vectorstore_type}")
|