File size: 3,754 Bytes
be6ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e06373
be6ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e06373
be6ede2
 
 
4e06373
be6ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import uuid
import chromadb
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import gradio as gr

# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Initialize ChromaDB client and collection
chroma_client = chromadb.PersistentClient(path="./chroma_db")
vectorstore = Chroma(
    client=chroma_client,
    collection_name="text_collection",
    embedding_function=embedding_model,
)

# Initialize reranker
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=reranker, top_n=5)
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})  # Retrieve 2k initially
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

def add_text_to_db(text):
    """
    Add a piece of text to the vector database.
    
    Args:
        text (str): The text to add.
        
    Returns:
        str: Confirmation message.
    """
    if not text or not text.strip():
        return "Error: Text cannot be empty."
    
    # Generate unique ID
    doc_id = str(uuid.uuid4())
    
    # Add text to vectorstore
    vectorstore.add_texts(
        texts=[text],
        metadatas=[{"text": text}],
        ids=[doc_id]
    )
    
    return f"Text added successfully with ID: {doc_id}"

def search_similar_texts(query, k):
    """
    Search for the top k similar texts in the vector database and rerank them.
    
    Args:
        query (str): The search query.
        k (int): Number of results to return.
        
    Returns:
        str: Formatted search results without similarity scores.
    """
    if not query or not query.strip():
        return "Error: Query cannot be empty."
    
    if not isinstance(k, int) or k < 1:
        return "Error: k must be a positive integer."
    
    # Retrieve and rerank
    retriever.search_kwargs["k"] = max(k * 2, 10)  # Retrieve 2k or at least 10
    compressor.top_n = k  # Rerank to top k
    docs = compression_retriever.get_relevant_documents(query)
    
    if not docs:
        return "No results found."
    
    # Format results without similarity scores
    results = []
    for i, doc in enumerate(docs[:k]):  # Ensure we return at most k
        text = doc.metadata.get("text", "No text available")
        results.append(f"Result {i+1}:\nText: {text}\n")
    
    return "\n".join(results) or "No results found."

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Semantic Search Pipeline")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("## Add Text to Database")
            text_input = gr.Textbox(label="Enter text to add")
            add_button = gr.Button("Add Text")
            add_output = gr.Textbox(label="Result")
            
        with gr.Column():
            gr.Markdown("## Search Similar Texts")
            query_input = gr.Textbox(label="Enter search query")
            k_input = gr.Number(label="Number of results (k)", value=5, precision=0)
            search_button = gr.Button("Search")
            search_output = gr.Textbox(label="Search Results")
    
    # Button actions
    add_button.click(
        fn=add_text_to_db,
        inputs=text_input,
        outputs=add_output
    )
    search_button.click(
        fn=search_similar_texts,
        inputs=[query_input, k_input],
        outputs=search_output
    )

# Launch Gradio app
if __name__ == "__main__":
    demo.launch()