deddoggo commited on
Commit
845a94d
·
1 Parent(s): b4aaddd

update main

Browse files
Files changed (1) hide show
  1. app.py +355 -4
app.py CHANGED
@@ -1,7 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # app.py
2
+ # File triển khai hoàn chỉnh cho đồ án Chatbot Luật Giao thông
3
+ # Tác giả: (Tên của bạn)
4
+ # Ngày: (Ngày bạn tạo)
5
+
6
+ # --- PHẦN 1: IMPORT CÁC THƯ VIỆN CẦN THIẾT ---
7
+ print("Bắt đầu import các thư viện...")
8
+ import os
9
+ import sys
10
+ import json
11
+ import re
12
+ import time
13
+ from collections import defaultdict
14
+
15
+ # Core ML/DL và Unsloth
16
+ import torch
17
+ from unsloth import FastLanguageModel
18
+ from transformers import TextStreamer
19
+
20
+ # RAG - Retrieval
21
+ import faiss
22
+ from sentence_transformers import SentenceTransformer
23
+ from rank_bm25 import BM25Okapi
24
+ import numpy as np
25
+
26
+ # Deployment
27
  import gradio as gr
28
 
29
+ print("✅ Import thư viện thành công.")
30
+
31
+ # --- PHẦN 2: CẤU HÌNH VÀ TẢI TÀI NGUYÊN (MODELS & DATA) ---
32
+ # Phần này sẽ chỉ chạy một lần khi ứng dụng khởi động.
33
+
34
+ # Cấu hình mô hình
35
+ MAX_SEQ_LENGTH = 2048
36
+ DTYPE = None
37
+ LOAD_IN_4BIT = True
38
+ EMBEDDING_MODEL_NAME = "bkai-foundation-models/vietnamese-bi-encoder"
39
+ LLM_MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
40
+ LAW_DATA_FILE = "luat_chi_tiet_output_openai_sdk_final_cleaned.json"
41
+
42
+ # Biến toàn cục để lưu các tài nguyên đã tải
43
+ # Điều này giúp tránh việc phải tải lại mô hình mỗi khi người dùng gửi yêu cầu.
44
+ MODELS_AND_DATA = {
45
+ "llm_model": None,
46
+ "tokenizer": None,
47
+ "embedding_model": None,
48
+ "faiss_index": None,
49
+ "bm25_model": None,
50
+ "chunks_data": None,
51
+ "tokenized_corpus_bm25": None,
52
+ }
53
+
54
+ # --- Các hàm xử lý dữ liệu (từ các notebook của bạn) ---
55
+
56
+ def process_law_data_to_chunks(structured_data_input):
57
+ """
58
+ Hàm làm phẳng dữ liệu luật có cấu trúc chi tiết thành danh sách các chunks.
59
+ Mỗi chunk chứa 'text' và 'metadata'.
60
+ """
61
+ flat_list = []
62
+ articles_list = []
63
+ if isinstance(structured_data_input, dict) and "article" in structured_data_input:
64
+ articles_list = [structured_data_input]
65
+ elif isinstance(structured_data_input, list):
66
+ articles_list = structured_data_input
67
+ else:
68
+ print("Lỗi: Dữ liệu đầu vào không hợp lệ.")
69
+ return flat_list
70
+
71
+ for article_data in articles_list:
72
+ if not isinstance(article_data, dict): continue
73
+ article_metadata_base = {
74
+ "source_document": article_data.get("source_document"),
75
+ "article": article_data.get("article"),
76
+ "article_title": article_data.get("article_title")
77
+ }
78
+ clauses = article_data.get("clauses", [])
79
+ if not isinstance(clauses, list): continue
80
+
81
+ for clause_data in clauses:
82
+ if not isinstance(clause_data, dict): continue
83
+ clause_metadata_base = article_metadata_base.copy()
84
+ clause_metadata_base.update({
85
+ "clause_number": clause_data.get("clause_number"),
86
+ "clause_metadata_summary": clause_data.get("clause_metadata_summary")
87
+ })
88
+ points_in_clause = clause_data.get("points_in_clause", [])
89
+ if not isinstance(points_in_clause, list): continue
90
+
91
+ if points_in_clause:
92
+ for point_data in points_in_clause:
93
+ if not isinstance(point_data, dict): continue
94
+ chunk_text = point_data.get("point_text_original") or point_data.get("violation_description_summary")
95
+ if not chunk_text: continue
96
+
97
+ current_point_metadata = clause_metadata_base.copy()
98
+ point_specific_metadata = point_data.copy()
99
+ if "point_text_original" in point_specific_metadata:
100
+ del point_specific_metadata["point_text_original"]
101
+ current_point_metadata.update(point_specific_metadata)
102
+ final_metadata_cleaned = {k: v for k, v in current_point_metadata.items() if v is not None}
103
+ flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
104
+ else:
105
+ chunk_text = clause_data.get("clause_text_original")
106
+ if chunk_text:
107
+ current_clause_metadata = clause_metadata_base.copy()
108
+ additional_clause_info = {k: v for k, value in clause_data.items() if k not in ["clause_text_original", "points_in_clause", "clause_number", "clause_metadata_summary"]}
109
+ if additional_clause_info:
110
+ current_clause_metadata.update(additional_clause_info)
111
+ final_metadata_cleaned = {k: v for k, v in current_clause_metadata.items() if v is not None}
112
+ flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
113
+ return flat_list
114
+
115
+ def tokenize_vi_for_bm25(text):
116
+ """Hàm tokenize tiếng Việt đơn giản cho BM25."""
117
+ text = text.lower()
118
+ text = re.sub(r'[^\w\s]', '', text)
119
+ return text.split()
120
+
121
+ def load_all_resources():
122
+ """
123
+ Hàm chính để tải tất cả mô hình và dữ liệu cần thiết.
124
+ Chỉ chạy một lần khi ứng dụng khởi động.
125
+ """
126
+ print("--- Bắt đầu quá trình tải tài nguyên ---")
127
+
128
+ # 1. Tải mô hình LLM và Tokenizer
129
+ print(f"1. Đang tải LLM và Tokenizer: {LLM_MODEL_NAME}...")
130
+ llm_model, tokenizer = FastLanguageModel.from_pretrained(
131
+ model_name=LLM_MODEL_NAME,
132
+ max_seq_length=MAX_SEQ_LENGTH,
133
+ dtype=DTYPE,
134
+ load_in_4bit=LOAD_IN_4BIT,
135
+ )
136
+ FastLanguageModel.for_inference(llm_model)
137
+ MODELS_AND_DATA["llm_model"] = llm_model
138
+ MODELS_AND_DATA["tokenizer"] = tokenizer
139
+ print("✅ Tải LLM và Tokenizer thành công.")
140
+
141
+ # 2. Tải mô hình Embedding
142
+ print(f"2. Đang tải Embedding Model: {EMBEDDING_MODEL_NAME}...")
143
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")
144
+ MODELS_AND_DATA["embedding_model"] = embedding_model
145
+ print("✅ Tải Embedding Model thành công.")
146
+
147
+ # 3. Tải và xử lý dữ liệu luật
148
+ print(f"3. Đang tải và xử lý dữ liệu từ: {LAW_DATA_FILE}...")
149
+ if not os.path.exists(LAW_DATA_FILE):
150
+ raise FileNotFoundError(f"Không tìm thấy file dữ liệu luật: {LAW_DATA_FILE}. Vui lòng upload file này lên Space.")
151
+ with open(LAW_DATA_FILE, 'r', encoding='utf-8') as f:
152
+ raw_data_from_file = json.load(f)
153
+ chunks_data = process_law_data_to_chunks(raw_data_from_file)
154
+ MODELS_AND_DATA["chunks_data"] = chunks_data
155
+ print(f"✅ Đã xử lý thành {len(chunks_data)} chunks.")
156
+
157
+ # 4. Tạo BM25 Model
158
+ print("4. Đang tạo BM25 Model...")
159
+ corpus_texts = [chunk.get('text', '') for chunk in chunks_data]
160
+ tokenized_corpus = [tokenize_vi_for_bm25(text) for text in corpus_texts]
161
+ bm25_model = BM25Okapi(tokenized_corpus)
162
+ MODELS_AND_DATA["bm25_model"] = bm25_model
163
+ MODELS_AND_DATA["tokenized_corpus_bm25"] = tokenized_corpus
164
+ print("✅ Tạo BM25 Model thành công.")
165
+
166
+ # 5. Tạo FAISS Index
167
+ print("5. Đang tạo FAISS Index...")
168
+ texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
169
+ chunk_embeddings = embedding_model.encode(texts_to_encode, convert_to_tensor=True, device=embedding_model.device)
170
+ chunk_embeddings_np = chunk_embeddings.cpu().numpy().astype('float32')
171
+ faiss.normalize_L2(chunk_embeddings_np)
172
+ dimension = chunk_embeddings_np.shape[1]
173
+ index = faiss.IndexFlatIP(dimension)
174
+ index.add(chunk_embeddings_np)
175
+ MODELS_AND_DATA["faiss_index"] = index
176
+ print(f"✅ Tạo FAISS Index thành công với {index.ntotal} vectors.")
177
+
178
+ print("\n--- Tải tài nguyên hoàn tất! Ứng dụng đã sẵn sàng. ---")
179
+
180
+ # --- PHẦN 3: CÁC HÀM LÕI CHO RAG ---
181
+
182
+ def search_relevant_laws(query_text, k=5, initial_k_multiplier=10, rrf_k_constant=60):
183
+ """
184
+ Hàm thực hiện Hybrid Search để tìm các đoạn luật liên quan.
185
+ """
186
+ # Lấy các tài nguyên đã tải
187
+ embedding_model = MODELS_AND_DATA["embedding_model"]
188
+ faiss_index = MODELS_AND_DATA["faiss_index"]
189
+ chunks_data = MODELS_AND_DATA["chunks_data"]
190
+ bm25_model = MODELS_AND_DATA["bm25_model"]
191
+
192
+ if not all([embedding_model, faiss_index, chunks_data, bm25_model]):
193
+ return "Lỗi: Tài nguyên chưa được tải xong. Vui lòng chờ."
194
+
195
+ # 1. Semantic Search (FAISS)
196
+ query_embedding = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
197
+ query_embedding_np = query_embedding.cpu().numpy().astype('float32')
198
+ faiss.normalize_L2(query_embedding_np)
199
+ num_candidates = min(k * initial_k_multiplier, faiss_index.ntotal)
200
+ semantic_scores, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
201
+
202
+ # 2. Keyword Search (BM25)
203
+ tokenized_query = tokenize_vi_for_bm25(query_text)
204
+ bm25_scores = bm25_model.get_scores(tokenized_query)
205
+ bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates]
206
+
207
+ # 3. Reciprocal Rank Fusion (RRF)
208
+ rrf_scores = defaultdict(float)
209
+ if semantic_indices.size > 0:
210
+ for rank, doc_idx in enumerate(semantic_indices[0]):
211
+ if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
212
+ for rank, (doc_idx, score) in enumerate(bm25_results):
213
+ if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
214
+
215
+ fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
216
+
217
+ # 4. Lấy kết quả cuối cùng
218
+ final_results = []
219
+ for doc_idx, score in fused_results[:k]:
220
+ result = chunks_data[doc_idx].copy()
221
+ result['score'] = score
222
+ final_results.append(result)
223
+
224
+ return final_results
225
+
226
+ def generate_llm_response(query, context):
227
+ """
228
+ Hàm sinh câu trả lời từ LLM dựa trên query và context.
229
+ """
230
+ llm_model = MODELS_AND_DATA["llm_model"]
231
+ tokenizer = MODELS_AND_DATA["tokenizer"]
232
+
233
+ prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
234
+ Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
235
+ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
236
+
237
+ ### Thông tin luật:
238
+ {context}
239
+
240
+ ### Câu hỏi:
241
+ {query}
242
+
243
+ ### Trả lời:"""
244
+
245
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
246
+ generation_config = dict(
247
+ max_new_tokens=300,
248
+ temperature=0.2,
249
+ top_p=0.7,
250
+ do_sample=True,
251
+ pad_token_id=tokenizer.eos_token_id,
252
+ eos_token_id=tokenizer.eos_token_id
253
+ )
254
+ output_ids = llm_model.generate(**inputs, **generation_config)
255
+ input_length = inputs.input_ids.shape[1]
256
+ generated_ids = output_ids[0][input_length:]
257
+ response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
258
+ return response_text
259
+
260
+ # --- PHẦN 4: CÁC HÀM XỬ LÝ CHO GRADIO INTERFACE ---
261
+
262
+ def run_retriever_only(query):
263
+ """
264
+ Chức năng 1: Chỉ tìm kiếm và trả về các điều luật liên quan.
265
+ """
266
+ print(f"Chạy chức năng Retriever cho query: '{query}'")
267
+ retrieved_results = search_relevant_laws(query)
268
+
269
+ if isinstance(retrieved_results, str): # Xử lý trường hợp lỗi
270
+ return retrieved_results
271
+
272
+ if not retrieved_results:
273
+ return "Không tìm thấy điều luật nào liên quan."
274
+
275
+ # Định dạng output cho Gradio Markdown
276
+ formatted_output = f"### Các điều luật liên quan nhất đến truy vấn của bạn:\n\n"
277
+ for i, res in enumerate(retrieved_results):
278
+ metadata = res.get('metadata', {})
279
+ article = metadata.get('article', 'N/A')
280
+ clause = metadata.get('clause_number', 'N/A')
281
+ source = metadata.get('source_document', 'N/A')
282
+ text = res.get('text', 'N/A')
283
+
284
+ formatted_output += f"**{i+1}. Nguồn: {source} | Điều {article} | Khoản {clause}**\n"
285
+ formatted_output += f"> {text}\n\n---\n\n"
286
+
287
+ return formatted_output
288
+
289
+ def run_full_rag(query, progress=gr.Progress()):
290
+ """
291
+ Chức năng 2: Thực hiện toàn bộ pipeline RAG.
292
+ """
293
+ progress(0, desc="Bắt đầu...")
294
+
295
+ # Bước 1: Truy xuất ngữ cảnh
296
+ progress(0.2, desc="Đang tìm kiếm các điều luật liên quan (Hybrid Search)...")
297
+ print(f"Chạy chức năng RAG cho query: '{query}'")
298
+ retrieved_results = search_relevant_laws(query)
299
+
300
+ if isinstance(retrieved_results, str) or not retrieved_results:
301
+ context_for_llm = "Không tìm thấy thông tin luật liên quan."
302
+ context_for_display = context_for_llm
303
+ else:
304
+ # Định dạng context cho LLM
305
+ context_parts = []
306
+ for res in retrieved_results:
307
+ text = res.get('text', '')
308
+ context_parts.append(text)
309
+ context_for_llm = "\n\n---\n\n".join(context_parts)
310
+
311
+ # Định dạng context để hiển thị cho người dùng
312
+ context_for_display = run_retriever_only(query) # Tái sử dụng hàm retriever
313
+
314
+ # Bước 2: Sinh câu trả lời
315
+ progress(0.7, desc="Đã có ngữ cảnh, đang yêu cầu LLM tạo câu trả lời...")
316
+ final_answer = generate_llm_response(query, context_for_llm)
317
+
318
+ progress(1, desc="Hoàn tất!")
319
+
320
+ return final_answer, context_for_display
321
+
322
+
323
+ # --- PHẦN 5: KHỞI CHẠY ỨNG DỤNG GRADIO ---
324
+
325
+ # Tải tài nguyên ngay khi script được chạy
326
+ load_all_resources()
327
+
328
+ with gr.Blocks(theme=gr.themes.Soft(), title="Chatbot Luật Giao thông") as demo:
329
+ gr.Markdown(
330
+ """
331
+ # ⚖️ Chatbot Luật Giao thông Việt Nam
332
+ Ứng dụng này sử dụng mô hình RAG (Retrieval-Augmented Generation) để trả lời các câu hỏi về luật giao thông.
333
+ """
334
+ )
335
+
336
+ with gr.Tabs():
337
+ # Tab 1: Chỉ tìm kiếm
338
+ with gr.TabItem("Tìm kiếm Điều luật (Retriever)"):
339
+ with gr.Row():
340
+ retriever_query = gr.Textbox(label="Nhập nội dung cần tìm kiếm", placeholder="Ví dụ: Vượt đèn đỏ bị phạt bao nhiêu tiền?", scale=4)
341
+ retriever_button = gr.Button("Tìm kiếm", variant="secondary", scale=1)
342
+ retriever_output = gr.Markdown(label="Các điều luật liên quan")
343
+
344
+ # Tab 2: Hỏi-đáp RAG đầy đủ
345
+ with gr.TabItem("Hỏi-Đáp (RAG)"):
346
+ with gr.Row():
347
+ rag_query = gr.Textbox(label="Nhập câu hỏi của bạn", placeholder="Ví dụ: Phương tiện giao thông đường bộ gồm những loại nào?", scale=4)
348
+ rag_button = gr.Button("Gửi câu hỏi", variant="primary", scale=1)
349
+ rag_answer = gr.Textbox(label="Câu trả lời của Chatbot", lines=5)
350
+ with gr.Accordion("Xem ngữ cảnh đã sử dụng để tạo câu trả lời", open=False):
351
+ rag_context = gr.Markdown(label="Ngữ cảnh")
352
+
353
+ # Xử lý sự kiện click
354
+ retriever_button.click(fn=run_retriever_only, inputs=retriever_query, outputs=retriever_output)
355
+ rag_button.click(fn=run_full_rag, inputs=rag_query, outputs=[rag_answer, rag_context])
356
 
357
+ if __name__ == "__main__":
358
+ demo.launch(share=True) # share=True để tạo link public nếu chạy trên Colab/local