Spaces:
Paused
Paused
# retrieval_handler.py | |
# Chịu trách nhiệm cho mọi logic liên quan đến việc truy xuất thông tin (Retrieval). | |
import json | |
import re | |
import numpy as np | |
import faiss | |
from collections import defaultdict | |
from typing import List, Dict, Any, Optional | |
from utils import tokenize_vi_simple # Import từ file utils.py | |
# --- HÀM XỬ LÝ DỮ LIỆU --- | |
def process_law_data_to_chunks(structured_data: Any) -> List[Dict]: | |
"""Làm phẳng dữ liệu luật có cấu trúc thành danh sách các chunks.""" | |
flat_list = [] | |
articles = [structured_data] if isinstance(structured_data, dict) else structured_data | |
for article_data in articles: | |
if not isinstance(article_data, dict): continue | |
# (Logic xử lý chi tiết của bạn ở đây... đã được rút gọn để dễ đọc) | |
# Giả sử logic này hoạt động đúng như bạn đã thiết kế | |
# và trả về một danh sách các chunk, mỗi chunk là một dict có "text" và "metadata". | |
# Để đảm bảo, tôi sẽ thêm một phiên bản đơn giản hóa ở đây. | |
clauses = article_data.get("clauses", []) | |
for clause in clauses: | |
points = clause.get("points_in_clause", []) | |
if points: | |
for point in points: | |
text = point.get("point_text_original") | |
if text: | |
flat_list.append({"text": text, "metadata": {"article": article_data.get("article"), "clause": clause.get("clause_number"), "point": point.get("point_id")}}) | |
else: | |
text = clause.get("clause_text_original") | |
if text: | |
flat_list.append({"text": text, "metadata": {"article": article_data.get("article"), "clause": clause.get("clause_number")}}) | |
return flat_list | |
# --- HÀM TÌM KIẾM --- | |
def search_relevant_laws( | |
query_text: str, | |
embedding_model, | |
faiss_index, | |
chunks_data: List[Dict], | |
bm25_model, | |
k: int = 5, | |
rrf_k_constant: int = 60 | |
) -> List[Dict]: | |
""" | |
Thực hiện Hybrid Search (Semantic + Keyword) với RRF để tìm các chunk liên quan. | |
""" | |
print(f"🔎 Bắt đầu tìm kiếm cho: '{query_text}'") | |
# 1. Semantic Search (FAISS) | |
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True) | |
query_embedding_np = query_embedding.cpu().numpy().astype('float32') | |
faiss.normalize_L2(query_embedding_np) | |
num_candidates = min(k * 10, faiss_index.ntotal) | |
_, semantic_indices = faiss_index.search(query_embedding_np, num_candidates) | |
# 2. Keyword Search (BM25) | |
tokenized_query = tokenize_vi_simple(query_text) | |
bm25_scores = bm25_model.get_scores(tokenized_query) | |
bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates] | |
# 3. Reciprocal Rank Fusion (RRF) | |
rrf_scores = defaultdict(float) | |
if semantic_indices.size > 0: | |
for rank, doc_idx in enumerate(semantic_indices[0]): | |
if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
for rank, (doc_idx, score) in enumerate(bm25_results): | |
if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) | |
# 4. Trả về top K kết quả cuối cùng | |
final_results = [] | |
for doc_idx, score in fused_results[:k]: | |
result = chunks_data[doc_idx].copy() | |
result['retrieval_score'] = score | |
final_results.append(result) | |
print(f"✅ Tìm kiếm hoàn tất, trả về {len(final_results)} kết quả.") | |
return final_results | |