Spaces:
Sleeping
Sleeping
Update rag_engine.py
Browse files- rag_engine.py +9 -9
rag_engine.py
CHANGED
@@ -232,7 +232,7 @@ def get_embedding(text):
|
|
232 |
return np.zeros((1, 384), dtype=np.float32)
|
233 |
|
234 |
@st.cache_data(ttl=900)
|
235 |
-
def retrieve_passages(query,
|
236 |
"""Retrieve top-k most relevant passages using FAISS with metadata."""
|
237 |
try:
|
238 |
print(f"\n🔍 Retrieving passages for query: {query}")
|
@@ -241,7 +241,7 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
|
|
241 |
query_embedding = get_embedding(query)
|
242 |
|
243 |
# Search in FAISS index
|
244 |
-
distances, indices =
|
245 |
|
246 |
print(f"Found {len(distances[0])} potential matches")
|
247 |
retrieved_passages = []
|
@@ -251,8 +251,8 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
|
|
251 |
# Process results
|
252 |
for dist, idx in zip(distances[0], indices[0]):
|
253 |
print(f"Distance: {dist:.4f}, Index: {idx}")
|
254 |
-
if idx in
|
255 |
-
title_with_txt, author, text =
|
256 |
|
257 |
# Clean title
|
258 |
clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
|
@@ -263,7 +263,7 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
|
|
263 |
continue
|
264 |
|
265 |
# Get metadata
|
266 |
-
metadata_entry =
|
267 |
author = metadata_entry.get("Author", "Unknown")
|
268 |
publisher = metadata_entry.get("Publisher", "Unknown")
|
269 |
|
@@ -389,9 +389,9 @@ def process_query(query, top_k=5, word_limit=100):
|
|
389 |
# Get relevant passages
|
390 |
retrieved_context, retrieved_sources = retrieve_passages(
|
391 |
query,
|
392 |
-
faiss_index,
|
393 |
-
text_chunks,
|
394 |
-
metadata_dict,
|
395 |
top_k=top_k
|
396 |
)
|
397 |
|
@@ -405,7 +405,7 @@ def process_query(query, top_k=5, word_limit=100):
|
|
405 |
else:
|
406 |
llm_answer_with_rag = "⚠️ No relevant context found."
|
407 |
|
408 |
-
# Clean up
|
409 |
del retrieved_context, retrieved_sources
|
410 |
gc.collect()
|
411 |
|
|
|
232 |
return np.zeros((1, 384), dtype=np.float32)
|
233 |
|
234 |
@st.cache_data(ttl=900)
|
235 |
+
def retrieve_passages(query, _faiss_index, _text_chunks, _metadata_dict, top_k=5, similarity_threshold=0.5):
|
236 |
"""Retrieve top-k most relevant passages using FAISS with metadata."""
|
237 |
try:
|
238 |
print(f"\n🔍 Retrieving passages for query: {query}")
|
|
|
241 |
query_embedding = get_embedding(query)
|
242 |
|
243 |
# Search in FAISS index
|
244 |
+
distances, indices = _faiss_index.search(query_embedding, top_k * 2)
|
245 |
|
246 |
print(f"Found {len(distances[0])} potential matches")
|
247 |
retrieved_passages = []
|
|
|
251 |
# Process results
|
252 |
for dist, idx in zip(distances[0], indices[0]):
|
253 |
print(f"Distance: {dist:.4f}, Index: {idx}")
|
254 |
+
if idx in _text_chunks and dist >= similarity_threshold:
|
255 |
+
title_with_txt, author, text = _text_chunks[idx]
|
256 |
|
257 |
# Clean title
|
258 |
clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
|
|
|
263 |
continue
|
264 |
|
265 |
# Get metadata
|
266 |
+
metadata_entry = _metadata_dict.get(clean_title, {})
|
267 |
author = metadata_entry.get("Author", "Unknown")
|
268 |
publisher = metadata_entry.get("Publisher", "Unknown")
|
269 |
|
|
|
389 |
# Get relevant passages
|
390 |
retrieved_context, retrieved_sources = retrieve_passages(
|
391 |
query,
|
392 |
+
_faiss_index=faiss_index,
|
393 |
+
_text_chunks=text_chunks,
|
394 |
+
_metadata_dict=metadata_dict,
|
395 |
top_k=top_k
|
396 |
)
|
397 |
|
|
|
405 |
else:
|
406 |
llm_answer_with_rag = "⚠️ No relevant context found."
|
407 |
|
408 |
+
# Clean up
|
409 |
del retrieved_context, retrieved_sources
|
410 |
gc.collect()
|
411 |
|