deddoggo commited on
Commit
a53f1d8
·
1 Parent(s): 0c16fc9
Files changed (2) hide show
  1. rag_pipeline.py +131 -41
  2. retriever.py +111 -76
rag_pipeline.py CHANGED
@@ -1,58 +1,148 @@
1
- # llm_handler.py
2
- # Chịu trách nhiệm cho mọi logic liên quan đến mô hình ngôn ngữ lớn (LLM).
3
-
4
  import torch
 
 
 
 
5
  from unsloth import FastLanguageModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # --- HÀM TẠO CÂU TRẢ LỜI ---
8
- def generate_llm_response(
9
- query: str,
10
- context: str,
11
- llm_model,
12
- tokenizer,
13
- max_new_tokens: int = 512,
14
- temperature: float = 0.3,
15
- top_p: float = 0.9,
16
- ) -> str:
 
 
17
  """
18
- Sinh câu trả lời từ LLM dựa trên câu hỏi ngữ cảnh đã được truy xuất.
19
  """
20
- print("🧠 Bắt đầu sinh câu trả lời từ LLM...")
 
 
 
 
21
 
22
- # Xây dựng prompt
23
- prompt = f"""Bạn là một trợ lý AI chuyên tư vấn về luật giao thông đường bộ Việt Nam.
24
- Dựa vào các thông tin luật được cung cấp dưới đây để trả lời câu hỏi của người dùng một cách chính xác và chi tiết.
25
- Nếu thông tin không đủ, hãy trả lời rằng bạn không tìm thấy thông tin cụ thể trong tài liệu.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- ### Thông tin luật được trích dẫn:
28
  {context}
29
 
30
- ### Câu hỏi của người dùng:
31
  {query}
32
 
33
- ### Trả lời của bạn:"""
34
 
35
- # Tạo input cho model
36
- device = llm_model.device
37
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
38
 
39
- # Cấu hình cho việc sinh văn bản
40
  generation_config = dict(
41
- max_new_tokens=max_new_tokens,
42
- temperature=temperature,
43
- top_p=top_p,
 
 
44
  do_sample=True,
45
- pad_token_id=tokenizer.eos_token_id
 
46
  )
 
 
 
 
 
47
 
48
- try:
49
- output_ids = llm_model.generate(**inputs, **generation_config)
50
- input_length = inputs.input_ids.shape[1]
51
- generated_ids = output_ids[0][input_length:]
52
- response_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
53
- print("✅ Sinh câu trả lời hoàn tất.")
54
- return response_text
55
- except Exception as e:
56
- print(f"❌ Lỗi khi sinh câu trả lời từ LLM: {e}")
57
- return "Xin lỗi, đã có lỗi xảy ra trong quá trình tạo câu trả lời."
58
-
 
1
+ # file: rag_pipeline.py
 
 
2
  import torch
3
+ import json
4
+ import faiss
5
+ import numpy as np
6
+ import re
7
  from unsloth import FastLanguageModel
8
+ from sentence_transformers import SentenceTransformer
9
+ from rank_bm25 import BM25Okapi
10
+ from transformers import TextStreamer
11
+
12
+ # Import các hàm từ file khác
13
+ from data_processor import process_law_data_to_chunks
14
+ from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup
15
+
16
+ def initialize_components(data_path):
17
+ """
18
+ Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline.
19
+ Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động.
20
+ """
21
+ print("--- Bắt đầu khởi tạo các thành phần ---")
22
+
23
+ # 1. Tải LLM và Tokenizer từ Unsloth
24
+ print("1. Tải mô hình LLM (Unsloth)...")
25
+ model, tokenizer = FastLanguageModel.from_pretrained(
26
+ model_name="unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
27
+ max_seq_length=2048,
28
+ dtype=None,
29
+ load_in_4bit=True,
30
+ )
31
+ print("✅ Tải LLM và Tokenizer thành công.")
32
+
33
+ # 2. Tải mô hình Embedding
34
+ print("2. Tải mô hình Embedding...")
35
+ embedding_model = SentenceTransformer(
36
+ "bkai-foundation-models/vietnamese-bi-encoder",
37
+ device="cuda" if torch.cuda.is_available() else "cpu"
38
+ )
39
+ print("✅ Tải mô hình Embedding thành công.")
40
+
41
+ # 3. Tải và xử lý dữ liệu JSON
42
+ print(f"3. Tải và xử lý dữ liệu từ {data_path}...")
43
+ with open(data_path, 'r', encoding='utf-8') as f:
44
+ raw_data = json.load(f)
45
+ chunks_data = process_law_data_to_chunks(raw_data)
46
+ print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.")
47
+
48
+ # 4. Tạo Embeddings và FAISS Index
49
+ print("4. Tạo embeddings và FAISS index...")
50
+ texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
51
+ chunk_embeddings_tensor = embedding_model.encode(
52
+ texts_to_encode,
53
+ convert_to_tensor=True,
54
+ device=embedding_model.device
55
+ )
56
+ chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32')
57
+ faiss.normalize_L2(chunk_embeddings_np)
58
+
59
+ dimension = chunk_embeddings_np.shape[1]
60
+ faiss_index = faiss.IndexFlatIP(dimension)
61
+ faiss_index.add(chunk_embeddings_np)
62
+ print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.")
63
+
64
+ # 5. Tạo BM25 Model
65
+ print("5. Tạo mô hình BM25...")
66
+ corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
67
+ tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
68
+ bm25_model = BM25Okapi(tokenized_corpus_bm25)
69
+ print("✅ Tạo mô hình BM25 thành công.")
70
 
71
+ print("--- Khởi tạo tất cả thành phần hoàn tất ---")
72
+
73
+ return {
74
+ "llm_model": model,
75
+ "tokenizer": tokenizer,
76
+ "embedding_model": embedding_model,
77
+ "chunks_data": chunks_data,
78
+ "faiss_index": faiss_index,
79
+ "bm25_model": bm25_model
80
+ }
81
+
82
+ def generate_response(query, components):
83
  """
84
+ Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo.
85
  """
86
+ print("--- Bắt đầu quy trình RAG cho query mới ---")
87
+
88
+ # Unpack các thành phần
89
+ llm_model = components["llm_model"]
90
+ tokenizer = components["tokenizer"]
91
 
92
+ # 1. Truy xuất ngữ cảnh
93
+ retrieved_results = search_relevant_laws(
94
+ query_text=query,
95
+ embedding_model=components["embedding_model"],
96
+ faiss_index=components["faiss_index"],
97
+ chunks_data=components["chunks_data"],
98
+ bm25_model=components["bm25_model"],
99
+ k=5,
100
+ initial_k_multiplier=18
101
+ )
102
+
103
+ # 2. Định dạng Context
104
+ if not retrieved_results:
105
+ context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
106
+ else:
107
+ context_parts = []
108
+ for i, res in enumerate(retrieved_results):
109
+ metadata = res.get('metadata', {})
110
+ header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
111
+ text = res.get('text', '*Nội dung không có*')
112
+ context_parts.append(f"{header}\n{text}")
113
+ context = "\n\n---\n\n".join(context_parts)
114
+
115
+ # 3. Xây dựng Prompt và tạo câu trả lời
116
+ prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
117
+ Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
118
+ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
119
 
120
+ ### Thông tin luật:
121
  {context}
122
 
123
+ ### Câu hỏi:
124
  {query}
125
 
126
+ ### Trả lời:"""
127
 
128
+ print("--- Bắt đầu tạo câu trả lời từ LLM ---")
129
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
 
130
 
 
131
  generation_config = dict(
132
+ max_new_tokens=256,
133
+ temperature=0.5,
134
+ top_p=0.7,
135
+ top_k=50,
136
+ repetition_penalty=1.1,
137
  do_sample=True,
138
+ pad_token_id=tokenizer.eos_token_id,
139
+ eos_token_id=tokenizer.eos_token_id
140
  )
141
+
142
+ output_ids = llm_model.generate(**inputs, **generation_config)
143
+ input_length = inputs.input_ids.shape[1]
144
+ generated_ids = output_ids[0][input_length:]
145
+ response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
146
 
147
+ print("--- Tạo câu trả lời hoàn tất ---")
148
+ return response_text
 
 
 
 
 
 
 
 
 
retriever.py CHANGED
@@ -1,82 +1,117 @@
1
- # retrieval_handler.py
2
- # Chịu trách nhiệm cho mọi logic liên quan đến việc truy xuất thông tin (Retrieval).
3
-
4
- import json
5
- import re
6
- import numpy as np
7
  import faiss
 
 
 
8
  from collections import defaultdict
9
- from typing import List, Dict, Any, Optional
10
- from utils import tokenize_vi_simple # Import từ file utils.py
11
-
12
- # --- HÀM XỬ DỮ LIỆU ---
13
- def process_law_data_to_chunks(structured_data: Any) -> List[Dict]:
14
- """Làm phẳng dữ liệu luật có cấu trúc thành danh sách các chunks."""
15
- flat_list = []
16
- articles = [structured_data] if isinstance(structured_data, dict) else structured_data
17
- for article_data in articles:
18
- if not isinstance(article_data, dict): continue
19
- # (Logic xử lý chi tiết của bạn ở đây... đã được rút gọn để dễ đọc)
20
- # Giả sử logic này hoạt động đúng như bạn đã thiết kế
21
- # và trả về một danh sách các chunk, mỗi chunk là một dict có "text" và "metadata".
22
- # Để đảm bảo, tôi sẽ thêm một phiên bản đơn giản hóa ở đây.
23
- clauses = article_data.get("clauses", [])
24
- for clause in clauses:
25
- points = clause.get("points_in_clause", [])
26
- if points:
27
- for point in points:
28
- text = point.get("point_text_original")
29
- if text:
30
- flat_list.append({"text": text, "metadata": {"article": article_data.get("article"), "clause": clause.get("clause_number"), "point": point.get("point_id")}})
31
- else:
32
- text = clause.get("clause_text_original")
33
- if text:
34
- flat_list.append({"text": text, "metadata": {"article": article_data.get("article"), "clause": clause.get("clause_number")}})
35
- return flat_list
36
-
37
-
38
- # --- HÀM TÌM KIẾM ---
39
  def search_relevant_laws(
40
- query_text: str,
41
- embedding_model,
42
- faiss_index,
43
- chunks_data: List[Dict],
44
- bm25_model,
45
- k: int = 5,
46
- rrf_k_constant: int = 60
47
- ) -> List[Dict]:
 
48
  """
49
- Thực hiện Hybrid Search (Semantic + Keyword) với RRF để tìm các chunk liên quan.
 
50
  """
51
- print(f"🔎 Bắt đầu tìm kiếm cho: '{query_text}'")
52
- # 1. Semantic Search (FAISS)
53
- query_embedding = embedding_model.encode([query_text], convert_to_tensor=True)
54
- query_embedding_np = query_embedding.cpu().numpy().astype('float32')
55
- faiss.normalize_L2(query_embedding_np)
56
- num_candidates = min(k * 10, faiss_index.ntotal)
57
- _, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
58
-
59
- # 2. Keyword Search (BM25)
60
- tokenized_query = tokenize_vi_simple(query_text)
61
- bm25_scores = bm25_model.get_scores(tokenized_query)
62
- bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates]
63
-
64
- # 3. Reciprocal Rank Fusion (RRF)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  rrf_scores = defaultdict(float)
66
- if semantic_indices.size > 0:
67
- for rank, doc_idx in enumerate(semantic_indices[0]):
68
- if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
69
- for rank, (doc_idx, score) in enumerate(bm25_results):
70
- if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
71
-
72
- fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
73
-
74
- # 4. Trả về top K kết quả cuối cùng
75
- final_results = []
76
- for doc_idx, score in fused_results[:k]:
77
- result = chunks_data[doc_idx].copy()
78
- result['retrieval_score'] = score
79
- final_results.append(result)
80
-
81
- print(f"✅ Tìm kiếm hoàn tất, trả về {len(final_results)} kết quả.")
82
- return final_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: retriever.py
 
 
 
 
 
2
  import faiss
3
+ import numpy as np
4
+ import torch
5
+ import re
6
  from collections import defaultdict
7
+ from rank_bm25 import BM25Okapi
8
+
9
+ def tokenize_vi_for_bm25_setup(text):
10
+ """Tokenize tiếng Việt đơn giản cho BM25."""
11
+ text = text.lower()
12
+ text = re.sub(r'[^\w\s]', '', text)
13
+ return text.split()
14
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def search_relevant_laws(
16
+ query_text,
17
+ embedding_model,
18
+ faiss_index,
19
+ chunks_data,
20
+ bm25_model,
21
+ k=5,
22
+ initial_k_multiplier=10,
23
+ rrf_k_constant=60
24
+ ):
25
  """
26
+ Thực hiện tìm kiếm lai (Hybrid Search) kết hợp Semantic Search (FAISS) Keyword Search (BM25),
27
+ sau đó kết hợp kết quả bằng Reciprocal Rank Fusion (RRF) và tăng cường bằng metadata.
28
  """
29
+ if k <= 0:
30
+ print("Lỗi: k (số lượng kết quả) phải là số dương.")
31
+ return []
32
+
33
+ print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'")
34
+ query_lower = query_text.lower()
35
+
36
+ # Phân tích query
37
+ fine_keywords = r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính'
38
+ points_keywords = r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx'
39
+ query_mentions_fine = bool(re.search(fine_keywords, query_lower))
40
+ query_mentions_points = bool(re.search(points_keywords, query_lower))
41
+ needs_specific_metadata_filter = query_mentions_fine or query_mentions_points
42
+ print(f" Phân tích query: Đề cập tiền phạt? {query_mentions_fine}, Đề cập điểm trừ? {query_mentions_points}")
43
+
44
+ num_vectors_in_index = faiss_index.ntotal
45
+ if num_vectors_in_index == 0:
46
+ print("Lỗi: FAISS index rỗng.")
47
+ return []
48
+
49
+ num_candidates_each_retriever = min(k * initial_k_multiplier, num_vectors_in_index)
50
+
51
+ # === 1. Semantic Search (FAISS) ===
52
+ try:
53
+ query_embedding_tensor = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
54
+ query_embedding_np = query_embedding_tensor.cpu().numpy().astype('float32')
55
+ faiss.normalize_L2(query_embedding_np)
56
+ semantic_scores_raw, semantic_indices_raw = faiss_index.search(query_embedding_np, num_candidates_each_retriever)
57
+ except Exception as e:
58
+ print(f"Lỗi khi tìm kiếm ngữ nghĩa (FAISS): {e}")
59
+ semantic_indices_raw = np.array([[]], dtype=int)
60
+
61
+ # === 2. Keyword Search (BM25) ===
62
+ try:
63
+ tokenized_query_bm25 = tokenize_vi_for_bm25_setup(query_text)
64
+ all_bm25_scores = bm25_model.get_scores(tokenized_query_bm25)
65
+ bm25_results_with_indices = [{'index': i, 'score': score} for i, score in enumerate(all_bm25_scores) if score > 0]
66
+ bm25_results_with_indices.sort(key=lambda x: x['score'], reverse=True)
67
+ top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever]
68
+ except Exception as e:
69
+ print(f"Lỗi khi tìm kiếm từ khóa (BM25): {e}")
70
+ top_bm25_results = []
71
+
72
+ # === 3. Result Fusion (RRF) ===
73
  rrf_scores = defaultdict(float)
74
+ all_retrieved_indices_set = set()
75
+
76
+ if semantic_indices_raw.size > 0:
77
+ for rank, doc_idx in enumerate(semantic_indices_raw[0]):
78
+ if 0 <= doc_idx < num_vectors_in_index:
79
+ rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
80
+ all_retrieved_indices_set.add(doc_idx)
81
+
82
+ for rank, item in enumerate(top_bm25_results):
83
+ doc_idx = item['index']
84
+ rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
85
+ all_retrieved_indices_set.add(doc_idx)
86
+
87
+ fused_initial_results = [{'index': doc_idx, 'fused_score': rrf_scores[doc_idx]} for doc_idx in all_retrieved_indices_set]
88
+ fused_initial_results.sort(key=lambda x: x['fused_score'], reverse=True)
89
+
90
+ # === 4. Lọc và Tái xếp hạng cuối cùng ===
91
+ final_processed_results = []
92
+ num_to_process_metadata = min(len(fused_initial_results), num_candidates_each_retriever * 2)
93
+
94
+ for rank_idx, res_item in enumerate(fused_initial_results[:num_to_process_metadata]):
95
+ try:
96
+ result_index = res_item['index']
97
+ base_score_from_fusion = res_item['fused_score']
98
+ original_chunk = chunks_data[result_index]
99
+ original_metadata = original_chunk.get('metadata', {})
100
+ # Thêm logic xử lý metadata boosting ở đây nếu cần...
101
+ # Hiện tại, chỉ trả về kết quả đã fusion.
102
+ # Bạn có thể copy lại toàn bộ logic boosting từ script gốc vào đây.
103
+
104
+ final_score_calculated = base_score_from_fusion # (Thêm boosting vào đây)
105
+
106
+ final_processed_results.append({
107
+ "rank_after_fusion": rank_idx + 1,
108
+ "index": int(result_index),
109
+ "final_score": final_score_calculated,
110
+ "text": original_chunk.get('text', '*Không có text*'),
111
+ "metadata": original_metadata
112
+ })
113
+ except Exception as e:
114
+ print(f"Lỗi khi xử lý ứng viên tại chỉ số {res_item.get('index')}: {e}")
115
+
116
+ final_processed_results.sort(key=lambda x: x["final_score"], reverse=True)
117
+ return final_processed_results[:k]