Semantic-Search / app.py
Y-Mangoes's picture
Update app.py
4e06373 verified
raw
history blame
3.75 kB
import uuid
import chromadb
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import gradio as gr
# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Initialize ChromaDB client and collection
chroma_client = chromadb.PersistentClient(path="./chroma_db")
vectorstore = Chroma(
client=chroma_client,
collection_name="text_collection",
embedding_function=embedding_model,
)
# Initialize reranker
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=reranker, top_n=5)
retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
def add_text_to_db(text):
"""
Add a piece of text to the vector database.
Args:
text (str): The text to add.
Returns:
str: Confirmation message.
"""
if not text or not text.strip():
return "Error: Text cannot be empty."
# Generate unique ID
doc_id = str(uuid.uuid4())
# Add text to vectorstore
vectorstore.add_texts(
texts=[text],
metadatas=[{"text": text}],
ids=[doc_id]
)
return f"Text added successfully with ID: {doc_id}"
def search_similar_texts(query, k):
"""
Search for the top k similar texts in the vector database and rerank them.
Args:
query (str): The search query.
k (int): Number of results to return.
Returns:
str: Formatted search results without similarity scores.
"""
if not query or not query.strip():
return "Error: Query cannot be empty."
if not isinstance(k, int) or k < 1:
return "Error: k must be a positive integer."
# Retrieve and rerank
retriever.search_kwargs["k"] = max(k * 2, 10) # Retrieve 2k or at least 10
compressor.top_n = k # Rerank to top k
docs = compression_retriever.get_relevant_documents(query)
if not docs:
return "No results found."
# Format results without similarity scores
results = []
for i, doc in enumerate(docs[:k]): # Ensure we return at most k
text = doc.metadata.get("text", "No text available")
results.append(f"Result {i+1}:\nText: {text}\n")
return "\n".join(results) or "No results found."
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Semantic Search Pipeline")
with gr.Row():
with gr.Column():
gr.Markdown("## Add Text to Database")
text_input = gr.Textbox(label="Enter text to add")
add_button = gr.Button("Add Text")
add_output = gr.Textbox(label="Result")
with gr.Column():
gr.Markdown("## Search Similar Texts")
query_input = gr.Textbox(label="Enter search query")
k_input = gr.Number(label="Number of results (k)", value=5, precision=0)
search_button = gr.Button("Search")
search_output = gr.Textbox(label="Search Results")
# Button actions
add_button.click(
fn=add_text_to_db,
inputs=text_input,
outputs=add_output
)
search_button.click(
fn=search_similar_texts,
inputs=[query_input, k_input],
outputs=search_output
)
# Launch Gradio app
if __name__ == "__main__":
demo.launch()