Spaces:
Paused
Paused
| # file: retriever.py | |
| import faiss | |
| import numpy as np | |
| import torch | |
| import re | |
| from collections import defaultdict | |
| from rank_bm25 import BM25Okapi | |
| def tokenize_vi_for_bm25_setup(text): | |
| """Tokenize tiếng Việt đơn giản cho BM25.""" | |
| text = text.lower() | |
| text = re.sub(r'[^\w\s]', '', text) | |
| return text.split() | |
| def _get_vehicle_type(query_lower: str) -> str | None: | |
| """Xác định loại xe được đề cập trong câu truy vấn.""" | |
| # Từ điển định nghĩa các từ khóa cho từng loại xe | |
| vehicle_keywords = { | |
| "ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"], | |
| "xe máy": ["xe máy", "xe mô tô", "xe gắn máy"], | |
| "xe đạp": ["xe đạp", "xe thô sơ"], | |
| "máy kéo": ["máy kéo", "xe chuyên dùng"] | |
| } | |
| for vehicle_type, keywords in vehicle_keywords.items(): | |
| if any(keyword in query_lower for keyword in keywords): | |
| return vehicle_type | |
| return None | |
| def search_relevant_laws( | |
| query_text: str, | |
| embedding_model, | |
| faiss_index, | |
| chunks_data: list[dict], | |
| bm25_model, | |
| k: int = 5, | |
| initial_k_multiplier: int = 15, | |
| rrf_k_constant: int = 60 | |
| ) -> list[dict]: | |
| """ | |
| Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe. | |
| Quy trình: | |
| 1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa). | |
| 2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF). | |
| 3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe). | |
| 4. Sắp xếp lại và trả về top-k kết quả cuối cùng. | |
| """ | |
| if k <= 0: | |
| return [] | |
| num_vectors_in_index = faiss_index.ntotal | |
| if num_vectors_in_index == 0: | |
| return [] | |
| num_candidates = min(k * initial_k_multiplier, num_vectors_in_index) | |
| # --- 1. Semantic Search (FAISS) --- | |
| try: | |
| query_embedding = embedding_model.encode([query_text], convert_to_tensor=True) | |
| query_embedding_np = query_embedding.cpu().numpy().astype('float32') | |
| faiss.normalize_L2(query_embedding_np) | |
| _, semantic_indices = faiss_index.search(query_embedding_np, num_candidates) | |
| semantic_indices = semantic_indices[0] | |
| except Exception as e: | |
| print(f"Lỗi FAISS search: {e}") | |
| semantic_indices = [] | |
| # --- 2. Keyword Search (BM25) --- | |
| try: | |
| tokenized_query = tokenize_vi_for_bm25_setup(query_text) | |
| bm25_scores = bm25_model.get_scores(tokenized_query) | |
| # Lấy top N chỉ mục từ BM25 | |
| top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates] | |
| except Exception as e: | |
| print(f"Lỗi BM25 search: {e}") | |
| top_bm25_indices = [] | |
| # --- 3. Result Fusion (RRF) --- | |
| rrf_scores = defaultdict(float) | |
| all_indices = set(semantic_indices) | set(top_bm25_indices) | |
| for rank, doc_idx in enumerate(semantic_indices): | |
| rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
| for rank, doc_idx in enumerate(top_bm25_indices): | |
| rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
| # --- 4. Metadata Boosting & Final Ranking --- | |
| query_lower = query_text.lower() | |
| matched_vehicle = _get_vehicle_type(query_lower) | |
| final_results = [] | |
| for doc_idx in all_indices: | |
| try: | |
| metadata = chunks_data[doc_idx].get('metadata', {}) | |
| final_score = rrf_scores[doc_idx] | |
| # **LOGIC BOOSTING QUAN TRỌNG NHẤT** | |
| if matched_vehicle: | |
| article_title_lower = metadata.get("article_title", "").lower() | |
| # Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài | |
| vehicle_keywords = { | |
| "ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"], | |
| "xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"] | |
| } | |
| if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])): | |
| # Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên | |
| final_score += 0.5 | |
| final_results.append({ | |
| 'index': doc_idx, | |
| 'final_score': final_score | |
| }) | |
| except IndexError: | |
| continue | |
| final_results.sort(key=lambda x: x['final_score'], reverse=True) | |
| # Lấy đầy đủ thông tin cho top-k kết quả cuối cùng | |
| top_k_results = [] | |
| for res in final_results[:k]: | |
| doc_idx = res['index'] | |
| top_k_results.append({ | |
| 'index': doc_idx, | |
| 'final_score': res['final_score'], | |
| 'text': chunks_data[doc_idx].get('text', ''), | |
| 'metadata': chunks_data[doc_idx].get('metadata', {}) | |
| }) | |
| return top_k_results |