Spaces:
Running
Running
File size: 4,458 Bytes
be6ede2 ac90524 be6ede2 ac90524 be6ede2 ac90524 be6ede2 ac90524 be6ede2 cee4fc8 be6ede2 cee4fc8 be6ede2 cee4fc8 be6ede2 cee4fc8 be6ede2 cee4fc8 16bd142 be6ede2 cee4fc8 be6ede2 16bd142 be6ede2 6325e03 f27158e 16bd142 cee4fc8 16bd142 cee4fc8 be6ede2 cee4fc8 be6ede2 cee4fc8 be6ede2 cee4fc8 be6ede2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import uuid
import chromadb
import torch
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import gradio as gr
# Set device to GPU if available, else CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": device}
)
# Initialize ChromaDB client and collection
chroma_client = chromadb.PersistentClient(path="./chroma_db")
vectorstore = Chroma(
client=chroma_client,
collection_name="text_collection",
embedding_function=embedding_model,
)
# Initialize reranker
reranker = HuggingFaceCrossEncoder(
model_name="BAAI/bge-reranker-base",
model_kwargs={"device": device}
)
def add_text_to_db(text):
"""
Add a piece of text to the vector database.
Args:
text (str): The text to add.
Returns:
str: Confirmation message.
"""
if not text or not text.strip():
return "Error: Text cannot be empty."
# Generate unique ID
doc_id = str(uuid.uuid4())
# Add text to vectorstore
vectorstore.add_texts(
texts=[text],
metadatas=[{"text": text}],
ids=[doc_id]
)
return f"Text added successfully with ID: {doc_id}"
def search_similar_texts(query, k, threshold):
"""
Search for the top k similar texts in the vector database and rerank them.
Only return results with similarity scores above the threshold.
Args:
query (str): The search query.
k (int): Number of results to return.
threshold (float): Minimum similarity score (0 to 1).
Returns:
str: Formatted search results with similarity scores or "No such record".
"""
if not query or not query.strip():
return "Error: Query cannot be empty."
if not isinstance(k, int) or k < 1:
return "Error: k must be a positive integer."
if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1:
return "Error: Threshold must be a number between 0 and 1."
# Retrieve initial documents
retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)})
docs = retriever.get_relevant_documents(query)
if not docs:
return "No such record."
# Compute reranker scores
scored_docs = []
for doc in docs:
text = doc.metadata.get("text", "No text available")
# Compute score using reranker
score = reranker.score([query, text])
doc.metadata["score"] = float(score)
scored_docs.append((doc, score))
# Sort by score in descending order
scored_docs.sort(key=lambda x: x[1], reverse=True)
# Filter by threshold and limit to k
results = []
for i, (doc, score) in enumerate(scored_docs[:k]):
if score >= threshold:
text = doc.metadata.get("text", "No text available")
results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n")
if not results:
return "No such record."
return "\n".join(results)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Semantic Search Pipeline")
with gr.Row():
with gr.Column():
gr.Markdown("## Add Text to Database")
text_input = gr.Textbox(label="Enter text to add")
add_button = gr.Button("Add Text")
add_output = gr.Textbox(label="Result")
with gr.Column():
gr.Markdown("## Search Similar Texts")
query_input = gr.Textbox(label="Enter search query")
k_input = gr.Number(label="Number of results (k)", value=5, precision=0)
threshold_input = gr.Number(label="Similarity threshold (0 to 1)", value=0.5, minimum=0, maximum=1)
search_button = gr.Button("Search")
search_output = gr.Textbox(label="Search Results")
# Button actions
add_button.click(
fn=add_text_to_db,
inputs=text_input,
outputs=add_output
)
search_button.click(
fn=search_similar_texts,
inputs=[query_input, k_input, threshold_input],
outputs=search_output
)
# Launch Gradio app
if __name__ == "__main__":
demo.launch() |