File size: 12,269 Bytes
a6fd1a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import logging
from typing import List, Dict, Any, Optional
import weaviate
import weaviate.classes.query as wvc_query
from concurrent.futures import ThreadPoolExecutor
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from utils.process_data import infer_field, infer_entity_type
from utils.synonym_map import rewrite_query_with_legal_synonyms
import prompt_templete

logger = logging.getLogger(__name__)

class AdvancedLawRetriever(BaseRetriever):
    client: weaviate.WeaviateClient
    collection_name: str
    llm: Any
    reranker: Any
    embeddings_model: Any

    default_k: int = 5
    initial_k: int = 15 # Lấy nhiều ứng viên ban đầu
    hybrid_search_alpha: float = 0.5
    doc_type_boost: float = 0.4

    class ConfigDict:
        arbitrary_types_allowed = True

    # === CÁC HÀM HELPER ===

    def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]:
        """Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả."""
        keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT)
        keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()])
        try:
            keywords = keyword_chain.invoke({"question": question})
            # Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh
            return [question] + keywords
        except Exception as e:
            logger.error(f"Failed to extract keywords: {e}")
            return [question]

    def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]:
        """
        CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter.
        Nó không còn nhiệm vụ suy luận nữa.
        """
        if not filters_dict:
            return None

        filter_conditions = []
        for key, value in filters_dict.items():
            if value is None:
                continue

            # Logic xây dựng Filter
            if key == "entity_type" and isinstance(value, list) and value:
                filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value))
            elif isinstance(value, str):
                filter_conditions.append(wvc_query.Filter.by_property(key).equal(value))
            # Thêm các điều kiện khác nếu cần

        if not filter_conditions:
            return None

        return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]


    def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]:
        # ... (giữ nguyên logic) ...
        try:
            collection = self.client.collections.get(self.collection_name)
            query_vector = self.embeddings_model.embed_query(query)
            response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True))
            docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects]
            return docs
        except Exception: return []



    # === HÀM CHÍNH _get_relevant_documents ===
    def _get_relevant_documents(
    self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:

        # =================================================================
        # PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn
        # =================================================================

        # 0.1. Đảm bảo an toàn cho input
        safe_query = str(query)
        logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---")

        # 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc
        query_info = self._extract_query_info_with_intent(safe_query)
        inferred_field = query_info.get("base_filters", {}).get("field")
        preferred_doc_type = query_info.get("preferred_doc_type")

        # 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển
        rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field)
        if safe_query != rewritten_query:
            logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'")

        # 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại
        search_terms = self._extract_searchable_keywords_with_llm(rewritten_query)
        logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}")

        # 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất
        base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"])

        # =================================================================
        # PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback
        # =================================================================

        def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]:
            """Hàm nội bộ để thực hiện tìm kiếm song song."""
            docs = []
            with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor:
                futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms]
                for future in futures:
                    try: docs.extend(future.result())
                    except Exception as e: logger.error(f"A search task failed: {e}")
            return docs

        logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---")
        retrieved_docs = run_search_tasks(base_weaviate_filter)

        # Lọc trùng lặp
        unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)}

        # Cơ chế Fallback
        if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None:
            logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...")
            fallback_docs = run_search_tasks(None)
            for doc in fallback_docs:
                if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict:
                    unique_docs_dict[doc.page_content] = doc

        candidate_docs_list = list(unique_docs_dict.values())

        # =================================================================
        # PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả
        # =================================================================

        # 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên
        final_candidates_for_rerank = candidate_docs_list
        if preferred_doc_type:
            logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'")
            docs_with_scores = []
            for doc in candidate_docs_list:
                score = doc.metadata.get('hybrid_score', 0.5)
                if doc.metadata.get("loai_van_ban") == preferred_doc_type:
                    score += self.doc_type_boost
                else:
                    score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên
                docs_with_scores.append((doc, score))

            docs_with_scores.sort(key=lambda x: x[1], reverse=True)
            final_candidates_for_rerank = [doc for doc, score in docs_with_scores]

        logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.")
        if not final_candidates_for_rerank: return []

        # 2.2. Cross-Encoder Re-ranking với Structured Context
        logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...")

        docs_for_reranking = []
        for doc in final_candidates_for_rerank:
            # Tạo chuỗi context giàu thông tin
            structured_content = (
                f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. "
                f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. "
                f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n"
                f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}"
            )
            docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content})

        contents_to_rank = [item["structured_content"] for item in docs_for_reranking]

        try:
            # Sử dụng câu hỏi đã được viết lại để có ngữ cảnh tốt nhất
            ranked_results_info = self.reranker.rank(rewritten_query, contents_to_rank, return_documents=False, top_k=self.default_k * 2) # Lấy nhiều hơn một chút
        except Exception as e:
            logger.error(f"Failed to re-rank with custom structured content: {e}. Falling back to default re-ranking.")
            # Fallback về cách re-rank mặc định nếu có lỗi
            reranked_docs = self.reranker.compress_documents(final_candidates_for_rerank, rewritten_query)
            return reranked_docs[:self.default_k]

        # Lấy lại các Document gốc theo thứ tự đã được re-rank
        final_reranked_docs = []
        for rank_info in ranked_results_info:
            original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"]
            original_doc.metadata['rerank_score'] = rank_info['score']
            final_reranked_docs.append(original_doc)

        # 2.3. Log và Trả về kết quả cuối cùng
        logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---")
        for i, doc in enumerate(final_reranked_docs[:self.default_k]):
            score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}"
            logger.info(f"  - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}")
            logger.info(f"    CONTENT: {doc.page_content[:400]}...") # Log dài hơn
            logger.info("-" * 25)

        return final_reranked_docs[:self.default_k]

    def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]:
        """
        Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản.
        """
        info = {"base_filters": {}, "preferred_doc_type": None}
        query_lower = query.lower()

        # 1. Suy luận field và entity
        inferred_field = infer_field(query, None)
        if inferred_field and inferred_field != "khac":
            info["base_filters"]["field"] = inferred_field

        inferred_entities = infer_entity_type(query, inferred_field)
        if inferred_entities:
            info["base_filters"]["entity_type"] = inferred_entities

        # 2. XÁC ĐỊNH Ý ĐỊNH -> ƯU TIÊN LOẠI VĂN BẢN
        # Nếu câu hỏi về MỨC PHẠT, ưu tiên tuyệt đối NGHỊ ĐỊNH
        if any(kw in query_lower for kw in ["phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt"]):
            info["preferred_doc_type"] = "NGHỊ ĐỊNH"
            logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.")
        # Nếu câu hỏi về NGUYÊN TẮC CHUNG, QUYỀN, NGHĨA VỤ, ưu tiên LUẬT
        elif any(kw in query_lower for kw in ["nguyên tắc", "quyền và nghĩa vụ", "cấm", "được phép", "khái niệm", "định nghĩa"]):
             info["preferred_doc_type"] = "LUẬT"
             logger.info("Intent detected: General Rule/Definition -> Preferring 'LUẬT'.")

        logger.info(f"Extracted query info: {info}")
        return info