Semantic-Search / app.py
Y-Mangoes's picture
Update app.py
f27158e verified
import uuid
import chromadb
import torch
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import gradio as gr
# Set device to GPU if available, else CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": device}
)
# Initialize ChromaDB client and collection
chroma_client = chromadb.PersistentClient(path="./chroma_db")
vectorstore = Chroma(
client=chroma_client,
collection_name="text_collection",
embedding_function=embedding_model,
)
# Initialize reranker
reranker = HuggingFaceCrossEncoder(
model_name="BAAI/bge-reranker-base",
model_kwargs={"device": device}
)
def add_text_to_db(text):
"""
Add a piece of text to the vector database.
Args:
text (str): The text to add.
Returns:
str: Confirmation message.
"""
if not text or not text.strip():
return "Error: Text cannot be empty."
# Generate unique ID
doc_id = str(uuid.uuid4())
# Add text to vectorstore
vectorstore.add_texts(
texts=[text],
metadatas=[{"text": text}],
ids=[doc_id]
)
return f"Text added successfully with ID: {doc_id}"
def search_similar_texts(query, k, threshold):
"""
Search for the top k similar texts in the vector database and rerank them.
Only return results with similarity scores above the threshold.
Args:
query (str): The search query.
k (int): Number of results to return.
threshold (float): Minimum similarity score (0 to 1).
Returns:
str: Formatted search results with similarity scores or "No such record".
"""
if not query or not query.strip():
return "Error: Query cannot be empty."
if not isinstance(k, int) or k < 1:
return "Error: k must be a positive integer."
if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1:
return "Error: Threshold must be a number between 0 and 1."
# Retrieve initial documents
retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)})
docs = retriever.get_relevant_documents(query)
if not docs:
return "No such record."
# Compute reranker scores
scored_docs = []
for doc in docs:
text = doc.metadata.get("text", "No text available")
# Compute score using reranker
score = reranker.score([query, text])
doc.metadata["score"] = float(score)
scored_docs.append((doc, score))
# Sort by score in descending order
scored_docs.sort(key=lambda x: x[1], reverse=True)
# Filter by threshold and limit to k
results = []
for i, (doc, score) in enumerate(scored_docs[:k]):
if score >= threshold:
text = doc.metadata.get("text", "No text available")
results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n")
if not results:
return "No such record."
return "\n".join(results)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Semantic Search Pipeline")
with gr.Row():
with gr.Column():
gr.Markdown("## Add Text to Database")
text_input = gr.Textbox(label="Enter text to add")
add_button = gr.Button("Add Text")
add_output = gr.Textbox(label="Result")
with gr.Column():
gr.Markdown("## Search Similar Texts")
query_input = gr.Textbox(label="Enter search query")
k_input = gr.Number(label="Number of results (k)", value=5, precision=0)
threshold_input = gr.Number(label="Similarity threshold (0 to 1)", value=0.5, minimum=0, maximum=1)
search_button = gr.Button("Search")
search_output = gr.Textbox(label="Search Results")
# Button actions
add_button.click(
fn=add_text_to_db,
inputs=text_input,
outputs=add_output
)
search_button.click(
fn=search_similar_texts,
inputs=[query_input, k_input, threshold_input],
outputs=search_output
)
# Launch Gradio app
if __name__ == "__main__":
demo.launch()