deddoggo commited on
Commit
85f9d54
·
1 Parent(s): 43d4d74
Files changed (1) hide show
  1. rag_pipeline.py +33 -15
rag_pipeline.py CHANGED
@@ -83,14 +83,16 @@ def initialize_components(data_path):
83
  def generate_response(query: str, components: dict) -> str:
84
  """
85
  Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
86
- Phiên bản này được sửa đổi để tương thích với các mô hình Vision (đa phương thức)
87
- bằng cách sử dụng chat template.
 
 
88
  """
89
  print("--- Bắt đầu quy trình RAG cho query mới ---")
90
 
91
- # --- Bước 1: Truy xuất Ngữ cảnh (Không thay đổi) ---
92
- # (Hàm này giả định bạn đã có phiên bản retriever.py đã sửa lỗi logic boosting)
93
- retrieved_results = search_relevant_laws(
94
  query_text=query,
95
  embedding_model=components["embedding_model"],
96
  faiss_index=components["faiss_index"],
@@ -100,7 +102,7 @@ def generate_response(query: str, components: dict) -> str:
100
  initial_k_multiplier=15
101
  )
102
 
103
- # --- Bước 2: Định dạng Ngữ cảnh (Không thay đổi, nhưng nên tách ra hàm riêng) ---
104
  if not retrieved_results:
105
  context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
106
  else:
@@ -109,19 +111,30 @@ def generate_response(query: str, components: dict) -> str:
109
  metadata = res.get('metadata', {})
110
  header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
111
  text = res.get('text', '*Nội dung không có*')
 
 
 
 
 
 
 
 
 
 
 
112
  context_parts.append(f"{header}\n{text}")
113
  context = "\n\n---\n\n".join(context_parts)
114
 
115
- # --- Bước 3: Xây dựng Prompt bằng Chat Template (Đây là phần thay đổi chính) ---
116
  print("--- Xây dựng prompt bằng chat template ---")
117
  llm_model = components["llm_model"]
118
  tokenizer = components["tokenizer"]
119
 
120
- # Tạo cấu trúc tin nhắn theo chuẩn của các mô hình hội thoại
121
  messages = [
122
  {
123
  "role": "system",
124
- "content": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam. Hãy dựa vào các thông tin được cung cấp để trả lời một cách chính xác và dễ hiểu."
125
  },
126
  {
127
  "role": "user",
@@ -136,14 +149,19 @@ def generate_response(query: str, components: dict) -> str:
136
  }
137
  ]
138
 
139
- # Sử dụng apply_chat_template để tạo prompt hoàn chỉnh.
140
- # Thêm `add_generation_prompt=True` để tự động thêm vai trò "assistant" cuối,
141
- # báo hiệu cho model bắt đầu sinh câu trả lời.
142
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
143
 
144
- # --- Bước 4: Tạo câu trả lời từ LLM (Không thay đổi) ---
145
  print("--- Bắt đầu tạo câu trả lời từ LLM ---")
146
 
 
147
  inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
148
 
149
  generation_config = dict(
@@ -158,4 +176,4 @@ def generate_response(query: str, components: dict) -> str:
158
  response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
159
 
160
  print("--- Tạo câu trả lời hoàn tất ---")
161
- return response_text
 
83
  def generate_response(query: str, components: dict) -> str:
84
  """
85
  Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
86
+ Phiên bản cuối cùng:
87
+ - Tương thích với mô hình Vision bằng cách sử dụng chat template.
88
+ - Nhận và sử dụng thông tin `matched_vehicle` từ retriever.
89
+ - Định dạng context với tóm tắt thông minh từ metadata.
90
  """
91
  print("--- Bắt đầu quy trình RAG cho query mới ---")
92
 
93
+ # === THAY ĐỔI 1: Nhận cả 2 giá trị trả về từ retriever ===
94
+ # 1. Truy xuất ngữ cảnh bằng retriever đã được nâng cấp
95
+ retrieved_results, matched_vehicle = search_relevant_laws(
96
  query_text=query,
97
  embedding_model=components["embedding_model"],
98
  faiss_index=components["faiss_index"],
 
102
  initial_k_multiplier=15
103
  )
104
 
105
+ # 2. Định dạng Context một cách thông minh
106
  if not retrieved_results:
107
  context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
108
  else:
 
111
  metadata = res.get('metadata', {})
112
  header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
113
  text = res.get('text', '*Nội dung không có*')
114
+
115
+ # === THAY ĐỔI 2: Thêm gợi ý về loại xe vào header ===
116
+ if matched_vehicle:
117
+ vehicle_keywords = {
118
+ "ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"],
119
+ "xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"]
120
+ }
121
+ article_title_lower = metadata.get("article_title", "").lower()
122
+ if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])):
123
+ header += f" [GỢI Ý: Thông tin này áp dụng cho {matched_vehicle.upper()}]"
124
+
125
  context_parts.append(f"{header}\n{text}")
126
  context = "\n\n---\n\n".join(context_parts)
127
 
128
+ # 3. Xây dựng Prompt bằng Chat Template
129
  print("--- Xây dựng prompt bằng chat template ---")
130
  llm_model = components["llm_model"]
131
  tokenizer = components["tokenizer"]
132
 
133
+ # Tạo cấu trúc tin nhắn theo chuẩn
134
  messages = [
135
  {
136
  "role": "system",
137
+ "content": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam. Hãy dựa vào các thông tin được cung cấp để trả lời một cách chính xác và dễ hiểu."
138
  },
139
  {
140
  "role": "user",
 
149
  }
150
  ]
151
 
152
+ # === THAY ĐỔI 3 (Cốt lõi): Sử dụng apply_chat_template ===
153
+ # Phương thức này sẽ tạo ra chuỗi prompt hoàn chỉnh với các token đặc biệt,
154
+ # tương thích với cả hình text vision (khi không có ảnh).
155
+ prompt = tokenizer.apply_chat_template(
156
+ messages,
157
+ tokenize=False,
158
+ add_generation_prompt=True
159
+ )
160
 
161
+ # 4. Tạo câu trả lời từ LLM
162
  print("--- Bắt đầu tạo câu trả lời từ LLM ---")
163
 
164
+ # Tokenize chuỗi prompt đã được định dạng đúng
165
  inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
166
 
167
  generation_config = dict(
 
176
  response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
177
 
178
  print("--- Tạo câu trả lời hoàn tất ---")
179
+ return response_text