Spaces:
Runtime error
Runtime error
File size: 12,269 Bytes
a6fd1a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import logging
from typing import List, Dict, Any, Optional
import weaviate
import weaviate.classes.query as wvc_query
from concurrent.futures import ThreadPoolExecutor
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from utils.process_data import infer_field, infer_entity_type
from utils.synonym_map import rewrite_query_with_legal_synonyms
import prompt_templete
logger = logging.getLogger(__name__)
class AdvancedLawRetriever(BaseRetriever):
client: weaviate.WeaviateClient
collection_name: str
llm: Any
reranker: Any
embeddings_model: Any
default_k: int = 5
initial_k: int = 15 # Lấy nhiều ứng viên ban đầu
hybrid_search_alpha: float = 0.5
doc_type_boost: float = 0.4
class ConfigDict:
arbitrary_types_allowed = True
# === CÁC HÀM HELPER ===
def _extract_searchable_keywords_with_llm(self, question: str) -> List[str]:
"""Sử dụng LLM để trích xuất các cụm từ khóa tìm kiếm hiệu quả."""
keyword_extraction_prompt = ChatPromptTemplate.from_template(prompt_templete.KEYWORD_EXTRACTION_PROMPT)
keyword_chain = keyword_extraction_prompt | self.llm | StrOutputParser() | (lambda text: [k.strip() for k in text.strip().split("\n") if k.strip()])
try:
keywords = keyword_chain.invoke({"question": question})
# Luôn bao gồm cả câu hỏi gốc đã được viết lại làm một truy vấn để không mất ngữ cảnh
return [question] + keywords
except Exception as e:
logger.error(f"Failed to extract keywords: {e}")
return [question]
def _extract_and_build_filters(self, filters_dict: Dict[str, Any]) -> Optional[wvc_query.Filter]:
"""
CẢI TIẾN: Hàm này CHỈ nhận một dict và xây dựng đối tượng Filter.
Nó không còn nhiệm vụ suy luận nữa.
"""
if not filters_dict:
return None
filter_conditions = []
for key, value in filters_dict.items():
if value is None:
continue
# Logic xây dựng Filter
if key == "entity_type" and isinstance(value, list) and value:
filter_conditions.append(wvc_query.Filter.by_property(key).contains_any(value))
elif isinstance(value, str):
filter_conditions.append(wvc_query.Filter.by_property(key).equal(value))
# Thêm các điều kiện khác nếu cần
if not filter_conditions:
return None
return wvc_query.Filter.all_of(filter_conditions) if len(filter_conditions) > 1 else filter_conditions[0]
def _perform_hybrid_search(self, query: str, k: int, where_filter: Optional[wvc_query.Filter]) -> List[Document]:
# ... (giữ nguyên logic) ...
try:
collection = self.client.collections.get(self.collection_name)
query_vector = self.embeddings_model.embed_query(query)
response = collection.query.hybrid(query=query, vector=query_vector, limit=k, alpha=self.hybrid_search_alpha, filters=where_filter, return_metadata=wvc_query.MetadataQuery(score=True))
docs = [Document(page_content=obj.properties.pop('text', ''), metadata={**obj.properties, 'hybrid_score': obj.metadata.score if obj.metadata else 0}) for obj in response.objects]
return docs
except Exception: return []
# === HÀM CHÍNH _get_relevant_documents ===
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# =================================================================
# PHASE 0: PREPARATION - Chuẩn bị và làm giàu truy vấn
# =================================================================
# 0.1. Đảm bảo an toàn cho input
safe_query = str(query)
logger.info(f"--- Starting Advanced Retrieval (FINAL) for Original Query: '{safe_query}' ---")
# 0.2. Trích xuất thông tin và ý định từ câu hỏi gốc
query_info = self._extract_query_info_with_intent(safe_query)
inferred_field = query_info.get("base_filters", {}).get("field")
preferred_doc_type = query_info.get("preferred_doc_type")
# 0.3. "Dịch" câu hỏi sang ngôn ngữ pháp lý bằng từ điển
rewritten_query = rewrite_query_with_legal_synonyms(safe_query, field=inferred_field)
if safe_query != rewritten_query:
logger.info(f"Query after Synonym Rewriting: '{rewritten_query}'")
# 0.4. Trích xuất từ khóa "vàng" bằng LLM từ câu hỏi đã được viết lại
search_terms = self._extract_searchable_keywords_with_llm(rewritten_query)
logger.info(f"Extracted {len(search_terms)} searchable terms: {search_terms}")
# 0.5. Xây dựng bộ lọc Weaviate từ thông tin đã trích xuất
base_weaviate_filter = self._extract_and_build_filters(query_info["base_filters"])
# =================================================================
# PHASE 1: RETRIEVAL - Truy xuất dữ liệu có fallback
# =================================================================
def run_search_tasks(filters: Optional[wvc_query.Filter]) -> List[Document]:
"""Hàm nội bộ để thực hiện tìm kiếm song song."""
docs = []
with ThreadPoolExecutor(max_workers=len(search_terms) or 1) as executor:
futures = [executor.submit(self._perform_hybrid_search, term, self.initial_k, filters) for term in search_terms]
for future in futures:
try: docs.extend(future.result())
except Exception as e: logger.error(f"A search task failed: {e}")
return docs
logger.info(f"--- Attempt 1: Searching with inferred filters: {base_weaviate_filter} ---")
retrieved_docs = run_search_tasks(base_weaviate_filter)
# Lọc trùng lặp
unique_docs_dict = {doc.page_content: doc for doc in retrieved_docs if isinstance(doc.page_content, str)}
# Cơ chế Fallback
if len(unique_docs_dict) < self.default_k and base_weaviate_filter is not None:
logger.warning("Initial search yielded few results. Retrying without any filters (fallback)...")
fallback_docs = run_search_tasks(None)
for doc in fallback_docs:
if isinstance(doc.page_content, str) and doc.page_content not in unique_docs_dict:
unique_docs_dict[doc.page_content] = doc
candidate_docs_list = list(unique_docs_dict.values())
# =================================================================
# PHASE 2: REFINEMENT - Tinh chỉnh, ưu tiên và xếp hạng kết quả
# =================================================================
# 2.1. Intent-based Boosting: Tăng điểm dựa trên loại văn bản ưu tiên
final_candidates_for_rerank = candidate_docs_list
if preferred_doc_type:
logger.info(f"Applying INTENT-BASED BOOST for preferred type: '{preferred_doc_type}'")
docs_with_scores = []
for doc in candidate_docs_list:
score = doc.metadata.get('hybrid_score', 0.5)
if doc.metadata.get("loai_van_ban") == preferred_doc_type:
score += self.doc_type_boost
else:
score -= 0.05 # Giảm nhẹ điểm của các loại không ưu tiên
docs_with_scores.append((doc, score))
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_candidates_for_rerank = [doc for doc, score in docs_with_scores]
logger.info(f"Found {len(final_candidates_for_rerank)} candidates for re-ranking.")
if not final_candidates_for_rerank: return []
# 2.2. Cross-Encoder Re-ranking với Structured Context
logger.info("Applying Cross-Encoder re-ranking with STRUCTURED CONTEXT...")
docs_for_reranking = []
for doc in final_candidates_for_rerank:
# Tạo chuỗi context giàu thông tin
structured_content = (
f"Loại văn bản: {doc.metadata.get('loai_van_ban', 'N/A')}. "
f"Lĩnh vực: {doc.metadata.get('field', 'N/A')}. "
f"Đối tượng: {doc.metadata.get('entity_type', 'N/A')}.\n"
f"Nội dung trích từ {doc.metadata.get('title', 'N/A')}: {doc.page_content}"
)
docs_for_reranking.append({"original_doc": doc, "structured_content": structured_content})
contents_to_rank = [item["structured_content"] for item in docs_for_reranking]
try:
# Sử dụng câu hỏi đã được viết lại để có ngữ cảnh tốt nhất
ranked_results_info = self.reranker.rank(rewritten_query, contents_to_rank, return_documents=False, top_k=self.default_k * 2) # Lấy nhiều hơn một chút
except Exception as e:
logger.error(f"Failed to re-rank with custom structured content: {e}. Falling back to default re-ranking.")
# Fallback về cách re-rank mặc định nếu có lỗi
reranked_docs = self.reranker.compress_documents(final_candidates_for_rerank, rewritten_query)
return reranked_docs[:self.default_k]
# Lấy lại các Document gốc theo thứ tự đã được re-rank
final_reranked_docs = []
for rank_info in ranked_results_info:
original_doc = docs_for_reranking[rank_info['corpus_id']]["original_doc"]
original_doc.metadata['rerank_score'] = rank_info['score']
final_reranked_docs.append(original_doc)
# 2.3. Log và Trả về kết quả cuối cùng
logger.info(f"--- Re-ranked down to {len(final_reranked_docs)} documents. Final results: ---")
for i, doc in enumerate(final_reranked_docs[:self.default_k]):
score_str = f"{doc.metadata.get('rerank_score', 0.0):.4f}"
logger.info(f" - RANK #{i+1} | ReRank Score: {score_str} | Source: {doc.metadata.get('source')}")
logger.info(f" CONTENT: {doc.page_content[:400]}...") # Log dài hơn
logger.info("-" * 25)
return final_reranked_docs[:self.default_k]
def _extract_query_info_with_intent(self, query: str) -> Dict[str, Any]:
"""
Trích xuất filter và xác định ý định của câu hỏi để ưu tiên loại văn bản.
"""
info = {"base_filters": {}, "preferred_doc_type": None}
query_lower = query.lower()
# 1. Suy luận field và entity
inferred_field = infer_field(query, None)
if inferred_field and inferred_field != "khac":
info["base_filters"]["field"] = inferred_field
inferred_entities = infer_entity_type(query, inferred_field)
if inferred_entities:
info["base_filters"]["entity_type"] = inferred_entities
# 2. XÁC ĐỊNH Ý ĐỊNH -> ƯU TIÊN LOẠI VĂN BẢN
# Nếu câu hỏi về MỨC PHẠT, ưu tiên tuyệt đối NGHỊ ĐỊNH
if any(kw in query_lower for kw in ["phạt bao nhiêu", "mức xử phạt", "tiền phạt", "xử phạt"]):
info["preferred_doc_type"] = "NGHỊ ĐỊNH"
logger.info("Intent detected: Sanction/Penalty -> Preferring 'NGHỊ ĐỊNH'.")
# Nếu câu hỏi về NGUYÊN TẮC CHUNG, QUYỀN, NGHĨA VỤ, ưu tiên LUẬT
elif any(kw in query_lower for kw in ["nguyên tắc", "quyền và nghĩa vụ", "cấm", "được phép", "khái niệm", "định nghĩa"]):
info["preferred_doc_type"] = "LUẬT"
logger.info("Intent detected: General Rule/Definition -> Preferring 'LUẬT'.")
logger.info(f"Extracted query info: {info}")
return info
|