deddoggo commited on
Commit
43d4d74
·
1 Parent(s): 3264b15
Files changed (1) hide show
  1. rag_pipeline.py +44 -66
rag_pipeline.py CHANGED
@@ -80,60 +80,16 @@ def initialize_components(data_path):
80
  "bm25_model": bm25_model
81
  }
82
 
83
- def _format_context_with_summary(retrieved_results: list[dict]) -> str:
84
- """
85
- Hàm phụ trợ: Định dạng ngữ cảnh từ kết quả truy xuất, bổ sung tóm tắt từ metadata.
86
- Hàm này được tách ra để làm cho code sạch sẽ và dễ quản lý hơn.
87
- """
88
- if not retrieved_results:
89
- return "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
90
-
91
- context_parts = []
92
- for i, res in enumerate(retrieved_results):
93
- metadata = res.get('metadata', {})
94
- text = res.get('text', '*Nội dung không có*')
95
-
96
- # Tạo header rõ ràng
97
- header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Điểm {metadata.get('point_id', '')} Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
98
-
99
- # --- LOGIC TÓM TẮT THÔNG MINH TỪ METADATA ---
100
- metadata_summary = ""
101
- penalty_details_list = metadata.get("penalties_detail", [])
102
-
103
- if penalty_details_list:
104
- summary_parts = []
105
- # Chỉ lấy thông tin từ mục hình phạt đầu tiên trong danh sách
106
- details = penalty_details_list[0].get('details', {})
107
-
108
- # Tóm tắt mức phạt tiền cho cá nhân (phổ biến nhất)
109
- i_min = details.get("individual_fine_min")
110
- i_max = details.get("individual_fine_max")
111
- if i_min is not None and i_max is not None:
112
- summary_parts.append(f"Phạt tiền cá nhân từ {i_min:,} - {i_max:,} đồng.")
113
-
114
- # Tóm tắt mức trừ điểm
115
- points = details.get("points_deducted")
116
- if points is not None:
117
- summary_parts.append(f"Trừ {points} điểm GPLX.")
118
-
119
- if summary_parts:
120
- # Chèn dòng tóm tắt vào giữa header và text
121
- metadata_summary = f"\n[Tóm tắt từ metadata: {' '.join(summary_parts)}]"
122
-
123
- context_parts.append(f"{header}{metadata_summary}\n{text}")
124
-
125
- return "\n\n---\n\n".join(context_parts)
126
-
127
-
128
  def generate_response(query: str, components: dict) -> str:
129
  """
130
  Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
131
- Phiên bản đã được tối ưu tái cấu trúc.
 
132
  """
133
  print("--- Bắt đầu quy trình RAG cho query mới ---")
134
 
135
- # 1. Truy xuất ngữ cảnh bằng retriever đã được nâng cấp
136
- # (Giả định search_relevant_laws đã được sửa để ưu tiên loại xe)
137
  retrieved_results = search_relevant_laws(
138
  query_text=query,
139
  embedding_model=components["embedding_model"],
@@ -144,39 +100,61 @@ def generate_response(query: str, components: dict) -> str:
144
  initial_k_multiplier=15
145
  )
146
 
147
- # 2. Định dạng Context một cách thông minh bằng hàm phụ trợ
148
- context = _format_context_with_summary(retrieved_results)
149
-
150
- # 3. Xây dựng Prompt
151
- prompt = f"""Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam. Dựa vào các trích dẫn luật dưới đây để trả lời câu hỏi của người dùng một cách chính xác.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  ### Thông tin luật:
154
  {context}
155
 
156
  ### Câu hỏi:
157
  {query}
 
 
 
 
 
 
 
 
158
 
159
- ### Trả lời:"""
160
-
161
- # 4. Tạo câu trả lời từ LLM
162
- llm_model = components["llm_model"]
163
- tokenizer = components["tokenizer"]
164
 
165
- # Chuyển input lên cùng device với model
166
- inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
167
 
168
- # Cấu hình generation tối ưu cho việc trả lời câu hỏi dựa trên facts
169
  generation_config = dict(
170
  max_new_tokens=256,
171
- temperature=0.1, # Rất thấp để câu trả lời bám sát ngữ cảnh
172
- repetition_penalty=1.1, # Phạt nhẹ việc lặp từ
173
- do_sample=True, # Vẫn cần bật để temperature và các tham số khác có hiệu lực
174
  pad_token_id=tokenizer.eos_token_id
175
  )
176
 
177
  output_ids = llm_model.generate(**inputs, **generation_config)
178
-
179
- # Chỉ decode phần văn bản được sinh ra mới, bỏ qua phần prompt
180
  response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
181
 
182
  print("--- Tạo câu trả lời hoàn tất ---")
 
80
  "bm25_model": bm25_model
81
  }
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def generate_response(query: str, components: dict) -> str:
84
  """
85
  Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
86
+ Phiên bản này được sửa đổi để tương thích với các mô hình Vision (đa phương thức)
87
+ bằng cách sử dụng chat template.
88
  """
89
  print("--- Bắt đầu quy trình RAG cho query mới ---")
90
 
91
+ # --- Bước 1: Truy xuất Ngữ cảnh (Không thay đổi) ---
92
+ # (Hàm này giả định bạn đã phiên bản retriever.py đã sửa lỗi logic boosting)
93
  retrieved_results = search_relevant_laws(
94
  query_text=query,
95
  embedding_model=components["embedding_model"],
 
100
  initial_k_multiplier=15
101
  )
102
 
103
+ # --- Bước 2: Định dạng Ngữ cảnh (Không thay đổi, nhưng nên tách ra hàm riêng) ---
104
+ if not retrieved_results:
105
+ context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
106
+ else:
107
+ context_parts = []
108
+ for i, res in enumerate(retrieved_results):
109
+ metadata = res.get('metadata', {})
110
+ header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
111
+ text = res.get('text', '*Nội dung không có*')
112
+ context_parts.append(f"{header}\n{text}")
113
+ context = "\n\n---\n\n".join(context_parts)
114
+
115
+ # --- Bước 3: Xây dựng Prompt bằng Chat Template (Đây là phần thay đổi chính) ---
116
+ print("--- Xây dựng prompt bằng chat template ---")
117
+ llm_model = components["llm_model"]
118
+ tokenizer = components["tokenizer"]
119
+
120
+ # Tạo cấu trúc tin nhắn theo chuẩn của các mô hình hội thoại
121
+ messages = [
122
+ {
123
+ "role": "system",
124
+ "content": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam. Hãy dựa vào các thông tin được cung cấp để trả lời một cách chính xác và dễ hiểu."
125
+ },
126
+ {
127
+ "role": "user",
128
+ "content": f"""Dựa vào các trích dẫn luật dưới đây:
129
 
130
  ### Thông tin luật:
131
  {context}
132
 
133
  ### Câu hỏi:
134
  {query}
135
+ """
136
+ }
137
+ ]
138
+
139
+ # Sử dụng apply_chat_template để tạo prompt hoàn chỉnh.
140
+ # Thêm `add_generation_prompt=True` để nó tự động thêm vai trò "assistant" ở cuối,
141
+ # báo hiệu cho model bắt đầu sinh câu trả lời.
142
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
143
 
144
+ # --- Bước 4: Tạo câu trả lời từ LLM (Không thay đổi) ---
145
+ print("--- Bắt đầu tạo câu trả lời từ LLM ---")
 
 
 
146
 
147
+ inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
 
148
 
 
149
  generation_config = dict(
150
  max_new_tokens=256,
151
+ temperature=0.1,
152
+ repetition_penalty=1.1,
153
+ do_sample=True,
154
  pad_token_id=tokenizer.eos_token_id
155
  )
156
 
157
  output_ids = llm_model.generate(**inputs, **generation_config)
 
 
158
  response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
159
 
160
  print("--- Tạo câu trả lời hoàn tất ---")