Spaces:
Running
Running
import uuid | |
import chromadb | |
import torch | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
import gradio as gr | |
# Set device to GPU if available, else CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Initialize embedding model | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": device} | |
) | |
# Initialize ChromaDB client and collection | |
chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
vectorstore = Chroma( | |
client=chroma_client, | |
collection_name="text_collection", | |
embedding_function=embedding_model, | |
) | |
# Initialize reranker | |
reranker = HuggingFaceCrossEncoder( | |
model_name="BAAI/bge-reranker-base", | |
model_kwargs={"device": device} | |
) | |
def add_text_to_db(text): | |
""" | |
Add a piece of text to the vector database. | |
Args: | |
text (str): The text to add. | |
Returns: | |
str: Confirmation message. | |
""" | |
if not text or not text.strip(): | |
return "Error: Text cannot be empty." | |
# Generate unique ID | |
doc_id = str(uuid.uuid4()) | |
# Add text to vectorstore | |
vectorstore.add_texts( | |
texts=[text], | |
metadatas=[{"text": text}], | |
ids=[doc_id] | |
) | |
return f"Text added successfully with ID: {doc_id}" | |
def search_similar_texts(query, k, threshold): | |
""" | |
Search for the top k similar texts in the vector database and rerank them. | |
Only return results with similarity scores above the threshold. | |
Args: | |
query (str): The search query. | |
k (int): Number of results to return. | |
threshold (float): Minimum similarity score (0 to 1). | |
Returns: | |
str: Formatted search results with similarity scores or "No such record". | |
""" | |
if not query or not query.strip(): | |
return "Error: Query cannot be empty." | |
if not isinstance(k, int) or k < 1: | |
return "Error: k must be a positive integer." | |
if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1: | |
return "Error: Threshold must be a number between 0 and 1." | |
# Retrieve initial documents | |
retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)}) | |
docs = retriever.get_relevant_documents(query) | |
if not docs: | |
return "No such record." | |
# Compute reranker scores | |
scored_docs = [] | |
for doc in docs: | |
text = doc.metadata.get("text", "No text available") | |
# Compute score using reranker | |
score = reranker.score([query, text]) | |
doc.metadata["score"] = float(score) | |
scored_docs.append((doc, score)) | |
# Sort by score in descending order | |
scored_docs.sort(key=lambda x: x[1], reverse=True) | |
# Filter by threshold and limit to k | |
results = [] | |
for i, (doc, score) in enumerate(scored_docs[:k]): | |
if score >= threshold: | |
text = doc.metadata.get("text", "No text available") | |
results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n") | |
if not results: | |
return "No such record." | |
return "\n".join(results) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Semantic Search Pipeline") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Add Text to Database") | |
text_input = gr.Textbox(label="Enter text to add") | |
add_button = gr.Button("Add Text") | |
add_output = gr.Textbox(label="Result") | |
with gr.Column(): | |
gr.Markdown("## Search Similar Texts") | |
query_input = gr.Textbox(label="Enter search query") | |
k_input = gr.Number(label="Number of results (k)", value=5, precision=0) | |
threshold_input = gr.Number(label="Similarity threshold (0 to 1)", value=0.5, minimum=0, maximum=1) | |
search_button = gr.Button("Search") | |
search_output = gr.Textbox(label="Search Results") | |
# Button actions | |
add_button.click( | |
fn=add_text_to_db, | |
inputs=text_input, | |
outputs=add_output | |
) | |
search_button.click( | |
fn=search_similar_texts, | |
inputs=[query_input, k_input, threshold_input], | |
outputs=search_output | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
demo.launch() |