juribot-backend / utils /AdvancedLawRetriever.py
entidi2608's picture
update: retriever
25e6e74
import logging
from typing import List, Dict, Any, Optional
import weaviate
import weaviate.classes.query as wvc_query
from concurrent.futures import ThreadPoolExecutor
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from utils.process_data import infer_field, infer_entity_type
from utils.synonym_map import rewrite_query_with_legal_synonyms
import prompt_templete
logger = logging.getLogger(__name__)
class AdvancedLawRetriever(BaseRetriever):
client: weaviate.WeaviateClient
collection_name: str
llm: Any
reranker: Any
embeddings_model: Any
default_k: int = 3
initial_k: int = 10 # Lấy nhiều ứng viên ban đầu
hybrid_search_alpha: float = 0.5
doc_type_boost: float = 0.4
class ConfigDict:
arbitrary_types_allowed = True
# === CÁC HÀM HELPER ===
def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]:
"""Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả."""
keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT)
keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()])
try:
keywords = keyword_chain.invoke({"question": question})
# Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh
return [question] + keywords
except Exception as e:
logger.error(f"Failed to extract keywords: {e}")
return [question]
def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]:
"""
CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter.
Nó không còn nhiệm vụ suy luận nữa.
"""
if not filters_dict:
return None
filter_conditions = []
for key, value in filters_dict.items():
if value is None:
continue
# Logic xây dựng Filter
if key == "entity_type" and isinstance(value, list) and value:
filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value))
elif isinstance(value, str):
filter_conditions.append(wvc_query.Filter.by_property(key).equal(value))
# Thêm các điều kiện khác nếu cần
if not filter_conditions:
return None
return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]
def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]:
# ... (giữ nguyên logic) ...
try:
collection = self.client.collections.get(self.collection_name)
query_vector = self.embeddings_model.embed_query(query)
response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True))
docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects]
return docs
except Exception: return []
# === HÀM CHÍNH _get_relevant_documents ===
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# =================================================================
# PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn
# =================================================================
# 0.1. Đảm bảo an toàn cho input
safe_query = str(query)
logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---")
# 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc
query_info = self._extract_query_info_with_intent(safe_query)
inferred_field = query_info.get("base_filters", {}).get("field")
preferred_doc_type = query_info.get("preferred_doc_type")
# 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển
rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field)
if safe_query != rewritten_query:
logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'")
# 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại
search_terms = self._extract_searchable_keywords_with_llm(rewritten_query)
logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}")
# 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất
base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"])
# =================================================================
# PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback
# =================================================================
def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]:
"""Hàm nội bộ để thực hiện tìm kiếm song song."""
docs = []
with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor:
futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms]
for future in futures:
try: docs.extend(future.result())
except Exception as e: logger.error(f"A search task failed: {e}")
return docs
logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---")
retrieved_docs = run_search_tasks(base_weaviate_filter)
# Lọc trùng lặp
unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)}
# Cơ chế Fallback
if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None:
logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...")
fallback_docs = run_search_tasks(None)
for doc in fallback_docs:
if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict:
unique_docs_dict[doc.page_content] = doc
candidate_docs_list = list(unique_docs_dict.values())
# =================================================================
# PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả
# =================================================================
# 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên
final_candidates_for_rerank = candidate_docs_list
if preferred_doc_type:
logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'")
docs_with_scores = []
for doc in candidate_docs_list:
score = doc.metadata.get('hybrid_score', 0.5)
if doc.metadata.get("loai_van_ban") == preferred_doc_type:
score += self.doc_type_boost
else:
score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên
docs_with_scores.append((doc, score))
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_candidates_for_rerank = [doc for doc, score in docs_with_scores]
logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.")
if not final_candidates_for_rerank: return []
# 2.2. Cross-Encoder Re-ranking với Structured Context
logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...")
docs_for_reranking = []
for doc in final_candidates_for_rerank:
# Tạo chuỗi context giàu thông tin
structured_content = (
f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. "
f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. "
f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n"
f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}"
)
docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content})
contents_to_rank = [item["structured_content"] for item in docs_for_reranking]
try:
# Kiểm tra nếu không có ứng viên nào để re-rank
if not final_candidates_for_rerank:
logger.warning("No candidates found for re-ranking. Returning empty list.")
return []
logger.info(f"Applying Cross-Encoder re-ranking to {len(final_candidates_for_rerank)} candidates...")
# SỬ DỤNG PHƯƠNG THỨC .compress_documents()
# Đây là phương thức chính của CrossEncoderReranker trong LangChain.
# Nó nhận danh sách Document và một chuỗi query, sau đó trả về
# danh sách Document đã được sắp xếp lại và lọc theo top_n.
reranked_docs = self.reranker.compress_documents(
documents=final_candidates_for_rerank,
query=rewritten_query
)
# `reranked_docs` bây giờ là một danh sách các tài liệu tốt nhất,
# đã được lọc theo `top_n` bạn đã cấu hình trong get_reranker_compressor (mặc định là 4).
logger.info(f"Re-ranking complete. Found {len(reranked_docs)} relevant documents.")
# Thêm log chi tiết để debug
# ---- LOGGING ĐỂ DEBUG ----
logger.info("--- Top documents after re-ranking ---")
for i, doc in enumerate(reranked_docs):
doc_info = {
"content_preview": doc.page_content[:200] + "...",
"metadata": doc.metadata,
# CrossEncoderReranker sẽ thêm điểm số vào metadata
"relevance_score": doc.metadata.get('relevance_score')
}
logger.info(f"Doc {i+1}: {doc_info}")
# ---- KẾT THÚC LOGGING ----
return reranked_docs
except Exception as e:
import traceback
logger.error(f"An unexpected error occurred during re-ranking: {e}\n{traceback.format_exc()}")
# Fallback an toàn: Trả về kết quả từ VectorDB mà không re-rank,
# lấy theo số lượng top_k của retriever ban đầu.
logger.warning("Falling back to returning top results from vector search without re-ranking.")
# Giả sử bạn có self.default_k được cấu hình
return final_candidates_for_rerank[:getattr(self, 'default_k', 4)]
# Lấy lại các Document gốc theo thứ tự đã được re-rank
final_reranked_docs = []
for rank_info in ranked_results_info:
original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"]
original_doc.metadata['rerank_score'] = rank_info['score']
final_reranked_docs.append(original_doc)
# 2.3. Log và Trả về kết quả cuối cùng
logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---")
for i, doc in enumerate(final_reranked_docs[:self.default_k]):
score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}"
logger.info(f" - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}")
logger.info(f" CONTENT: {doc.page_content[:400]}...") # Log dài hơn
logger.info("-" * 25)
return final_reranked_docs[:self.default_k]
def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]:
"""
Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản.
"""
info = {"base_filters": {}, "preferred_doc_type": None}
query_lower = query.lower()
intent_found = False
# --- Giai đoạn 1: Xác định ý định và loại văn bản ưu tiên ---
# Đặt các intent cụ thể và dễ xung đột lên đầu.
# INTENT 1: Dữ kiện cụ thể (Diện tích, Dân số) - Rất cụ thể
if not intent_found and (
any(kw in query_lower for kw in ["diện tích", "dân số", "số lượng"])
and any(loc_kw in query_lower for loc_kw in ["tỉnh", "thành phố", "huyện", "xã", "việt nam"])
):
info["preferred_doc_type"] = "NGHỊ QUYẾT"
logger.info("Intent detected: Specific Factual Data (Area, Population) -> Preferring 'NGHỊ QUYẾT'.")
intent_found = True
# Với intent này, chúng ta không áp dụng bộ lọc entity_type tự động.
# INTENT 2: Mức phạt / Chế tài - Rất cụ thể
if not intent_found and any(kw in query_lower for kw in [
"phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt", "phạt tiền",
"mức phạt", "số tiền phạt", "bị phạt", "chế phạt", "tước bằng",
"tạm giữ phương tiện", "tịch thu", "vi phạm"
]):
info["preferred_doc_type"] = "NGHỊ ĐỊNH"
logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.")
intent_found = True
# INTENT 3: Thủ tục hành chính - Cụ thể
if not intent_found and any(kw in query_lower for kw in [
"thủ tục", "hồ sơ", "đăng ký", "cấp phép", "giấy phép", "chứng nhận",
"xin phép", "nộp hồ sơ", "thời hạn giải quyết", "lệ phí", "phí", "thẩm quyền",
"cơ quan nào", "nộp ở đâu", "ai cấp", "quy trình", "trình tự"
]):
info["preferred_doc_type"] = "THÔNG TƯ"
logger.info("Intent detected: Administrative Procedure -> Preferring 'THÔNG TƯ'.")
intent_found = True
# INTENT 4: Biểu mẫu / Tiêu chuẩn kỹ thuật - Cụ thể
if not intent_found and any(kw in query_lower for kw in [
"biểu mẫu", "mẫu đơn", "form", "danh mục", "bảng biểu", "định mức",
"khung", "tiêu chuẩn kỹ thuật", "quy chuẩn", "đơn giá"
]):
info["preferred_doc_type"] = "THÔNG TƯ"
logger.info("Intent detected: Forms/Standards/Rates -> Preferring 'THÔNG TƯ'.")
intent_found = True
# INTENT 5: Tổ chức bộ máy cụ thể (Sáp nhập, thành lập) - Chung hơn
# Phải đứng trước intent về Nguyên tắc chung để bắt các case "sáp nhập đơn vị hành chính"
if not intent_found and any(kw in query_lower for kw in [
"bộ máy", "cơ cấu tổ chức", "chức năng", "nhiệm vụ cụ thể",
"thành lập", "giải thể", "sáp nhập", "chia tách", "tái cơ cấu",
"bổ nhiệm", "miễn nhiệm"
]):
info["preferred_doc_type"] = "NGHỊ ĐỊNH" # Có thể là NGHỊ QUYẾT nữa
logger.info("Intent detected: Specific Organization Structure -> Preferring 'NGHỊ ĐỊNH'.")
intent_found = True
# INTENT 6: Nguyên tắc chung / Định nghĩa / Cấu trúc cơ bản - Rất chung
# Đặt cuối cùng để nó chỉ bắt các câu hỏi chung chung thực sự.
if not intent_found and any(kw in query_lower for kw in [
"nguyên tắc", "quyền và nghĩa vụ", "quyền", "nghĩa vụ", "trách nhiệm",
"khái niệm", "định nghĩa", "là gì", "hiểu như thế nào", "nghiêm cấm",
"đơn vị hành chính", "cấp hành chính", "bao nhiêu cấp", "phân cấp",
"cấu trúc nhà nước", "hệ thống chính quyền"
]):
info["preferred_doc_type"] = "LUẬT"
logger.info("Intent detected: General Principle/Definition/Rights/Administrative Structure -> Preferring 'LUẬT'.")
intent_found = True
# --- Giai đoạn 2: Áp dụng bộ lọc (filters) một cách có kiểm soát ---
# 2.1. Lọc theo lĩnh vực (field) - Có thể áp dụng cho hầu hết các trường hợp
inferred_field = infer_field(query, None)
if inferred_field and inferred_field != "khac":
info["base_filters"]["field"] = inferred_field
logger.info(f"Applying filter: field = {inferred_field}")
# 2.2. Lọc theo loại thực thể (entity_type) - CHỈ áp dụng khi thực sự cần thiết
# Chúng ta KHÔNG áp dụng bộ lọc này cho các câu hỏi về dữ kiện địa lý
# hoặc các câu hỏi chung chung có thể không có entity_type rõ ràng.
# Ví dụ, chỉ áp dụng cho các câu hỏi về thủ tục, tổ chức, chế tài...
# Lấy lại intent đã xác định ở trên để quyết định
current_intent_log = logger.handlers[0].records[-1].msg if logger.handlers and logger.handlers[0].records else ""
# Chỉ áp dụng entity filter cho các intent này
applicable_intents_for_entity_filter = [
"Sanction/Penalty",
"Administrative Procedure",
"Specific Organization Structure"
]
if any(intent_name in current_intent_log for intent_name in applicable_intents_for_entity_filter):
inferred_entities = infer_entity_type(query, inferred_field)
if inferred_entities:
info["base_filters"]["entity_type"] = inferred_entities
logger.info(f"Applying filter: entity_type = {inferred_entities}")
else:
logger.info("Skipping entity_type filter for the detected intent.")
logger.info(f"Final extracted query info: {info}")
return info