Spaces:
Running
Running
import uuid | |
import chromadb | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
import gradio as gr | |
# Initialize embedding model | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Initialize ChromaDB client and collection | |
chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
vectorstore = Chroma( | |
client=chroma_client, | |
collection_name="text_collection", | |
embedding_function=embedding_model, | |
) | |
# Initialize reranker | |
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | |
compressor = CrossEncoderReranker(model=reranker, top_n=5) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, base_retriever=retriever | |
) | |
def add_text_to_db(text): | |
""" | |
Add a piece of text to the vector database. | |
Args: | |
text (str): The text to add. | |
Returns: | |
str: Confirmation message. | |
""" | |
if not text or not text.strip(): | |
return "Error: Text cannot be empty." | |
# Generate unique ID | |
doc_id = str(uuid.uuid4()) | |
# Add text to vectorstore | |
vectorstore.add_texts( | |
texts=[text], | |
metadatas=[{"text": text}], | |
ids=[doc_id] | |
) | |
return f"Text added successfully with ID: {doc_id}" | |
def search_similar_texts(query, k): | |
""" | |
Search for the top k similar texts in the vector database and rerank them. | |
Args: | |
query (str): The search query. | |
k (int): Number of results to return. | |
Returns: | |
str: Formatted search results without similarity scores. | |
""" | |
if not query or not query.strip(): | |
return "Error: Query cannot be empty." | |
if not isinstance(k, int) or k < 1: | |
return "Error: k must be a positive integer." | |
# Retrieve and rerank | |
retriever.search_kwargs["k"] = max(k * 2, 10) # Retrieve 2k or at least 10 | |
compressor.top_n = k # Rerank to top k | |
docs = compression_retriever.get_relevant_documents(query) | |
if not docs: | |
return "No results found." | |
# Format results without similarity scores | |
results = [] | |
for i, doc in enumerate(docs[:k]): # Ensure we return at most k | |
text = doc.metadata.get("text", "No text available") | |
results.append(f"Result {i+1}:\nText: {text}\n") | |
return "\n".join(results) or "No results found." | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Semantic Search Pipeline") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Add Text to Database") | |
text_input = gr.Textbox(label="Enter text to add") | |
add_button = gr.Button("Add Text") | |
add_output = gr.Textbox(label="Result") | |
with gr.Column(): | |
gr.Markdown("## Search Similar Texts") | |
query_input = gr.Textbox(label="Enter search query") | |
k_input = gr.Number(label="Number of results (k)", value=5, precision=0) | |
search_button = gr.Button("Search") | |
search_output = gr.Textbox(label="Search Results") | |
# Button actions | |
add_button.click( | |
fn=add_text_to_db, | |
inputs=text_input, | |
outputs=add_output | |
) | |
search_button.click( | |
fn=search_similar_texts, | |
inputs=[query_input, k_input], | |
outputs=search_output | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
demo.launch() |