Spaces:
Running
on
T4
Running
on
T4
File size: 4,336 Bytes
137c471 ae874c6 137c471 d2c728b 712bf59 137c471 0ddb79b 137c471 0a928d9 137c471 d2c728b 137c471 d2c728b 137c471 d2c728b 137c471 d2c728b 137c471 712bf59 d2c728b 137c471 d2c728b 137c471 ae874c6 137c471 ae874c6 137c471 ae874c6 137c471 ae874c6 137c471 ae874c6 137c471 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import sys
from utils.retriever import get_context, get_vectorstore
# Initialize vector store at startup
print("Initializing vector store connection...", flush=True)
try:
vectorstore = get_vectorstore()
print("Vector store connection initialized successfully", flush=True)
except Exception as e:
print(f"Failed to initialize vector store: {e}", flush=True)
raise
# ---------------------------------------------------------------------
# MCP - returns raw dictionary format
# ---------------------------------------------------------------------
def create_metadata_dict(sources_input, sources_value):
"""
This helper function creates the filter dictionary based on user input.
"""
if sources_input and sources_value:
return {sources_input: sources_value}
return None
def retrieve(
query:str,
collection_name:str = None,
filter_metadata:dict = None,
) -> list:
"""
Retrieve semantically similar documents from the vector database for MCP clients.
Args:
query (str): The search query text
reports_filter (str): Comma-separated list of specific report filenames (optional)
sources_filter (str): Filter by document source type (optional)
subtype_filter (str): Filter by document subtype (optional)
year_filter (str): Comma-separated list of years to filter by (optional)
Returns:
list: List of dictionaries containing document content, metadata, and scores
"""
# Call retriever function and return raw results
results = get_context(
vectorstore=vectorstore,
query=query,
collection_name=collection_name,
filter_metadata = filter_metadata
)
return results
# Create the Gradio interface with Blocks to support both UI and MCP
with gr.Blocks() as ui:
gr.Markdown("# ChatFed Retrieval/Reranker Module")
gr.Markdown("Retrieves semantically similar documents from vector database and reranks. Intended for use in RAG pipelines as an MCP server with other ChatFed modules.")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
lines=2,
placeholder="Enter your search query here",
info="The query to search for in the vector database"
)
collection_name = gr.Textbox(
label="Collection Name (optional)",
lines=1,
placeholder="EUDR, Humboldt",
info="Name of the collection"
)
sources_input = gr.Textbox(
label="Sources Filter key to be looked in metadata (optional)",
lines=1,
placeholder="country",
info="Filter by document source type (leave empty for all)"
)
sources_value = gr.Textbox(
label="Value in filter to be looked for(optional)",
lines=1,
placeholder="Ecuador, Guatemala",
info="Filter by document subtype (leave empty for all)"
)
filter_metadata_state = gr.State(None)
submit_btn = gr.Button("Submit", variant="primary")
# Output needs to be in json format to be added as tool in HuggingChat
with gr.Column():
output = gr.Text(
label="Retrieved Context",
lines=10,
show_copy_button=True
)
# UI event handler
# submit_btn.click(
# fn=retrieve,
# inputs=[query_input,collection_name, filter_metadata],
# outputs=output,
# api_name="retrieve"
# )
submit_btn.click(
fn=create_metadata_dict,
inputs=[sources_input, sources_value],
outputs=filter_metadata_state,
queue=False
).then(
# 3. Use the updated state as an input to the retrieve function.
fn=retrieve,
inputs=[query_input, collection_name, filter_metadata_state],
outputs=output,
queue=False
)
# Launch with MCP server enabled
if __name__ == "__main__":
ui.launch(
server_name="0.0.0.0",
server_port=7860,
#mcp_server=True,
show_error=True
) |