chatbot_demo / retriever.py
deddoggo's picture
update
0c16fc9
raw
history blame
3.69 kB
# retrieval_handler.py
# Chịu trách nhiệm cho mọi logic liên quan đến việc truy xuất thông tin (Retrieval).
import json
import re
import numpy as np
import faiss
from collections import defaultdict
from typing import List, Dict, Any, Optional
from utils import tokenize_vi_simple # Import từ file utils.py
# --- HÀM XỬ LÝ DỮ LIỆU ---
def process_law_data_to_chunks(structured_data: Any) -> List[Dict]:
"""Làm phẳng dữ liệu luật có cấu trúc thành danh sách các chunks."""
flat_list = []
articles = [structured_data] if isinstance(structured_data, dict) else structured_data
for article_data in articles:
if not isinstance(article_data, dict): continue
# (Logic xử lý chi tiết của bạn ở đây... đã được rút gọn để dễ đọc)
# Giả sử logic này hoạt động đúng như bạn đã thiết kế
# và trả về một danh sách các chunk, mỗi chunk là một dict có "text" và "metadata".
# Để đảm bảo, tôi sẽ thêm một phiên bản đơn giản hóa ở đây.
clauses = article_data.get("clauses", [])
for clause in clauses:
points = clause.get("points_in_clause", [])
if points:
for point in points:
text = point.get("point_text_original")
if text:
flat_list.append({"text": text, "metadata": {"article": article_data.get("article"), "clause": clause.get("clause_number"), "point": point.get("point_id")}})
else:
text = clause.get("clause_text_original")
if text:
flat_list.append({"text": text, "metadata": {"article": article_data.get("article"), "clause": clause.get("clause_number")}})
return flat_list
# --- HÀM TÌM KIẾM ---
def search_relevant_laws(
query_text: str,
embedding_model,
faiss_index,
chunks_data: List[Dict],
bm25_model,
k: int = 5,
rrf_k_constant: int = 60
) -> List[Dict]:
"""
Thực hiện Hybrid Search (Semantic + Keyword) với RRF để tìm các chunk liên quan.
"""
print(f"🔎 Bắt đầu tìm kiếm cho: '{query_text}'")
# 1. Semantic Search (FAISS)
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True)
query_embedding_np = query_embedding.cpu().numpy().astype('float32')
faiss.normalize_L2(query_embedding_np)
num_candidates = min(k * 10, faiss_index.ntotal)
_, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
# 2. Keyword Search (BM25)
tokenized_query = tokenize_vi_simple(query_text)
bm25_scores = bm25_model.get_scores(tokenized_query)
bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates]
# 3. Reciprocal Rank Fusion (RRF)
rrf_scores = defaultdict(float)
if semantic_indices.size > 0:
for rank, doc_idx in enumerate(semantic_indices[0]):
if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
for rank, (doc_idx, score) in enumerate(bm25_results):
if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
# 4. Trả về top K kết quả cuối cùng
final_results = []
for doc_idx, score in fused_results[:k]:
result = chunks_data[doc_idx].copy()
result['retrieval_score'] = score
final_results.append(result)
print(f"✅ Tìm kiếm hoàn tất, trả về {len(final_results)} kết quả.")
return final_results