File size: 15,379 Bytes
845a94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4aaddd
 
845a94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4aaddd
845a94d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# app.py
# File triển khai hoàn chỉnh cho đồ án Chatbot Luật Giao thông
# Tác giả: (Tên của bạn)
# Ngày: (Ngày bạn tạo)

# --- PHẦN 1: IMPORT CÁC THƯ VIỆN CẦN THIẾT ---
print("Bắt đầu import các thư viện...")
import os
import sys
import json
import re
import time
from collections import defaultdict

# Core ML/DL và Unsloth
import torch
from unsloth import FastLanguageModel
from transformers import TextStreamer

# RAG - Retrieval
import faiss
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import numpy as np

# Deployment
import gradio as gr

print("✅ Import thư viện thành công.")

# --- PHẦN 2: CẤU HÌNH VÀ TẢI TÀI NGUYÊN (MODELS & DATA) ---
# Phần này sẽ chỉ chạy một lần khi ứng dụng khởi động.

# Cấu hình mô hình
MAX_SEQ_LENGTH = 2048
DTYPE = None
LOAD_IN_4BIT = True
EMBEDDING_MODEL_NAME = "bkai-foundation-models/vietnamese-bi-encoder"
LLM_MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
LAW_DATA_FILE = "luat_chi_tiet_output_openai_sdk_final_cleaned.json"

# Biến toàn cục để lưu các tài nguyên đã tải
# Điều này giúp tránh việc phải tải lại mô hình mỗi khi người dùng gửi yêu cầu.
MODELS_AND_DATA = {
    "llm_model": None,
    "tokenizer": None,
    "embedding_model": None,
    "faiss_index": None,
    "bm25_model": None,
    "chunks_data": None,
    "tokenized_corpus_bm25": None,
}

# --- Các hàm xử lý dữ liệu (từ các notebook của bạn) ---

def process_law_data_to_chunks(structured_data_input):
    """
    Hàm làm phẳng dữ liệu luật có cấu trúc chi tiết thành danh sách các chunks.
    Mỗi chunk chứa 'text' và 'metadata'.
    """
    flat_list = []
    articles_list = []
    if isinstance(structured_data_input, dict) and "article" in structured_data_input:
        articles_list = [structured_data_input]
    elif isinstance(structured_data_input, list):
        articles_list = structured_data_input
    else:
        print("Lỗi: Dữ liệu đầu vào không hợp lệ.")
        return flat_list

    for article_data in articles_list:
        if not isinstance(article_data, dict): continue
        article_metadata_base = {
            "source_document": article_data.get("source_document"),
            "article": article_data.get("article"),
            "article_title": article_data.get("article_title")
        }
        clauses = article_data.get("clauses", [])
        if not isinstance(clauses, list): continue

        for clause_data in clauses:
            if not isinstance(clause_data, dict): continue
            clause_metadata_base = article_metadata_base.copy()
            clause_metadata_base.update({
                "clause_number": clause_data.get("clause_number"),
                "clause_metadata_summary": clause_data.get("clause_metadata_summary")
            })
            points_in_clause = clause_data.get("points_in_clause", [])
            if not isinstance(points_in_clause, list): continue

            if points_in_clause:
                for point_data in points_in_clause:
                    if not isinstance(point_data, dict): continue
                    chunk_text = point_data.get("point_text_original") or point_data.get("violation_description_summary")
                    if not chunk_text: continue
                    
                    current_point_metadata = clause_metadata_base.copy()
                    point_specific_metadata = point_data.copy()
                    if "point_text_original" in point_specific_metadata:
                        del point_specific_metadata["point_text_original"]
                    current_point_metadata.update(point_specific_metadata)
                    final_metadata_cleaned = {k: v for k, v in current_point_metadata.items() if v is not None}
                    flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
            else:
                chunk_text = clause_data.get("clause_text_original")
                if chunk_text:
                    current_clause_metadata = clause_metadata_base.copy()
                    additional_clause_info = {k: v for k, value in clause_data.items() if k not in ["clause_text_original", "points_in_clause", "clause_number", "clause_metadata_summary"]}
                    if additional_clause_info:
                        current_clause_metadata.update(additional_clause_info)
                    final_metadata_cleaned = {k: v for k, v in current_clause_metadata.items() if v is not None}
                    flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
    return flat_list

def tokenize_vi_for_bm25(text):
    """Hàm tokenize tiếng Việt đơn giản cho BM25."""
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return text.split()

def load_all_resources():
    """
    Hàm chính để tải tất cả mô hình và dữ liệu cần thiết.
    Chỉ chạy một lần khi ứng dụng khởi động.
    """
    print("--- Bắt đầu quá trình tải tài nguyên ---")
    
    # 1. Tải mô hình LLM và Tokenizer
    print(f"1. Đang tải LLM và Tokenizer: {LLM_MODEL_NAME}...")
    llm_model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=LLM_MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=DTYPE,
        load_in_4bit=LOAD_IN_4BIT,
    )
    FastLanguageModel.for_inference(llm_model)
    MODELS_AND_DATA["llm_model"] = llm_model
    MODELS_AND_DATA["tokenizer"] = tokenizer
    print("✅ Tải LLM và Tokenizer thành công.")

    # 2. Tải mô hình Embedding
    print(f"2. Đang tải Embedding Model: {EMBEDDING_MODEL_NAME}...")
    embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")
    MODELS_AND_DATA["embedding_model"] = embedding_model
    print("✅ Tải Embedding Model thành công.")

    # 3. Tải và xử lý dữ liệu luật
    print(f"3. Đang tải và xử lý dữ liệu từ: {LAW_DATA_FILE}...")
    if not os.path.exists(LAW_DATA_FILE):
        raise FileNotFoundError(f"Không tìm thấy file dữ liệu luật: {LAW_DATA_FILE}. Vui lòng upload file này lên Space.")
    with open(LAW_DATA_FILE, 'r', encoding='utf-8') as f:
        raw_data_from_file = json.load(f)
    chunks_data = process_law_data_to_chunks(raw_data_from_file)
    MODELS_AND_DATA["chunks_data"] = chunks_data
    print(f"✅ Đã xử lý thành {len(chunks_data)} chunks.")

    # 4. Tạo BM25 Model
    print("4. Đang tạo BM25 Model...")
    corpus_texts = [chunk.get('text', '') for chunk in chunks_data]
    tokenized_corpus = [tokenize_vi_for_bm25(text) for text in corpus_texts]
    bm25_model = BM25Okapi(tokenized_corpus)
    MODELS_AND_DATA["bm25_model"] = bm25_model
    MODELS_AND_DATA["tokenized_corpus_bm25"] = tokenized_corpus
    print("✅ Tạo BM25 Model thành công.")

    # 5. Tạo FAISS Index
    print("5. Đang tạo FAISS Index...")
    texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
    chunk_embeddings = embedding_model.encode(texts_to_encode, convert_to_tensor=True, device=embedding_model.device)
    chunk_embeddings_np = chunk_embeddings.cpu().numpy().astype('float32')
    faiss.normalize_L2(chunk_embeddings_np)
    dimension = chunk_embeddings_np.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(chunk_embeddings_np)
    MODELS_AND_DATA["faiss_index"] = index
    print(f"✅ Tạo FAISS Index thành công với {index.ntotal} vectors.")

    print("\n--- Tải tài nguyên hoàn tất! Ứng dụng đã sẵn sàng. ---")

# --- PHẦN 3: CÁC HÀM LÕI CHO RAG ---

def search_relevant_laws(query_text, k=5, initial_k_multiplier=10, rrf_k_constant=60):
    """
    Hàm thực hiện Hybrid Search để tìm các đoạn luật liên quan.
    """
    # Lấy các tài nguyên đã tải
    embedding_model = MODELS_AND_DATA["embedding_model"]
    faiss_index = MODELS_AND_DATA["faiss_index"]
    chunks_data = MODELS_AND_DATA["chunks_data"]
    bm25_model = MODELS_AND_DATA["bm25_model"]

    if not all([embedding_model, faiss_index, chunks_data, bm25_model]):
        return "Lỗi: Tài nguyên chưa được tải xong. Vui lòng chờ."

    # 1. Semantic Search (FAISS)
    query_embedding = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
    query_embedding_np = query_embedding.cpu().numpy().astype('float32')
    faiss.normalize_L2(query_embedding_np)
    num_candidates = min(k * initial_k_multiplier, faiss_index.ntotal)
    semantic_scores, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)

    # 2. Keyword Search (BM25)
    tokenized_query = tokenize_vi_for_bm25(query_text)
    bm25_scores = bm25_model.get_scores(tokenized_query)
    bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates]

    # 3. Reciprocal Rank Fusion (RRF)
    rrf_scores = defaultdict(float)
    if semantic_indices.size > 0:
        for rank, doc_idx in enumerate(semantic_indices[0]):
            if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
    for rank, (doc_idx, score) in enumerate(bm25_results):
        if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
    
    fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)

    # 4. Lấy kết quả cuối cùng
    final_results = []
    for doc_idx, score in fused_results[:k]:
        result = chunks_data[doc_idx].copy()
        result['score'] = score
        final_results.append(result)
        
    return final_results

def generate_llm_response(query, context):
    """
    Hàm sinh câu trả lời từ LLM dựa trên query và context.
    """
    llm_model = MODELS_AND_DATA["llm_model"]
    tokenizer = MODELS_AND_DATA["tokenizer"]

    prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.

### Thông tin luật:
{context}

### Câu hỏi:
{query}

### Trả lời:"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generation_config = dict(
        max_new_tokens=300,
        temperature=0.2,
        top_p=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    output_ids = llm_model.generate(**inputs, **generation_config)
    input_length = inputs.input_ids.shape[1]
    generated_ids = output_ids[0][input_length:]
    response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return response_text

# --- PHẦN 4: CÁC HÀM XỬ LÝ CHO GRADIO INTERFACE ---

def run_retriever_only(query):
    """
    Chức năng 1: Chỉ tìm kiếm và trả về các điều luật liên quan.
    """
    print(f"Chạy chức năng Retriever cho query: '{query}'")
    retrieved_results = search_relevant_laws(query)
    
    if isinstance(retrieved_results, str): # Xử lý trường hợp lỗi
        return retrieved_results

    if not retrieved_results:
        return "Không tìm thấy điều luật nào liên quan."

    # Định dạng output cho Gradio Markdown
    formatted_output = f"### Các điều luật liên quan nhất đến truy vấn của bạn:\n\n"
    for i, res in enumerate(retrieved_results):
        metadata = res.get('metadata', {})
        article = metadata.get('article', 'N/A')
        clause = metadata.get('clause_number', 'N/A')
        source = metadata.get('source_document', 'N/A')
        text = res.get('text', 'N/A')
        
        formatted_output += f"**{i+1}. Nguồn: {source} | Điều {article} | Khoản {clause}**\n"
        formatted_output += f"> {text}\n\n---\n\n"
        
    return formatted_output

def run_full_rag(query, progress=gr.Progress()):
    """
    Chức năng 2: Thực hiện toàn bộ pipeline RAG.
    """
    progress(0, desc="Bắt đầu...")
    
    # Bước 1: Truy xuất ngữ cảnh
    progress(0.2, desc="Đang tìm kiếm các điều luật liên quan (Hybrid Search)...")
    print(f"Chạy chức năng RAG cho query: '{query}'")
    retrieved_results = search_relevant_laws(query)
    
    if isinstance(retrieved_results, str) or not retrieved_results:
        context_for_llm = "Không tìm thấy thông tin luật liên quan."
        context_for_display = context_for_llm
    else:
        # Định dạng context cho LLM
        context_parts = []
        for res in retrieved_results:
            text = res.get('text', '')
            context_parts.append(text)
        context_for_llm = "\n\n---\n\n".join(context_parts)
        
        # Định dạng context để hiển thị cho người dùng
        context_for_display = run_retriever_only(query) # Tái sử dụng hàm retriever
    
    # Bước 2: Sinh câu trả lời
    progress(0.7, desc="Đã có ngữ cảnh, đang yêu cầu LLM tạo câu trả lời...")
    final_answer = generate_llm_response(query, context_for_llm)
    
    progress(1, desc="Hoàn tất!")
    
    return final_answer, context_for_display


# --- PHẦN 5: KHỞI CHẠY ỨNG DỤNG GRADIO ---

# Tải tài nguyên ngay khi script được chạy
load_all_resources()

with gr.Blocks(theme=gr.themes.Soft(), title="Chatbot Luật Giao thông") as demo:
    gr.Markdown(
        """
        # ⚖️ Chatbot Luật Giao thông Việt Nam
        Ứng dụng này sử dụng mô hình RAG (Retrieval-Augmented Generation) để trả lời các câu hỏi về luật giao thông.
        """
    )
    
    with gr.Tabs():
        # Tab 1: Chỉ tìm kiếm
        with gr.TabItem("Tìm kiếm Điều luật (Retriever)"):
            with gr.Row():
                retriever_query = gr.Textbox(label="Nhập nội dung cần tìm kiếm", placeholder="Ví dụ: Vượt đèn đỏ bị phạt bao nhiêu tiền?", scale=4)
                retriever_button = gr.Button("Tìm kiếm", variant="secondary", scale=1)
            retriever_output = gr.Markdown(label="Các điều luật liên quan")
        
        # Tab 2: Hỏi-đáp RAG đầy đủ
        with gr.TabItem("Hỏi-Đáp (RAG)"):
            with gr.Row():
                rag_query = gr.Textbox(label="Nhập câu hỏi của bạn", placeholder="Ví dụ: Phương tiện giao thông đường bộ gồm những loại nào?", scale=4)
                rag_button = gr.Button("Gửi câu hỏi", variant="primary", scale=1)
            rag_answer = gr.Textbox(label="Câu trả lời của Chatbot", lines=5)
            with gr.Accordion("Xem ngữ cảnh đã sử dụng để tạo câu trả lời", open=False):
                rag_context = gr.Markdown(label="Ngữ cảnh")

    # Xử lý sự kiện click
    retriever_button.click(fn=run_retriever_only, inputs=retriever_query, outputs=retriever_output)
    rag_button.click(fn=run_full_rag, inputs=rag_query, outputs=[rag_answer, rag_context])

if __name__ == "__main__":
    demo.launch(share=True) # share=True để tạo link public nếu chạy trên Colab/local