Y-Mangoes commited on
Commit
16bd142
·
verified ·
1 Parent(s): cee4fc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -3,8 +3,6 @@ import chromadb
3
  import torch
4
  from langchain.vectorstores import Chroma
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.retrievers import ContextualCompressionRetriever
7
- from langchain.retrievers.document_compressors import CrossEncoderReranker
8
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
9
  import gradio as gr
10
 
@@ -31,11 +29,6 @@ reranker = HuggingFaceCrossEncoder(
31
  model_name="BAAI/bge-reranker-base",
32
  model_kwargs={"device": device}
33
  )
34
- compressor = CrossEncoderReranker(model=reranker, top_n=5)
35
- retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially
36
- compression_retriever = ContextualCompressionRetriever(
37
- base_compressor=compressor, base_retriever=retriever
38
- )
39
 
40
  def add_text_to_db(text):
41
  """
@@ -84,20 +77,30 @@ def search_similar_texts(query, k, threshold):
84
  if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1:
85
  return "Error: Threshold must be a number between 0 and 1."
86
 
87
- # Retrieve and rerank
88
- retriever.search_kwargs["k"] = max(k * 2, 10) # Retrieve 2k or at least 10
89
- compressor.top_n = k # Rerank to top k
90
- docs = compression_retriever.get_relevant_documents(query)
91
 
92
  if not docs:
93
  return "No such record."
94
 
95
- # Filter results by threshold
96
- results = []
97
- for i, doc in enumerate(docs[:k]): # Ensure at most k results
98
  text = doc.metadata.get("text", "No text available")
99
- score = doc.metadata.get("score", 0.0) # Reranker score
 
 
 
 
 
 
 
 
 
 
100
  if score >= threshold:
 
101
  results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n")
102
 
103
  if not results:
 
3
  import torch
4
  from langchain.vectorstores import Chroma
5
  from langchain.embeddings import HuggingFaceEmbeddings
 
 
6
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
7
  import gradio as gr
8
 
 
29
  model_name="BAAI/bge-reranker-base",
30
  model_kwargs={"device": device}
31
  )
 
 
 
 
 
32
 
33
  def add_text_to_db(text):
34
  """
 
77
  if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1:
78
  return "Error: Threshold must be a number between 0 and 1."
79
 
80
+ # Retrieve initial documents
81
+ retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)})
82
+ docs = retriever.get_relevant_documents(query)
 
83
 
84
  if not docs:
85
  return "No such record."
86
 
87
+ # Compute reranker scores
88
+ scored_docs = []
89
+ for doc in docs:
90
  text = doc.metadata.get("text", "No text available")
91
+ # Prepare input for reranker: list of [query, document] pairs
92
+ score = reranker.predict([[query, text]])[0]
93
+ doc.metadata["score"] = float(score)
94
+ scored_docs.append((doc, score))
95
+
96
+ # Sort by score in descending order
97
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
98
+
99
+ # Filter by threshold and limit to k
100
+ results = []
101
+ for i, (doc, score) in enumerate(scored_docs[:k]):
102
  if score >= threshold:
103
+ text = doc.metadata.get("text", "No text available")
104
  results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n")
105
 
106
  if not results: