File size: 5,253 Bytes
a53f1d8
b8e326d
a53f1d8
 
 
b8e326d
a53f1d8
 
 
 
 
 
 
 
b8e326d
a53f1d8
 
 
 
 
 
 
 
 
b8e326d
a53f1d8
 
b8e326d
a53f1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8e326d
a53f1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# file: retriever.py
import faiss
import numpy as np
import torch
import re
from collections import defaultdict
from rank_bm25 import BM25Okapi

def tokenize_vi_for_bm25_setup(text):
    """Tokenize tiếng Việt đơn giản cho BM25."""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return text.split()

def search_relevant_laws(
        query_text,
        embedding_model,
        faiss_index,
        chunks_data,
        bm25_model,
        k=5,
        initial_k_multiplier=10,
        rrf_k_constant=60
    ):
    """
    Thực hiện tìm kiếm lai (Hybrid Search) kết hợp Semantic Search (FAISS) và Keyword Search (BM25),
    sau đó kết hợp kết quả bằng Reciprocal Rank Fusion (RRF) và tăng cường bằng metadata.
    """
    if k <= 0:
        print("Lỗi: k (số lượng kết quả) phải là số dương.")
        return []

    print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'")
    query_lower = query_text.lower()

    # Phân tích query
    fine_keywords = r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính'
    points_keywords = r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx'
    query_mentions_fine = bool(re.search(fine_keywords, query_lower))
    query_mentions_points = bool(re.search(points_keywords, query_lower))
    needs_specific_metadata_filter = query_mentions_fine or query_mentions_points
    print(f"   Phân tích query: Đề cập tiền phạt? {query_mentions_fine}, Đề cập điểm trừ? {query_mentions_points}")

    num_vectors_in_index = faiss_index.ntotal
    if num_vectors_in_index == 0:
        print("Lỗi: FAISS index rỗng.")
        return []

    num_candidates_each_retriever = min(k * initial_k_multiplier, num_vectors_in_index)

    # === 1. Semantic Search (FAISS) ===
    try:
        query_embedding_tensor = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
        query_embedding_np = query_embedding_tensor.cpu().numpy().astype('float32')
        faiss.normalize_L2(query_embedding_np)
        semantic_scores_raw, semantic_indices_raw = faiss_index.search(query_embedding_np, num_candidates_each_retriever)
    except Exception as e:
        print(f"Lỗi khi tìm kiếm ngữ nghĩa (FAISS): {e}")
        semantic_indices_raw = np.array([[]], dtype=int)

    # === 2. Keyword Search (BM25) ===
    try:
        tokenized_query_bm25 = tokenize_vi_for_bm25_setup(query_text)
        all_bm25_scores = bm25_model.get_scores(tokenized_query_bm25)
        bm25_results_with_indices = [{'index': i, 'score': score} for i, score in enumerate(all_bm25_scores) if score > 0]
        bm25_results_with_indices.sort(key=lambda x: x['score'], reverse=True)
        top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever]
    except Exception as e:
        print(f"Lỗi khi tìm kiếm từ khóa (BM25): {e}")
        top_bm25_results = []

    # === 3. Result Fusion (RRF) ===
    rrf_scores = defaultdict(float)
    all_retrieved_indices_set = set()

    if semantic_indices_raw.size > 0:
        for rank, doc_idx in enumerate(semantic_indices_raw[0]):
            if 0 <= doc_idx < num_vectors_in_index:
                rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
                all_retrieved_indices_set.add(doc_idx)

    for rank, item in enumerate(top_bm25_results):
        doc_idx = item['index']
        rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
        all_retrieved_indices_set.add(doc_idx)

    fused_initial_results = [{'index': doc_idx, 'fused_score': rrf_scores[doc_idx]} for doc_idx in all_retrieved_indices_set]
    fused_initial_results.sort(key=lambda x: x['fused_score'], reverse=True)

    # === 4. Lọc và Tái xếp hạng cuối cùng ===
    final_processed_results = []
    num_to_process_metadata = min(len(fused_initial_results), num_candidates_each_retriever * 2)

    for rank_idx, res_item in enumerate(fused_initial_results[:num_to_process_metadata]):
        try:
            result_index = res_item['index']
            base_score_from_fusion = res_item['fused_score']
            original_chunk = chunks_data[result_index]
            original_metadata = original_chunk.get('metadata', {})
            # Thêm logic xử lý metadata boosting ở đây nếu cần...
            # Hiện tại, chỉ trả về kết quả đã fusion.
            # Bạn có thể copy lại toàn bộ logic boosting từ script gốc vào đây.
            
            final_score_calculated = base_score_from_fusion # (Thêm boosting vào đây)

            final_processed_results.append({
                "rank_after_fusion": rank_idx + 1,
                "index": int(result_index),
                "final_score": final_score_calculated,
                "text": original_chunk.get('text', '*Không có text*'),
                "metadata": original_metadata
            })
        except Exception as e:
            print(f"Lỗi khi xử lý ứng viên tại chỉ số {res_item.get('index')}: {e}")

    final_processed_results.sort(key=lambda x: x["final_score"], reverse=True)
    return final_processed_results[:k]