import uuid import chromadb import torch from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings from langchain_community.cross_encoders import HuggingFaceCrossEncoder import gradio as gr # Set device to GPU if available, else CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Initialize embedding model embedding_model = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": device} ) # Initialize ChromaDB client and collection chroma_client = chromadb.PersistentClient(path="./chroma_db") vectorstore = Chroma( client=chroma_client, collection_name="text_collection", embedding_function=embedding_model, ) # Initialize reranker reranker = HuggingFaceCrossEncoder( model_name="BAAI/bge-reranker-base", model_kwargs={"device": device} ) def add_text_to_db(text): """ Add a piece of text to the vector database. Args: text (str): The text to add. Returns: str: Confirmation message. """ if not text or not text.strip(): return "Error: Text cannot be empty." # Generate unique ID doc_id = str(uuid.uuid4()) # Add text to vectorstore vectorstore.add_texts( texts=[text], metadatas=[{"text": text}], ids=[doc_id] ) return f"Text added successfully with ID: {doc_id}" def search_similar_texts(query, k, threshold): """ Search for the top k similar texts in the vector database and rerank them. Only return results with similarity scores above the threshold. Args: query (str): The search query. k (int): Number of results to return. threshold (float): Minimum similarity score (0 to 1). Returns: str: Formatted search results with similarity scores or "No such record". """ if not query or not query.strip(): return "Error: Query cannot be empty." if not isinstance(k, int) or k < 1: return "Error: k must be a positive integer." if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1: return "Error: Threshold must be a number between 0 and 1." # Retrieve initial documents retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)}) docs = retriever.get_relevant_documents(query) if not docs: return "No such record." # Compute reranker scores scored_docs = [] for doc in docs: text = doc.metadata.get("text", "No text available") # Compute score using reranker score = reranker.score([query, text]) doc.metadata["score"] = float(score) scored_docs.append((doc, score)) # Sort by score in descending order scored_docs.sort(key=lambda x: x[1], reverse=True) # Filter by threshold and limit to k results = [] for i, (doc, score) in enumerate(scored_docs[:k]): if score >= threshold: text = doc.metadata.get("text", "No text available") results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n") if not results: return "No such record." return "\n".join(results) # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Semantic Search Pipeline") with gr.Row(): with gr.Column(): gr.Markdown("## Add Text to Database") text_input = gr.Textbox(label="Enter text to add") add_button = gr.Button("Add Text") add_output = gr.Textbox(label="Result") with gr.Column(): gr.Markdown("## Search Similar Texts") query_input = gr.Textbox(label="Enter search query") k_input = gr.Number(label="Number of results (k)", value=5, precision=0) threshold_input = gr.Number(label="Similarity threshold (0 to 1)", value=0.5, minimum=0, maximum=1) search_button = gr.Button("Search") search_output = gr.Textbox(label="Search Results") # Button actions add_button.click( fn=add_text_to_db, inputs=text_input, outputs=add_output ) search_button.click( fn=search_similar_texts, inputs=[query_input, k_input, threshold_input], outputs=search_output ) # Launch Gradio app if __name__ == "__main__": demo.launch()