File size: 4,458 Bytes
be6ede2
 
ac90524
be6ede2
 
 
 
 
ac90524
 
 
 
be6ede2
ac90524
 
 
 
be6ede2
 
 
 
 
 
 
 
 
 
ac90524
 
 
 
be6ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cee4fc8
be6ede2
 
cee4fc8
be6ede2
 
 
 
cee4fc8
be6ede2
 
cee4fc8
be6ede2
 
 
 
 
 
 
cee4fc8
 
 
16bd142
 
 
be6ede2
 
cee4fc8
be6ede2
16bd142
 
 
be6ede2
6325e03
f27158e
16bd142
 
 
 
 
 
 
 
 
cee4fc8
16bd142
cee4fc8
 
 
 
be6ede2
cee4fc8
be6ede2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cee4fc8
be6ede2
 
 
 
 
 
 
 
 
 
 
cee4fc8
be6ede2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import uuid
import chromadb
import torch
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import gradio as gr

# Set device to GPU if available, else CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={"device": device}
)

# Initialize ChromaDB client and collection
chroma_client = chromadb.PersistentClient(path="./chroma_db")
vectorstore = Chroma(
    client=chroma_client,
    collection_name="text_collection",
    embedding_function=embedding_model,
)

# Initialize reranker
reranker = HuggingFaceCrossEncoder(
    model_name="BAAI/bge-reranker-base",
    model_kwargs={"device": device}
)

def add_text_to_db(text):
    """
    Add a piece of text to the vector database.
    
    Args:
        text (str): The text to add.
        
    Returns:
        str: Confirmation message.
    """
    if not text or not text.strip():
        return "Error: Text cannot be empty."
    
    # Generate unique ID
    doc_id = str(uuid.uuid4())
    
    # Add text to vectorstore
    vectorstore.add_texts(
        texts=[text],
        metadatas=[{"text": text}],
        ids=[doc_id]
    )
    
    return f"Text added successfully with ID: {doc_id}"

def search_similar_texts(query, k, threshold):
    """
    Search for the top k similar texts in the vector database and rerank them.
    Only return results with similarity scores above the threshold.
    
    Args:
        query (str): The search query.
        k (int): Number of results to return.
        threshold (float): Minimum similarity score (0 to 1).
        
    Returns:
        str: Formatted search results with similarity scores or "No such record".
    """
    if not query or not query.strip():
        return "Error: Query cannot be empty."
    
    if not isinstance(k, int) or k < 1:
        return "Error: k must be a positive integer."
    
    if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1:
        return "Error: Threshold must be a number between 0 and 1."
    
    # Retrieve initial documents
    retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)})
    docs = retriever.get_relevant_documents(query)
    
    if not docs:
        return "No such record."
    
    # Compute reranker scores
    scored_docs = []
    for doc in docs:
        text = doc.metadata.get("text", "No text available")
        # Compute score using reranker
        score = reranker.score([query, text])
        doc.metadata["score"] = float(score)
        scored_docs.append((doc, score))
    
    # Sort by score in descending order
    scored_docs.sort(key=lambda x: x[1], reverse=True)
    
    # Filter by threshold and limit to k
    results = []
    for i, (doc, score) in enumerate(scored_docs[:k]):
        if score >= threshold:
            text = doc.metadata.get("text", "No text available")
            results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n")
    
    if not results:
        return "No such record."
    
    return "\n".join(results)

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Semantic Search Pipeline")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("## Add Text to Database")
            text_input = gr.Textbox(label="Enter text to add")
            add_button = gr.Button("Add Text")
            add_output = gr.Textbox(label="Result")
            
        with gr.Column():
            gr.Markdown("## Search Similar Texts")
            query_input = gr.Textbox(label="Enter search query")
            k_input = gr.Number(label="Number of results (k)", value=5, precision=0)
            threshold_input = gr.Number(label="Similarity threshold (0 to 1)", value=0.5, minimum=0, maximum=1)
            search_button = gr.Button("Search")
            search_output = gr.Textbox(label="Search Results")
    
    # Button actions
    add_button.click(
        fn=add_text_to_db,
        inputs=text_input,
        outputs=add_output
    )
    search_button.click(
        fn=search_similar_texts,
        inputs=[query_input, k_input, threshold_input],
        outputs=search_output
    )

# Launch Gradio app
if __name__ == "__main__":
    demo.launch()