juribot-backend / utils /AdvancedLawRetriever.py
entidi2608's picture
Initial backend deployment
a6fd1a3
raw
history blame
12.3 kB
import logging
from typing import List, Dict, Any, Optional
import weaviate
import weaviate.classes.query as wvc_query
from concurrent.futures import ThreadPoolExecutor
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from utils.process_data import infer_field, infer_entity_type
from utils.synonym_map import rewrite_query_with_legal_synonyms
import prompt_templete
logger = logging.getLogger(__name__)
class AdvancedLawRetriever(BaseRetriever):
client: weaviate.WeaviateClient
collection_name: str
llm: Any
reranker: Any
embeddings_model: Any
default_k: int = 5
initial_k: int = 15 # Lấy nhiều ứng viên ban đầu
hybrid_search_alpha: float = 0.5
doc_type_boost: float = 0.4
class ConfigDict:
arbitrary_types_allowed = True
# === CÁC HÀM HELPER ===
def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]:
"""Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả."""
keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT)
keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()])
try:
keywords = keyword_chain.invoke({"question": question})
# Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh
return [question] + keywords
except Exception as e:
logger.error(f"Failed to extract keywords: {e}")
return [question]
def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]:
"""
CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter.
Nó không còn nhiệm vụ suy luận nữa.
"""
if not filters_dict:
return None
filter_conditions = []
for key, value in filters_dict.items():
if value is None:
continue
# Logic xây dựng Filter
if key == "entity_type" and isinstance(value, list) and value:
filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value))
elif isinstance(value, str):
filter_conditions.append(wvc_query.Filter.by_property(key).equal(value))
# Thêm các điều kiện khác nếu cần
if not filter_conditions:
return None
return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]
def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]:
# ... (giữ nguyên logic) ...
try:
collection = self.client.collections.get(self.collection_name)
query_vector = self.embeddings_model.embed_query(query)
response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True))
docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects]
return docs
except Exception: return []
# === HÀM CHÍNH _get_relevant_documents ===
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# =================================================================
# PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn
# =================================================================
# 0.1. Đảm bảo an toàn cho input
safe_query = str(query)
logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---")
# 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc
query_info = self._extract_query_info_with_intent(safe_query)
inferred_field = query_info.get("base_filters", {}).get("field")
preferred_doc_type = query_info.get("preferred_doc_type")
# 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển
rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field)
if safe_query != rewritten_query:
logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'")
# 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại
search_terms = self._extract_searchable_keywords_with_llm(rewritten_query)
logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}")
# 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất
base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"])
# =================================================================
# PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback
# =================================================================
def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]:
"""Hàm nội bộ để thực hiện tìm kiếm song song."""
docs = []
with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor:
futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms]
for future in futures:
try: docs.extend(future.result())
except Exception as e: logger.error(f"A search task failed: {e}")
return docs
logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---")
retrieved_docs = run_search_tasks(base_weaviate_filter)
# Lọc trùng lặp
unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)}
# Cơ chế Fallback
if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None:
logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...")
fallback_docs = run_search_tasks(None)
for doc in fallback_docs:
if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict:
unique_docs_dict[doc.page_content] = doc
candidate_docs_list = list(unique_docs_dict.values())
# =================================================================
# PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả
# =================================================================
# 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên
final_candidates_for_rerank = candidate_docs_list
if preferred_doc_type:
logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'")
docs_with_scores = []
for doc in candidate_docs_list:
score = doc.metadata.get('hybrid_score', 0.5)
if doc.metadata.get("loai_van_ban") == preferred_doc_type:
score += self.doc_type_boost
else:
score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên
docs_with_scores.append((doc, score))
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_candidates_for_rerank = [doc for doc, score in docs_with_scores]
logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.")
if not final_candidates_for_rerank: return []
# 2.2. Cross-Encoder Re-ranking với Structured Context
logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...")
docs_for_reranking = []
for doc in final_candidates_for_rerank:
# Tạo chuỗi context giàu thông tin
structured_content = (
f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. "
f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. "
f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n"
f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}"
)
docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content})
contents_to_rank = [item["structured_content"] for item in docs_for_reranking]
try:
# Sử dụng câu hỏi đã được viết lại để có ngữ cảnh tốt nhất
ranked_results_info = self.reranker.rank(rewritten_query, contents_to_rank, return_documents=False, top_k=self.default_k * 2) # Lấy nhiều hơn một chút
except Exception as e:
logger.error(f"Failed to re-rank with custom structured content: {e}. Falling back to default re-ranking.")
# Fallback về cách re-rank mặc định nếu có lỗi
reranked_docs = self.reranker.compress_documents(final_candidates_for_rerank, rewritten_query)
return reranked_docs[:self.default_k]
# Lấy lại các Document gốc theo thứ tự đã được re-rank
final_reranked_docs = []
for rank_info in ranked_results_info:
original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"]
original_doc.metadata['rerank_score'] = rank_info['score']
final_reranked_docs.append(original_doc)
# 2.3. Log và Trả về kết quả cuối cùng
logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---")
for i, doc in enumerate(final_reranked_docs[:self.default_k]):
score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}"
logger.info(f" - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}")
logger.info(f" CONTENT: {doc.page_content[:400]}...") # Log dài hơn
logger.info("-" * 25)
return final_reranked_docs[:self.default_k]
def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]:
"""
Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản.
"""
info = {"base_filters": {}, "preferred_doc_type": None}
query_lower = query.lower()
# 1. Suy luận field và entity
inferred_field = infer_field(query, None)
if inferred_field and inferred_field != "khac":
info["base_filters"]["field"] = inferred_field
inferred_entities = infer_entity_type(query, inferred_field)
if inferred_entities:
info["base_filters"]["entity_type"] = inferred_entities
# 2. XÁC ĐỊNH Ý ĐỊNH -> ƯU TIÊN LOẠI VĂN BẢN
# Nếu câu hỏi về MỨC PHẠT, ưu tiên tuyệt đối NGHỊ ĐỊNH
if any(kw in query_lower for kw in ["phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt"]):
info["preferred_doc_type"] = "NGHỊ ĐỊNH"
logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.")
# Nếu câu hỏi về NGUYÊN TẮC CHUNG, QUYỀN, NGHĨA VỤ, ưu tiên LUẬT
elif any(kw in query_lower for kw in ["nguyên tắc", "quyền và nghĩa vụ", "cấm", "được phép", "khái niệm", "định nghĩa"]):
info["preferred_doc_type"] = "LUẬT"
logger.info("Intent detected: General Rule/Definition -> Preferring 'LUẬT'.")
logger.info(f"Extracted query info: {info}")
return info