Spaces:
Paused
Paused
fix bug
Browse files- rag_pipeline.py +44 -66
rag_pipeline.py
CHANGED
@@ -80,60 +80,16 @@ def initialize_components(data_path):
|
|
80 |
"bm25_model": bm25_model
|
81 |
}
|
82 |
|
83 |
-
def _format_context_with_summary(retrieved_results: list[dict]) -> str:
|
84 |
-
"""
|
85 |
-
Hàm phụ trợ: Định dạng ngữ cảnh từ kết quả truy xuất, bổ sung tóm tắt từ metadata.
|
86 |
-
Hàm này được tách ra để làm cho code sạch sẽ và dễ quản lý hơn.
|
87 |
-
"""
|
88 |
-
if not retrieved_results:
|
89 |
-
return "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
90 |
-
|
91 |
-
context_parts = []
|
92 |
-
for i, res in enumerate(retrieved_results):
|
93 |
-
metadata = res.get('metadata', {})
|
94 |
-
text = res.get('text', '*Nội dung không có*')
|
95 |
-
|
96 |
-
# Tạo header rõ ràng
|
97 |
-
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Điểm {metadata.get('point_id', '')} Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
98 |
-
|
99 |
-
# --- LOGIC TÓM TẮT THÔNG MINH TỪ METADATA ---
|
100 |
-
metadata_summary = ""
|
101 |
-
penalty_details_list = metadata.get("penalties_detail", [])
|
102 |
-
|
103 |
-
if penalty_details_list:
|
104 |
-
summary_parts = []
|
105 |
-
# Chỉ lấy thông tin từ mục hình phạt đầu tiên trong danh sách
|
106 |
-
details = penalty_details_list[0].get('details', {})
|
107 |
-
|
108 |
-
# Tóm tắt mức phạt tiền cho cá nhân (phổ biến nhất)
|
109 |
-
i_min = details.get("individual_fine_min")
|
110 |
-
i_max = details.get("individual_fine_max")
|
111 |
-
if i_min is not None and i_max is not None:
|
112 |
-
summary_parts.append(f"Phạt tiền cá nhân từ {i_min:,} - {i_max:,} đồng.")
|
113 |
-
|
114 |
-
# Tóm tắt mức trừ điểm
|
115 |
-
points = details.get("points_deducted")
|
116 |
-
if points is not None:
|
117 |
-
summary_parts.append(f"Trừ {points} điểm GPLX.")
|
118 |
-
|
119 |
-
if summary_parts:
|
120 |
-
# Chèn dòng tóm tắt vào giữa header và text
|
121 |
-
metadata_summary = f"\n[Tóm tắt từ metadata: {' '.join(summary_parts)}]"
|
122 |
-
|
123 |
-
context_parts.append(f"{header}{metadata_summary}\n{text}")
|
124 |
-
|
125 |
-
return "\n\n---\n\n".join(context_parts)
|
126 |
-
|
127 |
-
|
128 |
def generate_response(query: str, components: dict) -> str:
|
129 |
"""
|
130 |
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
131 |
-
Phiên bản
|
|
|
132 |
"""
|
133 |
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
134 |
|
135 |
-
# 1
|
136 |
-
# (
|
137 |
retrieved_results = search_relevant_laws(
|
138 |
query_text=query,
|
139 |
embedding_model=components["embedding_model"],
|
@@ -144,39 +100,61 @@ def generate_response(query: str, components: dict) -> str:
|
|
144 |
initial_k_multiplier=15
|
145 |
)
|
146 |
|
147 |
-
# 2
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
### Thông tin luật:
|
154 |
{context}
|
155 |
|
156 |
### Câu hỏi:
|
157 |
{query}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
# 4. Tạo câu trả lời từ LLM
|
162 |
-
llm_model = components["llm_model"]
|
163 |
-
tokenizer = components["tokenizer"]
|
164 |
|
165 |
-
|
166 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
|
167 |
|
168 |
-
# Cấu hình generation tối ưu cho việc trả lời câu hỏi dựa trên facts
|
169 |
generation_config = dict(
|
170 |
max_new_tokens=256,
|
171 |
-
temperature=0.1,
|
172 |
-
repetition_penalty=1.1,
|
173 |
-
do_sample=True,
|
174 |
pad_token_id=tokenizer.eos_token_id
|
175 |
)
|
176 |
|
177 |
output_ids = llm_model.generate(**inputs, **generation_config)
|
178 |
-
|
179 |
-
# Chỉ decode phần văn bản được sinh ra mới, bỏ qua phần prompt
|
180 |
response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
181 |
|
182 |
print("--- Tạo câu trả lời hoàn tất ---")
|
|
|
80 |
"bm25_model": bm25_model
|
81 |
}
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
def generate_response(query: str, components: dict) -> str:
|
84 |
"""
|
85 |
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
86 |
+
Phiên bản này được sửa đổi để tương thích với các mô hình Vision (đa phương thức)
|
87 |
+
bằng cách sử dụng chat template.
|
88 |
"""
|
89 |
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
90 |
|
91 |
+
# --- Bước 1: Truy xuất Ngữ cảnh (Không thay đổi) ---
|
92 |
+
# (Hàm này giả định bạn đã có phiên bản retriever.py đã sửa lỗi logic boosting)
|
93 |
retrieved_results = search_relevant_laws(
|
94 |
query_text=query,
|
95 |
embedding_model=components["embedding_model"],
|
|
|
100 |
initial_k_multiplier=15
|
101 |
)
|
102 |
|
103 |
+
# --- Bước 2: Định dạng Ngữ cảnh (Không thay đổi, nhưng nên tách ra hàm riêng) ---
|
104 |
+
if not retrieved_results:
|
105 |
+
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
106 |
+
else:
|
107 |
+
context_parts = []
|
108 |
+
for i, res in enumerate(retrieved_results):
|
109 |
+
metadata = res.get('metadata', {})
|
110 |
+
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
111 |
+
text = res.get('text', '*Nội dung không có*')
|
112 |
+
context_parts.append(f"{header}\n{text}")
|
113 |
+
context = "\n\n---\n\n".join(context_parts)
|
114 |
+
|
115 |
+
# --- Bước 3: Xây dựng Prompt bằng Chat Template (Đây là phần thay đổi chính) ---
|
116 |
+
print("--- Xây dựng prompt bằng chat template ---")
|
117 |
+
llm_model = components["llm_model"]
|
118 |
+
tokenizer = components["tokenizer"]
|
119 |
+
|
120 |
+
# Tạo cấu trúc tin nhắn theo chuẩn của các mô hình hội thoại
|
121 |
+
messages = [
|
122 |
+
{
|
123 |
+
"role": "system",
|
124 |
+
"content": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam. Hãy dựa vào các thông tin được cung cấp để trả lời một cách chính xác và dễ hiểu."
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"role": "user",
|
128 |
+
"content": f"""Dựa vào các trích dẫn luật dưới đây:
|
129 |
|
130 |
### Thông tin luật:
|
131 |
{context}
|
132 |
|
133 |
### Câu hỏi:
|
134 |
{query}
|
135 |
+
"""
|
136 |
+
}
|
137 |
+
]
|
138 |
+
|
139 |
+
# Sử dụng apply_chat_template để tạo prompt hoàn chỉnh.
|
140 |
+
# Thêm `add_generation_prompt=True` để nó tự động thêm vai trò "assistant" ở cuối,
|
141 |
+
# báo hiệu cho model bắt đầu sinh câu trả lời.
|
142 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
143 |
|
144 |
+
# --- Bước 4: Tạo câu trả lời từ LLM (Không thay đổi) ---
|
145 |
+
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
|
|
|
|
|
|
146 |
|
147 |
+
inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
|
|
|
148 |
|
|
|
149 |
generation_config = dict(
|
150 |
max_new_tokens=256,
|
151 |
+
temperature=0.1,
|
152 |
+
repetition_penalty=1.1,
|
153 |
+
do_sample=True,
|
154 |
pad_token_id=tokenizer.eos_token_id
|
155 |
)
|
156 |
|
157 |
output_ids = llm_model.generate(**inputs, **generation_config)
|
|
|
|
|
158 |
response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
159 |
|
160 |
print("--- Tạo câu trả lời hoàn tất ---")
|