Spaces:
Paused
Paused
# file: retriever.py | |
import faiss | |
import numpy as np | |
import torch | |
import re | |
from collections import defaultdict | |
from rank_bm25 import BM25Okapi | |
def tokenize_vi_for_bm25_setup(text): | |
"""Tokenize tiếng Việt đơn giản cho BM25.""" | |
text = text.lower() | |
text = re.sub(r'[^\w\s]', '', text) | |
return text.split() | |
def search_relevant_laws( | |
query_text, | |
embedding_model, | |
faiss_index, | |
chunks_data, | |
bm25_model, | |
k=5, | |
initial_k_multiplier=10, | |
rrf_k_constant=60 | |
): | |
""" | |
Thực hiện tìm kiếm lai (Hybrid Search) kết hợp Semantic Search (FAISS) và Keyword Search (BM25), | |
sau đó kết hợp kết quả bằng Reciprocal Rank Fusion (RRF) và tăng cường bằng metadata. | |
""" | |
if k <= 0: | |
print("Lỗi: k (số lượng kết quả) phải là số dương.") | |
return [] | |
print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'") | |
query_lower = query_text.lower() | |
# Phân tích query | |
fine_keywords = r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính' | |
points_keywords = r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx' | |
query_mentions_fine = bool(re.search(fine_keywords, query_lower)) | |
query_mentions_points = bool(re.search(points_keywords, query_lower)) | |
needs_specific_metadata_filter = query_mentions_fine or query_mentions_points | |
print(f" Phân tích query: Đề cập tiền phạt? {query_mentions_fine}, Đề cập điểm trừ? {query_mentions_points}") | |
num_vectors_in_index = faiss_index.ntotal | |
if num_vectors_in_index == 0: | |
print("Lỗi: FAISS index rỗng.") | |
return [] | |
num_candidates_each_retriever = min(k * initial_k_multiplier, num_vectors_in_index) | |
# === 1. Semantic Search (FAISS) === | |
try: | |
query_embedding_tensor = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device) | |
query_embedding_np = query_embedding_tensor.cpu().numpy().astype('float32') | |
faiss.normalize_L2(query_embedding_np) | |
semantic_scores_raw, semantic_indices_raw = faiss_index.search(query_embedding_np, num_candidates_each_retriever) | |
except Exception as e: | |
print(f"Lỗi khi tìm kiếm ngữ nghĩa (FAISS): {e}") | |
semantic_indices_raw = np.array([[]], dtype=int) | |
# === 2. Keyword Search (BM25) === | |
try: | |
tokenized_query_bm25 = tokenize_vi_for_bm25_setup(query_text) | |
all_bm25_scores = bm25_model.get_scores(tokenized_query_bm25) | |
bm25_results_with_indices = [{'index': i, 'score': score} for i, score in enumerate(all_bm25_scores) if score > 0] | |
bm25_results_with_indices.sort(key=lambda x: x['score'], reverse=True) | |
top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever] | |
except Exception as e: | |
print(f"Lỗi khi tìm kiếm từ khóa (BM25): {e}") | |
top_bm25_results = [] | |
# === 3. Result Fusion (RRF) === | |
rrf_scores = defaultdict(float) | |
all_retrieved_indices_set = set() | |
if semantic_indices_raw.size > 0: | |
for rank, doc_idx in enumerate(semantic_indices_raw[0]): | |
if 0 <= doc_idx < num_vectors_in_index: | |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
all_retrieved_indices_set.add(doc_idx) | |
for rank, item in enumerate(top_bm25_results): | |
doc_idx = item['index'] | |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
all_retrieved_indices_set.add(doc_idx) | |
fused_initial_results = [{'index': doc_idx, 'fused_score': rrf_scores[doc_idx]} for doc_idx in all_retrieved_indices_set] | |
fused_initial_results.sort(key=lambda x: x['fused_score'], reverse=True) | |
# === 4. Lọc và Tái xếp hạng cuối cùng === | |
final_processed_results = [] | |
num_to_process_metadata = min(len(fused_initial_results), num_candidates_each_retriever * 2) | |
for rank_idx, res_item in enumerate(fused_initial_results[:num_to_process_metadata]): | |
try: | |
result_index = res_item['index'] | |
base_score_from_fusion = res_item['fused_score'] | |
original_chunk = chunks_data[result_index] | |
original_metadata = original_chunk.get('metadata', {}) | |
# Thêm logic xử lý metadata boosting ở đây nếu cần... | |
# Hiện tại, chỉ trả về kết quả đã fusion. | |
# Bạn có thể copy lại toàn bộ logic boosting từ script gốc vào đây. | |
final_score_calculated = base_score_from_fusion # (Thêm boosting vào đây) | |
final_processed_results.append({ | |
"rank_after_fusion": rank_idx + 1, | |
"index": int(result_index), | |
"final_score": final_score_calculated, | |
"text": original_chunk.get('text', '*Không có text*'), | |
"metadata": original_metadata | |
}) | |
except Exception as e: | |
print(f"Lỗi khi xử lý ứng viên tại chỉ số {res_item.get('index')}: {e}") | |
final_processed_results.sort(key=lambda x: x["final_score"], reverse=True) | |
return final_processed_results[:k] |