Spaces:
Paused
Paused
File size: 5,253 Bytes
a53f1d8 b8e326d a53f1d8 b8e326d a53f1d8 b8e326d a53f1d8 b8e326d a53f1d8 b8e326d a53f1d8 b8e326d a53f1d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# file: retriever.py
import faiss
import numpy as np
import torch
import re
from collections import defaultdict
from rank_bm25 import BM25Okapi
def tokenize_vi_for_bm25_setup(text):
"""Tokenize tiếng Việt đơn giản cho BM25."""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
return text.split()
def search_relevant_laws(
query_text,
embedding_model,
faiss_index,
chunks_data,
bm25_model,
k=5,
initial_k_multiplier=10,
rrf_k_constant=60
):
"""
Thực hiện tìm kiếm lai (Hybrid Search) kết hợp Semantic Search (FAISS) và Keyword Search (BM25),
sau đó kết hợp kết quả bằng Reciprocal Rank Fusion (RRF) và tăng cường bằng metadata.
"""
if k <= 0:
print("Lỗi: k (số lượng kết quả) phải là số dương.")
return []
print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'")
query_lower = query_text.lower()
# Phân tích query
fine_keywords = r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính'
points_keywords = r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx'
query_mentions_fine = bool(re.search(fine_keywords, query_lower))
query_mentions_points = bool(re.search(points_keywords, query_lower))
needs_specific_metadata_filter = query_mentions_fine or query_mentions_points
print(f" Phân tích query: Đề cập tiền phạt? {query_mentions_fine}, Đề cập điểm trừ? {query_mentions_points}")
num_vectors_in_index = faiss_index.ntotal
if num_vectors_in_index == 0:
print("Lỗi: FAISS index rỗng.")
return []
num_candidates_each_retriever = min(k * initial_k_multiplier, num_vectors_in_index)
# === 1. Semantic Search (FAISS) ===
try:
query_embedding_tensor = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
query_embedding_np = query_embedding_tensor.cpu().numpy().astype('float32')
faiss.normalize_L2(query_embedding_np)
semantic_scores_raw, semantic_indices_raw = faiss_index.search(query_embedding_np, num_candidates_each_retriever)
except Exception as e:
print(f"Lỗi khi tìm kiếm ngữ nghĩa (FAISS): {e}")
semantic_indices_raw = np.array([[]], dtype=int)
# === 2. Keyword Search (BM25) ===
try:
tokenized_query_bm25 = tokenize_vi_for_bm25_setup(query_text)
all_bm25_scores = bm25_model.get_scores(tokenized_query_bm25)
bm25_results_with_indices = [{'index': i, 'score': score} for i, score in enumerate(all_bm25_scores) if score > 0]
bm25_results_with_indices.sort(key=lambda x: x['score'], reverse=True)
top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever]
except Exception as e:
print(f"Lỗi khi tìm kiếm từ khóa (BM25): {e}")
top_bm25_results = []
# === 3. Result Fusion (RRF) ===
rrf_scores = defaultdict(float)
all_retrieved_indices_set = set()
if semantic_indices_raw.size > 0:
for rank, doc_idx in enumerate(semantic_indices_raw[0]):
if 0 <= doc_idx < num_vectors_in_index:
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
all_retrieved_indices_set.add(doc_idx)
for rank, item in enumerate(top_bm25_results):
doc_idx = item['index']
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
all_retrieved_indices_set.add(doc_idx)
fused_initial_results = [{'index': doc_idx, 'fused_score': rrf_scores[doc_idx]} for doc_idx in all_retrieved_indices_set]
fused_initial_results.sort(key=lambda x: x['fused_score'], reverse=True)
# === 4. Lọc và Tái xếp hạng cuối cùng ===
final_processed_results = []
num_to_process_metadata = min(len(fused_initial_results), num_candidates_each_retriever * 2)
for rank_idx, res_item in enumerate(fused_initial_results[:num_to_process_metadata]):
try:
result_index = res_item['index']
base_score_from_fusion = res_item['fused_score']
original_chunk = chunks_data[result_index]
original_metadata = original_chunk.get('metadata', {})
# Thêm logic xử lý metadata boosting ở đây nếu cần...
# Hiện tại, chỉ trả về kết quả đã fusion.
# Bạn có thể copy lại toàn bộ logic boosting từ script gốc vào đây.
final_score_calculated = base_score_from_fusion # (Thêm boosting vào đây)
final_processed_results.append({
"rank_after_fusion": rank_idx + 1,
"index": int(result_index),
"final_score": final_score_calculated,
"text": original_chunk.get('text', '*Không có text*'),
"metadata": original_metadata
})
except Exception as e:
print(f"Lỗi khi xử lý ứng viên tại chỉ số {res_item.get('index')}: {e}")
final_processed_results.sort(key=lambda x: x["final_score"], reverse=True)
return final_processed_results[:k] |