Spaces:
Paused
Paused
# app.py | |
# File triển khai hoàn chỉnh cho đồ án Chatbot Luật Giao thông | |
# Tác giả: (Tên của bạn) | |
# Ngày: (Ngày bạn tạo) | |
# --- PHẦN 1: IMPORT CÁC THƯ VIỆN CẦN THIẾT --- | |
print("Bắt đầu import các thư viện...") | |
import os | |
import sys | |
import json | |
import re | |
import time | |
from collections import defaultdict | |
# Core ML/DL và Unsloth | |
import torch | |
from unsloth import FastLanguageModel | |
from transformers import TextStreamer | |
# RAG - Retrieval | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
import numpy as np | |
# Deployment | |
import gradio as gr | |
print("✅ Import thư viện thành công.") | |
# --- PHẦN 2: CẤU HÌNH VÀ TẢI TÀI NGUYÊN (MODELS & DATA) --- | |
# Phần này sẽ chỉ chạy một lần khi ứng dụng khởi động. | |
# Cấu hình mô hình | |
MAX_SEQ_LENGTH = 2048 | |
DTYPE = None | |
LOAD_IN_4BIT = True | |
EMBEDDING_MODEL_NAME = "bkai-foundation-models/vietnamese-bi-encoder" | |
LLM_MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" | |
LAW_DATA_FILE = "luat_chi_tiet_output_openai_sdk_final_cleaned.json" | |
# Biến toàn cục để lưu các tài nguyên đã tải | |
# Điều này giúp tránh việc phải tải lại mô hình mỗi khi người dùng gửi yêu cầu. | |
MODELS_AND_DATA = { | |
"llm_model": None, | |
"tokenizer": None, | |
"embedding_model": None, | |
"faiss_index": None, | |
"bm25_model": None, | |
"chunks_data": None, | |
"tokenized_corpus_bm25": None, | |
} | |
# --- Các hàm xử lý dữ liệu (từ các notebook của bạn) --- | |
def process_law_data_to_chunks(structured_data_input): | |
""" | |
Hàm làm phẳng dữ liệu luật có cấu trúc chi tiết thành danh sách các chunks. | |
Mỗi chunk chứa 'text' và 'metadata'. | |
""" | |
flat_list = [] | |
articles_list = [] | |
if isinstance(structured_data_input, dict) and "article" in structured_data_input: | |
articles_list = [structured_data_input] | |
elif isinstance(structured_data_input, list): | |
articles_list = structured_data_input | |
else: | |
print("Lỗi: Dữ liệu đầu vào không hợp lệ.") | |
return flat_list | |
for article_data in articles_list: | |
if not isinstance(article_data, dict): continue | |
article_metadata_base = { | |
"source_document": article_data.get("source_document"), | |
"article": article_data.get("article"), | |
"article_title": article_data.get("article_title") | |
} | |
clauses = article_data.get("clauses", []) | |
if not isinstance(clauses, list): continue | |
for clause_data in clauses: | |
if not isinstance(clause_data, dict): continue | |
clause_metadata_base = article_metadata_base.copy() | |
clause_metadata_base.update({ | |
"clause_number": clause_data.get("clause_number"), | |
"clause_metadata_summary": clause_data.get("clause_metadata_summary") | |
}) | |
points_in_clause = clause_data.get("points_in_clause", []) | |
if not isinstance(points_in_clause, list): continue | |
if points_in_clause: | |
for point_data in points_in_clause: | |
if not isinstance(point_data, dict): continue | |
chunk_text = point_data.get("point_text_original") or point_data.get("violation_description_summary") | |
if not chunk_text: continue | |
current_point_metadata = clause_metadata_base.copy() | |
point_specific_metadata = point_data.copy() | |
if "point_text_original" in point_specific_metadata: | |
del point_specific_metadata["point_text_original"] | |
current_point_metadata.update(point_specific_metadata) | |
final_metadata_cleaned = {k: v for k, v in current_point_metadata.items() if v is not None} | |
flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned}) | |
else: | |
chunk_text = clause_data.get("clause_text_original") | |
if chunk_text: | |
current_clause_metadata = clause_metadata_base.copy() | |
additional_clause_info = {k: v for k, value in clause_data.items() if k not in ["clause_text_original", "points_in_clause", "clause_number", "clause_metadata_summary"]} | |
if additional_clause_info: | |
current_clause_metadata.update(additional_clause_info) | |
final_metadata_cleaned = {k: v for k, v in current_clause_metadata.items() if v is not None} | |
flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned}) | |
return flat_list | |
def tokenize_vi_for_bm25(text): | |
"""Hàm tokenize tiếng Việt đơn giản cho BM25.""" | |
text = text.lower() | |
text = re.sub(r'[^\w\s]', '', text) | |
return text.split() | |
def load_all_resources(): | |
""" | |
Hàm chính để tải tất cả mô hình và dữ liệu cần thiết. | |
Chỉ chạy một lần khi ứng dụng khởi động. | |
""" | |
print("--- Bắt đầu quá trình tải tài nguyên ---") | |
# 1. Tải mô hình LLM và Tokenizer | |
print(f"1. Đang tải LLM và Tokenizer: {LLM_MODEL_NAME}...") | |
llm_model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=LLM_MODEL_NAME, | |
max_seq_length=MAX_SEQ_LENGTH, | |
dtype=DTYPE, | |
load_in_4bit=LOAD_IN_4BIT, | |
) | |
FastLanguageModel.for_inference(llm_model) | |
MODELS_AND_DATA["llm_model"] = llm_model | |
MODELS_AND_DATA["tokenizer"] = tokenizer | |
print("✅ Tải LLM và Tokenizer thành công.") | |
# 2. Tải mô hình Embedding | |
print(f"2. Đang tải Embedding Model: {EMBEDDING_MODEL_NAME}...") | |
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu") | |
MODELS_AND_DATA["embedding_model"] = embedding_model | |
print("✅ Tải Embedding Model thành công.") | |
# 3. Tải và xử lý dữ liệu luật | |
print(f"3. Đang tải và xử lý dữ liệu từ: {LAW_DATA_FILE}...") | |
if not os.path.exists(LAW_DATA_FILE): | |
raise FileNotFoundError(f"Không tìm thấy file dữ liệu luật: {LAW_DATA_FILE}. Vui lòng upload file này lên Space.") | |
with open(LAW_DATA_FILE, 'r', encoding='utf-8') as f: | |
raw_data_from_file = json.load(f) | |
chunks_data = process_law_data_to_chunks(raw_data_from_file) | |
MODELS_AND_DATA["chunks_data"] = chunks_data | |
print(f"✅ Đã xử lý thành {len(chunks_data)} chunks.") | |
# 4. Tạo BM25 Model | |
print("4. Đang tạo BM25 Model...") | |
corpus_texts = [chunk.get('text', '') for chunk in chunks_data] | |
tokenized_corpus = [tokenize_vi_for_bm25(text) for text in corpus_texts] | |
bm25_model = BM25Okapi(tokenized_corpus) | |
MODELS_AND_DATA["bm25_model"] = bm25_model | |
MODELS_AND_DATA["tokenized_corpus_bm25"] = tokenized_corpus | |
print("✅ Tạo BM25 Model thành công.") | |
# 5. Tạo FAISS Index | |
print("5. Đang tạo FAISS Index...") | |
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data] | |
chunk_embeddings = embedding_model.encode(texts_to_encode, convert_to_tensor=True, device=embedding_model.device) | |
chunk_embeddings_np = chunk_embeddings.cpu().numpy().astype('float32') | |
faiss.normalize_L2(chunk_embeddings_np) | |
dimension = chunk_embeddings_np.shape[1] | |
index = faiss.IndexFlatIP(dimension) | |
index.add(chunk_embeddings_np) | |
MODELS_AND_DATA["faiss_index"] = index | |
print(f"✅ Tạo FAISS Index thành công với {index.ntotal} vectors.") | |
print("\n--- Tải tài nguyên hoàn tất! Ứng dụng đã sẵn sàng. ---") | |
# --- PHẦN 3: CÁC HÀM LÕI CHO RAG --- | |
def search_relevant_laws(query_text, k=5, initial_k_multiplier=10, rrf_k_constant=60): | |
""" | |
Hàm thực hiện Hybrid Search để tìm các đoạn luật liên quan. | |
""" | |
# Lấy các tài nguyên đã tải | |
embedding_model = MODELS_AND_DATA["embedding_model"] | |
faiss_index = MODELS_AND_DATA["faiss_index"] | |
chunks_data = MODELS_AND_DATA["chunks_data"] | |
bm25_model = MODELS_AND_DATA["bm25_model"] | |
if not all([embedding_model, faiss_index, chunks_data, bm25_model]): | |
return "Lỗi: Tài nguyên chưa được tải xong. Vui lòng chờ." | |
# 1. Semantic Search (FAISS) | |
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device) | |
query_embedding_np = query_embedding.cpu().numpy().astype('float32') | |
faiss.normalize_L2(query_embedding_np) | |
num_candidates = min(k * initial_k_multiplier, faiss_index.ntotal) | |
semantic_scores, semantic_indices = faiss_index.search(query_embedding_np, num_candidates) | |
# 2. Keyword Search (BM25) | |
tokenized_query = tokenize_vi_for_bm25(query_text) | |
bm25_scores = bm25_model.get_scores(tokenized_query) | |
bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates] | |
# 3. Reciprocal Rank Fusion (RRF) | |
rrf_scores = defaultdict(float) | |
if semantic_indices.size > 0: | |
for rank, doc_idx in enumerate(semantic_indices[0]): | |
if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
for rank, (doc_idx, score) in enumerate(bm25_results): | |
if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) | |
# 4. Lấy kết quả cuối cùng | |
final_results = [] | |
for doc_idx, score in fused_results[:k]: | |
result = chunks_data[doc_idx].copy() | |
result['score'] = score | |
final_results.append(result) | |
return final_results | |
def generate_llm_response(query, context): | |
""" | |
Hàm sinh câu trả lời từ LLM dựa trên query và context. | |
""" | |
llm_model = MODELS_AND_DATA["llm_model"] | |
tokenizer = MODELS_AND_DATA["tokenizer"] | |
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam. | |
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ. | |
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất. | |
### Thông tin luật: | |
{context} | |
### Câu hỏi: | |
{query} | |
### Trả lời:""" | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
generation_config = dict( | |
max_new_tokens=300, | |
temperature=0.2, | |
top_p=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
output_ids = llm_model.generate(**inputs, **generation_config) | |
input_length = inputs.input_ids.shape[1] | |
generated_ids = output_ids[0][input_length:] | |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
return response_text | |
# --- PHẦN 4: CÁC HÀM XỬ LÝ CHO GRADIO INTERFACE --- | |
def run_retriever_only(query): | |
""" | |
Chức năng 1: Chỉ tìm kiếm và trả về các điều luật liên quan. | |
""" | |
print(f"Chạy chức năng Retriever cho query: '{query}'") | |
retrieved_results = search_relevant_laws(query) | |
if isinstance(retrieved_results, str): # Xử lý trường hợp lỗi | |
return retrieved_results | |
if not retrieved_results: | |
return "Không tìm thấy điều luật nào liên quan." | |
# Định dạng output cho Gradio Markdown | |
formatted_output = f"### Các điều luật liên quan nhất đến truy vấn của bạn:\n\n" | |
for i, res in enumerate(retrieved_results): | |
metadata = res.get('metadata', {}) | |
article = metadata.get('article', 'N/A') | |
clause = metadata.get('clause_number', 'N/A') | |
source = metadata.get('source_document', 'N/A') | |
text = res.get('text', 'N/A') | |
formatted_output += f"**{i+1}. Nguồn: {source} | Điều {article} | Khoản {clause}**\n" | |
formatted_output += f"> {text}\n\n---\n\n" | |
return formatted_output | |
def run_full_rag(query, progress=gr.Progress()): | |
""" | |
Chức năng 2: Thực hiện toàn bộ pipeline RAG. | |
""" | |
progress(0, desc="Bắt đầu...") | |
# Bước 1: Truy xuất ngữ cảnh | |
progress(0.2, desc="Đang tìm kiếm các điều luật liên quan (Hybrid Search)...") | |
print(f"Chạy chức năng RAG cho query: '{query}'") | |
retrieved_results = search_relevant_laws(query) | |
if isinstance(retrieved_results, str) or not retrieved_results: | |
context_for_llm = "Không tìm thấy thông tin luật liên quan." | |
context_for_display = context_for_llm | |
else: | |
# Định dạng context cho LLM | |
context_parts = [] | |
for res in retrieved_results: | |
text = res.get('text', '') | |
context_parts.append(text) | |
context_for_llm = "\n\n---\n\n".join(context_parts) | |
# Định dạng context để hiển thị cho người dùng | |
context_for_display = run_retriever_only(query) # Tái sử dụng hàm retriever | |
# Bước 2: Sinh câu trả lời | |
progress(0.7, desc="Đã có ngữ cảnh, đang yêu cầu LLM tạo câu trả lời...") | |
final_answer = generate_llm_response(query, context_for_llm) | |
progress(1, desc="Hoàn tất!") | |
return final_answer, context_for_display | |
# --- PHẦN 5: KHỞI CHẠY ỨNG DỤNG GRADIO --- | |
# Tải tài nguyên ngay khi script được chạy | |
load_all_resources() | |
with gr.Blocks(theme=gr.themes.Soft(), title="Chatbot Luật Giao thông") as demo: | |
gr.Markdown( | |
""" | |
# ⚖️ Chatbot Luật Giao thông Việt Nam | |
Ứng dụng này sử dụng mô hình RAG (Retrieval-Augmented Generation) để trả lời các câu hỏi về luật giao thông. | |
""" | |
) | |
with gr.Tabs(): | |
# Tab 1: Chỉ tìm kiếm | |
with gr.TabItem("Tìm kiếm Điều luật (Retriever)"): | |
with gr.Row(): | |
retriever_query = gr.Textbox(label="Nhập nội dung cần tìm kiếm", placeholder="Ví dụ: Vượt đèn đỏ bị phạt bao nhiêu tiền?", scale=4) | |
retriever_button = gr.Button("Tìm kiếm", variant="secondary", scale=1) | |
retriever_output = gr.Markdown(label="Các điều luật liên quan") | |
# Tab 2: Hỏi-đáp RAG đầy đủ | |
with gr.TabItem("Hỏi-Đáp (RAG)"): | |
with gr.Row(): | |
rag_query = gr.Textbox(label="Nhập câu hỏi của bạn", placeholder="Ví dụ: Phương tiện giao thông đường bộ gồm những loại nào?", scale=4) | |
rag_button = gr.Button("Gửi câu hỏi", variant="primary", scale=1) | |
rag_answer = gr.Textbox(label="Câu trả lời của Chatbot", lines=5) | |
with gr.Accordion("Xem ngữ cảnh đã sử dụng để tạo câu trả lời", open=False): | |
rag_context = gr.Markdown(label="Ngữ cảnh") | |
# Xử lý sự kiện click | |
retriever_button.click(fn=run_retriever_only, inputs=retriever_query, outputs=retriever_output) | |
rag_button.click(fn=run_full_rag, inputs=rag_query, outputs=[rag_answer, rag_context]) | |
if __name__ == "__main__": | |
demo.launch(share=True) # share=True để tạo link public nếu chạy trên Colab/local | |