Spaces:
Runtime error
Runtime error
import logging | |
from typing import List, Dict, Any, Optional | |
import weaviate | |
import weaviate.classes.query as wvc_query | |
from concurrent.futures import ThreadPoolExecutor | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun | |
from langchain_core.documents import Document | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from utils.process_data import infer_field, infer_entity_type | |
from utils.synonym_map import rewrite_query_with_legal_synonyms | |
import prompt_templete | |
logger = logging.getLogger(__name__) | |
class AdvancedLawRetriever(BaseRetriever): | |
client: weaviate.WeaviateClient | |
collection_name: str | |
llm: Any | |
reranker: Any | |
embeddings_model: Any | |
default_k: int = 3 | |
initial_k: int = 10 # Lấy nhiều ứng viên ban đầu | |
hybrid_search_alpha: float = 0.5 | |
doc_type_boost: float = 0.4 | |
class ConfigDict: | |
arbitrary_types_allowed = True | |
# === CÁC HÀM HELPER === | |
def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]: | |
"""Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả.""" | |
keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT) | |
keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()]) | |
try: | |
keywords = keyword_chain.invoke({"question": question}) | |
# Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh | |
return [question] + keywords | |
except Exception as e: | |
logger.error(f"Failed to extract keywords: {e}") | |
return [question] | |
def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]: | |
""" | |
CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter. | |
Nó không còn nhiệm vụ suy luận nữa. | |
""" | |
if not filters_dict: | |
return None | |
filter_conditions = [] | |
for key, value in filters_dict.items(): | |
if value is None: | |
continue | |
# Logic xây dựng Filter | |
if key == "entity_type" and isinstance(value, list) and value: | |
filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value)) | |
elif isinstance(value, str): | |
filter_conditions.append(wvc_query.Filter.by_property(key).equal(value)) | |
# Thêm các điều kiện khác nếu cần | |
if not filter_conditions: | |
return None | |
return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0] | |
def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]: | |
# ... (giữ nguyên logic) ... | |
try: | |
collection = self.client.collections.get(self.collection_name) | |
query_vector = self.embeddings_model.embed_query(query) | |
response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True)) | |
docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects] | |
return docs | |
except Exception: return [] | |
# === HÀM CHÍNH _get_relevant_documents === | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
# ================================================================= | |
# PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn | |
# ================================================================= | |
# 0.1. Đảm bảo an toàn cho input | |
safe_query = str(query) | |
logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---") | |
# 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc | |
query_info = self._extract_query_info_with_intent(safe_query) | |
inferred_field = query_info.get("base_filters", {}).get("field") | |
preferred_doc_type = query_info.get("preferred_doc_type") | |
# 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển | |
rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field) | |
if safe_query != rewritten_query: | |
logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'") | |
# 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại | |
search_terms = self._extract_searchable_keywords_with_llm(rewritten_query) | |
logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}") | |
# 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất | |
base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"]) | |
# ================================================================= | |
# PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback | |
# ================================================================= | |
def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]: | |
"""Hàm nội bộ để thực hiện tìm kiếm song song.""" | |
docs = [] | |
with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor: | |
futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms] | |
for future in futures: | |
try: docs.extend(future.result()) | |
except Exception as e: logger.error(f"A search task failed: {e}") | |
return docs | |
logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---") | |
retrieved_docs = run_search_tasks(base_weaviate_filter) | |
# Lọc trùng lặp | |
unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)} | |
# Cơ chế Fallback | |
if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None: | |
logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...") | |
fallback_docs = run_search_tasks(None) | |
for doc in fallback_docs: | |
if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict: | |
unique_docs_dict[doc.page_content] = doc | |
candidate_docs_list = list(unique_docs_dict.values()) | |
# ================================================================= | |
# PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả | |
# ================================================================= | |
# 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên | |
final_candidates_for_rerank = candidate_docs_list | |
if preferred_doc_type: | |
logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'") | |
docs_with_scores = [] | |
for doc in candidate_docs_list: | |
score = doc.metadata.get('hybrid_score', 0.5) | |
if doc.metadata.get("loai_van_ban") == preferred_doc_type: | |
score += self.doc_type_boost | |
else: | |
score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên | |
docs_with_scores.append((doc, score)) | |
docs_with_scores.sort(key=lambda x: x[1], reverse=True) | |
final_candidates_for_rerank = [doc for doc, score in docs_with_scores] | |
logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.") | |
if not final_candidates_for_rerank: return [] | |
# 2.2. Cross-Encoder Re-ranking với Structured Context | |
logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...") | |
docs_for_reranking = [] | |
for doc in final_candidates_for_rerank: | |
# Tạo chuỗi context giàu thông tin | |
structured_content = ( | |
f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. " | |
f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. " | |
f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n" | |
f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}" | |
) | |
docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content}) | |
contents_to_rank = [item["structured_content"] for item in docs_for_reranking] | |
try: | |
# Kiểm tra nếu không có ứng viên nào để re-rank | |
if not final_candidates_for_rerank: | |
logger.warning("No candidates found for re-ranking. Returning empty list.") | |
return [] | |
logger.info(f"Applying Cross-Encoder re-ranking to {len(final_candidates_for_rerank)} candidates...") | |
# SỬ DỤNG PHƯƠNG THỨC .compress_documents() | |
# Đây là phương thức chính của CrossEncoderReranker trong LangChain. | |
# Nó nhận danh sách Document và một chuỗi query, sau đó trả về | |
# danh sách Document đã được sắp xếp lại và lọc theo top_n. | |
reranked_docs = self.reranker.compress_documents( | |
documents=final_candidates_for_rerank, | |
query=rewritten_query | |
) | |
# `reranked_docs` bây giờ là một danh sách các tài liệu tốt nhất, | |
# đã được lọc theo `top_n` bạn đã cấu hình trong get_reranker_compressor (mặc định là 4). | |
logger.info(f"Re-ranking complete. Found {len(reranked_docs)} relevant documents.") | |
# Thêm log chi tiết để debug | |
# ---- LOGGING ĐỂ DEBUG ---- | |
logger.info("--- Top documents after re-ranking ---") | |
for i, doc in enumerate(reranked_docs): | |
doc_info = { | |
"content_preview": doc.page_content[:200] + "...", | |
"metadata": doc.metadata, | |
# CrossEncoderReranker sẽ thêm điểm số vào metadata | |
"relevance_score": doc.metadata.get('relevance_score') | |
} | |
logger.info(f"Doc {i+1}: {doc_info}") | |
# ---- KẾT THÚC LOGGING ---- | |
return reranked_docs | |
except Exception as e: | |
import traceback | |
logger.error(f"An unexpected error occurred during re-ranking: {e}\n{traceback.format_exc()}") | |
# Fallback an toàn: Trả về kết quả từ VectorDB mà không re-rank, | |
# lấy theo số lượng top_k của retriever ban đầu. | |
logger.warning("Falling back to returning top results from vector search without re-ranking.") | |
# Giả sử bạn có self.default_k được cấu hình | |
return final_candidates_for_rerank[:getattr(self, 'default_k', 4)] | |
# Lấy lại các Document gốc theo thứ tự đã được re-rank | |
final_reranked_docs = [] | |
for rank_info in ranked_results_info: | |
original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"] | |
original_doc.metadata['rerank_score'] = rank_info['score'] | |
final_reranked_docs.append(original_doc) | |
# 2.3. Log và Trả về kết quả cuối cùng | |
logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---") | |
for i, doc in enumerate(final_reranked_docs[:self.default_k]): | |
score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}" | |
logger.info(f" - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}") | |
logger.info(f" CONTENT: {doc.page_content[:400]}...") # Log dài hơn | |
logger.info("-" * 25) | |
return final_reranked_docs[:self.default_k] | |
def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]: | |
""" | |
Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản. | |
""" | |
info = {"base_filters": {}, "preferred_doc_type": None} | |
query_lower = query.lower() | |
intent_found = False | |
# --- Giai đoạn 1: Xác định ý định và loại văn bản ưu tiên --- | |
# Đặt các intent cụ thể và dễ xung đột lên đầu. | |
# INTENT 1: Dữ kiện cụ thể (Diện tích, Dân số) - Rất cụ thể | |
if not intent_found and ( | |
any(kw in query_lower for kw in ["diện tích", "dân số", "số lượng"]) | |
and any(loc_kw in query_lower for loc_kw in ["tỉnh", "thành phố", "huyện", "xã", "việt nam"]) | |
): | |
info["preferred_doc_type"] = "NGHỊ QUYẾT" | |
logger.info("Intent detected: Specific Factual Data (Area, Population) -> Preferring 'NGHỊ QUYẾT'.") | |
intent_found = True | |
# Với intent này, chúng ta không áp dụng bộ lọc entity_type tự động. | |
# INTENT 2: Mức phạt / Chế tài - Rất cụ thể | |
if not intent_found and any(kw in query_lower for kw in [ | |
"phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt", "phạt tiền", | |
"mức phạt", "số tiền phạt", "bị phạt", "chế phạt", "tước bằng", | |
"tạm giữ phương tiện", "tịch thu", "vi phạm" | |
]): | |
info["preferred_doc_type"] = "NGHỊ ĐỊNH" | |
logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.") | |
intent_found = True | |
# INTENT 3: Thủ tục hành chính - Cụ thể | |
if not intent_found and any(kw in query_lower for kw in [ | |
"thủ tục", "hồ sơ", "đăng ký", "cấp phép", "giấy phép", "chứng nhận", | |
"xin phép", "nộp hồ sơ", "thời hạn giải quyết", "lệ phí", "phí", "thẩm quyền", | |
"cơ quan nào", "nộp ở đâu", "ai cấp", "quy trình", "trình tự" | |
]): | |
info["preferred_doc_type"] = "THÔNG TƯ" | |
logger.info("Intent detected: Administrative Procedure -> Preferring 'THÔNG TƯ'.") | |
intent_found = True | |
# INTENT 4: Biểu mẫu / Tiêu chuẩn kỹ thuật - Cụ thể | |
if not intent_found and any(kw in query_lower for kw in [ | |
"biểu mẫu", "mẫu đơn", "form", "danh mục", "bảng biểu", "định mức", | |
"khung", "tiêu chuẩn kỹ thuật", "quy chuẩn", "đơn giá" | |
]): | |
info["preferred_doc_type"] = "THÔNG TƯ" | |
logger.info("Intent detected: Forms/Standards/Rates -> Preferring 'THÔNG TƯ'.") | |
intent_found = True | |
# INTENT 5: Tổ chức bộ máy cụ thể (Sáp nhập, thành lập) - Chung hơn | |
# Phải đứng trước intent về Nguyên tắc chung để bắt các case "sáp nhập đơn vị hành chính" | |
if not intent_found and any(kw in query_lower for kw in [ | |
"bộ máy", "cơ cấu tổ chức", "chức năng", "nhiệm vụ cụ thể", | |
"thành lập", "giải thể", "sáp nhập", "chia tách", "tái cơ cấu", | |
"bổ nhiệm", "miễn nhiệm" | |
]): | |
info["preferred_doc_type"] = "NGHỊ ĐỊNH" # Có thể là NGHỊ QUYẾT nữa | |
logger.info("Intent detected: Specific Organization Structure -> Preferring 'NGHỊ ĐỊNH'.") | |
intent_found = True | |
# INTENT 6: Nguyên tắc chung / Định nghĩa / Cấu trúc cơ bản - Rất chung | |
# Đặt cuối cùng để nó chỉ bắt các câu hỏi chung chung thực sự. | |
if not intent_found and any(kw in query_lower for kw in [ | |
"nguyên tắc", "quyền và nghĩa vụ", "quyền", "nghĩa vụ", "trách nhiệm", | |
"khái niệm", "định nghĩa", "là gì", "hiểu như thế nào", "nghiêm cấm", | |
"đơn vị hành chính", "cấp hành chính", "bao nhiêu cấp", "phân cấp", | |
"cấu trúc nhà nước", "hệ thống chính quyền" | |
]): | |
info["preferred_doc_type"] = "LUẬT" | |
logger.info("Intent detected: General Principle/Definition/Rights/Administrative Structure -> Preferring 'LUẬT'.") | |
intent_found = True | |
# --- Giai đoạn 2: Áp dụng bộ lọc (filters) một cách có kiểm soát --- | |
# 2.1. Lọc theo lĩnh vực (field) - Có thể áp dụng cho hầu hết các trường hợp | |
inferred_field = infer_field(query, None) | |
if inferred_field and inferred_field != "khac": | |
info["base_filters"]["field"] = inferred_field | |
logger.info(f"Applying filter: field = {inferred_field}") | |
# 2.2. Lọc theo loại thực thể (entity_type) - CHỈ áp dụng khi thực sự cần thiết | |
# Chúng ta KHÔNG áp dụng bộ lọc này cho các câu hỏi về dữ kiện địa lý | |
# hoặc các câu hỏi chung chung có thể không có entity_type rõ ràng. | |
# Ví dụ, chỉ áp dụng cho các câu hỏi về thủ tục, tổ chức, chế tài... | |
# Lấy lại intent đã xác định ở trên để quyết định | |
current_intent_log = logger.handlers[0].records[-1].msg if logger.handlers and logger.handlers[0].records else "" | |
# Chỉ áp dụng entity filter cho các intent này | |
applicable_intents_for_entity_filter = [ | |
"Sanction/Penalty", | |
"Administrative Procedure", | |
"Specific Organization Structure" | |
] | |
if any(intent_name in current_intent_log for intent_name in applicable_intents_for_entity_filter): | |
inferred_entities = infer_entity_type(query, inferred_field) | |
if inferred_entities: | |
info["base_filters"]["entity_type"] = inferred_entities | |
logger.info(f"Applying filter: entity_type = {inferred_entities}") | |
else: | |
logger.info("Skipping entity_type filter for the detected intent.") | |
logger.info(f"Final extracted query info: {info}") | |
return info |