chatbot_demo / app.py
deddoggo's picture
update main
845a94d
raw
history blame
15.4 kB
# app.py
# File triển khai hoàn chỉnh cho đồ án Chatbot Luật Giao thông
# Tác giả: (Tên của bạn)
# Ngày: (Ngày bạn tạo)
# --- PHẦN 1: IMPORT CÁC THƯ VIỆN CẦN THIẾT ---
print("Bắt đầu import các thư viện...")
import os
import sys
import json
import re
import time
from collections import defaultdict
# Core ML/DL và Unsloth
import torch
from unsloth import FastLanguageModel
from transformers import TextStreamer
# RAG - Retrieval
import faiss
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import numpy as np
# Deployment
import gradio as gr
print("✅ Import thư viện thành công.")
# --- PHẦN 2: CẤU HÌNH VÀ TẢI TÀI NGUYÊN (MODELS & DATA) ---
# Phần này sẽ chỉ chạy một lần khi ứng dụng khởi động.
# Cấu hình mô hình
MAX_SEQ_LENGTH = 2048
DTYPE = None
LOAD_IN_4BIT = True
EMBEDDING_MODEL_NAME = "bkai-foundation-models/vietnamese-bi-encoder"
LLM_MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
LAW_DATA_FILE = "luat_chi_tiet_output_openai_sdk_final_cleaned.json"
# Biến toàn cục để lưu các tài nguyên đã tải
# Điều này giúp tránh việc phải tải lại mô hình mỗi khi người dùng gửi yêu cầu.
MODELS_AND_DATA = {
"llm_model": None,
"tokenizer": None,
"embedding_model": None,
"faiss_index": None,
"bm25_model": None,
"chunks_data": None,
"tokenized_corpus_bm25": None,
}
# --- Các hàm xử lý dữ liệu (từ các notebook của bạn) ---
def process_law_data_to_chunks(structured_data_input):
"""
Hàm làm phẳng dữ liệu luật có cấu trúc chi tiết thành danh sách các chunks.
Mỗi chunk chứa 'text' và 'metadata'.
"""
flat_list = []
articles_list = []
if isinstance(structured_data_input, dict) and "article" in structured_data_input:
articles_list = [structured_data_input]
elif isinstance(structured_data_input, list):
articles_list = structured_data_input
else:
print("Lỗi: Dữ liệu đầu vào không hợp lệ.")
return flat_list
for article_data in articles_list:
if not isinstance(article_data, dict): continue
article_metadata_base = {
"source_document": article_data.get("source_document"),
"article": article_data.get("article"),
"article_title": article_data.get("article_title")
}
clauses = article_data.get("clauses", [])
if not isinstance(clauses, list): continue
for clause_data in clauses:
if not isinstance(clause_data, dict): continue
clause_metadata_base = article_metadata_base.copy()
clause_metadata_base.update({
"clause_number": clause_data.get("clause_number"),
"clause_metadata_summary": clause_data.get("clause_metadata_summary")
})
points_in_clause = clause_data.get("points_in_clause", [])
if not isinstance(points_in_clause, list): continue
if points_in_clause:
for point_data in points_in_clause:
if not isinstance(point_data, dict): continue
chunk_text = point_data.get("point_text_original") or point_data.get("violation_description_summary")
if not chunk_text: continue
current_point_metadata = clause_metadata_base.copy()
point_specific_metadata = point_data.copy()
if "point_text_original" in point_specific_metadata:
del point_specific_metadata["point_text_original"]
current_point_metadata.update(point_specific_metadata)
final_metadata_cleaned = {k: v for k, v in current_point_metadata.items() if v is not None}
flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
else:
chunk_text = clause_data.get("clause_text_original")
if chunk_text:
current_clause_metadata = clause_metadata_base.copy()
additional_clause_info = {k: v for k, value in clause_data.items() if k not in ["clause_text_original", "points_in_clause", "clause_number", "clause_metadata_summary"]}
if additional_clause_info:
current_clause_metadata.update(additional_clause_info)
final_metadata_cleaned = {k: v for k, v in current_clause_metadata.items() if v is not None}
flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
return flat_list
def tokenize_vi_for_bm25(text):
"""Hàm tokenize tiếng Việt đơn giản cho BM25."""
text = text.lower()
text = re.sub(r'[^\w\s]', '', text)
return text.split()
def load_all_resources():
"""
Hàm chính để tải tất cả mô hình và dữ liệu cần thiết.
Chỉ chạy một lần khi ứng dụng khởi động.
"""
print("--- Bắt đầu quá trình tải tài nguyên ---")
# 1. Tải mô hình LLM và Tokenizer
print(f"1. Đang tải LLM và Tokenizer: {LLM_MODEL_NAME}...")
llm_model, tokenizer = FastLanguageModel.from_pretrained(
model_name=LLM_MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=DTYPE,
load_in_4bit=LOAD_IN_4BIT,
)
FastLanguageModel.for_inference(llm_model)
MODELS_AND_DATA["llm_model"] = llm_model
MODELS_AND_DATA["tokenizer"] = tokenizer
print("✅ Tải LLM và Tokenizer thành công.")
# 2. Tải mô hình Embedding
print(f"2. Đang tải Embedding Model: {EMBEDDING_MODEL_NAME}...")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")
MODELS_AND_DATA["embedding_model"] = embedding_model
print("✅ Tải Embedding Model thành công.")
# 3. Tải và xử lý dữ liệu luật
print(f"3. Đang tải và xử lý dữ liệu từ: {LAW_DATA_FILE}...")
if not os.path.exists(LAW_DATA_FILE):
raise FileNotFoundError(f"Không tìm thấy file dữ liệu luật: {LAW_DATA_FILE}. Vui lòng upload file này lên Space.")
with open(LAW_DATA_FILE, 'r', encoding='utf-8') as f:
raw_data_from_file = json.load(f)
chunks_data = process_law_data_to_chunks(raw_data_from_file)
MODELS_AND_DATA["chunks_data"] = chunks_data
print(f"✅ Đã xử lý thành {len(chunks_data)} chunks.")
# 4. Tạo BM25 Model
print("4. Đang tạo BM25 Model...")
corpus_texts = [chunk.get('text', '') for chunk in chunks_data]
tokenized_corpus = [tokenize_vi_for_bm25(text) for text in corpus_texts]
bm25_model = BM25Okapi(tokenized_corpus)
MODELS_AND_DATA["bm25_model"] = bm25_model
MODELS_AND_DATA["tokenized_corpus_bm25"] = tokenized_corpus
print("✅ Tạo BM25 Model thành công.")
# 5. Tạo FAISS Index
print("5. Đang tạo FAISS Index...")
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
chunk_embeddings = embedding_model.encode(texts_to_encode, convert_to_tensor=True, device=embedding_model.device)
chunk_embeddings_np = chunk_embeddings.cpu().numpy().astype('float32')
faiss.normalize_L2(chunk_embeddings_np)
dimension = chunk_embeddings_np.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(chunk_embeddings_np)
MODELS_AND_DATA["faiss_index"] = index
print(f"✅ Tạo FAISS Index thành công với {index.ntotal} vectors.")
print("\n--- Tải tài nguyên hoàn tất! Ứng dụng đã sẵn sàng. ---")
# --- PHẦN 3: CÁC HÀM LÕI CHO RAG ---
def search_relevant_laws(query_text, k=5, initial_k_multiplier=10, rrf_k_constant=60):
"""
Hàm thực hiện Hybrid Search để tìm các đoạn luật liên quan.
"""
# Lấy các tài nguyên đã tải
embedding_model = MODELS_AND_DATA["embedding_model"]
faiss_index = MODELS_AND_DATA["faiss_index"]
chunks_data = MODELS_AND_DATA["chunks_data"]
bm25_model = MODELS_AND_DATA["bm25_model"]
if not all([embedding_model, faiss_index, chunks_data, bm25_model]):
return "Lỗi: Tài nguyên chưa được tải xong. Vui lòng chờ."
# 1. Semantic Search (FAISS)
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
query_embedding_np = query_embedding.cpu().numpy().astype('float32')
faiss.normalize_L2(query_embedding_np)
num_candidates = min(k * initial_k_multiplier, faiss_index.ntotal)
semantic_scores, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
# 2. Keyword Search (BM25)
tokenized_query = tokenize_vi_for_bm25(query_text)
bm25_scores = bm25_model.get_scores(tokenized_query)
bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates]
# 3. Reciprocal Rank Fusion (RRF)
rrf_scores = defaultdict(float)
if semantic_indices.size > 0:
for rank, doc_idx in enumerate(semantic_indices[0]):
if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
for rank, (doc_idx, score) in enumerate(bm25_results):
if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
# 4. Lấy kết quả cuối cùng
final_results = []
for doc_idx, score in fused_results[:k]:
result = chunks_data[doc_idx].copy()
result['score'] = score
final_results.append(result)
return final_results
def generate_llm_response(query, context):
"""
Hàm sinh câu trả lời từ LLM dựa trên query và context.
"""
llm_model = MODELS_AND_DATA["llm_model"]
tokenizer = MODELS_AND_DATA["tokenizer"]
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
### Thông tin luật:
{context}
### Câu hỏi:
{query}
### Trả lời:"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generation_config = dict(
max_new_tokens=300,
temperature=0.2,
top_p=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output_ids = llm_model.generate(**inputs, **generation_config)
input_length = inputs.input_ids.shape[1]
generated_ids = output_ids[0][input_length:]
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
return response_text
# --- PHẦN 4: CÁC HÀM XỬ LÝ CHO GRADIO INTERFACE ---
def run_retriever_only(query):
"""
Chức năng 1: Chỉ tìm kiếm và trả về các điều luật liên quan.
"""
print(f"Chạy chức năng Retriever cho query: '{query}'")
retrieved_results = search_relevant_laws(query)
if isinstance(retrieved_results, str): # Xử lý trường hợp lỗi
return retrieved_results
if not retrieved_results:
return "Không tìm thấy điều luật nào liên quan."
# Định dạng output cho Gradio Markdown
formatted_output = f"### Các điều luật liên quan nhất đến truy vấn của bạn:\n\n"
for i, res in enumerate(retrieved_results):
metadata = res.get('metadata', {})
article = metadata.get('article', 'N/A')
clause = metadata.get('clause_number', 'N/A')
source = metadata.get('source_document', 'N/A')
text = res.get('text', 'N/A')
formatted_output += f"**{i+1}. Nguồn: {source} | Điều {article} | Khoản {clause}**\n"
formatted_output += f"> {text}\n\n---\n\n"
return formatted_output
def run_full_rag(query, progress=gr.Progress()):
"""
Chức năng 2: Thực hiện toàn bộ pipeline RAG.
"""
progress(0, desc="Bắt đầu...")
# Bước 1: Truy xuất ngữ cảnh
progress(0.2, desc="Đang tìm kiếm các điều luật liên quan (Hybrid Search)...")
print(f"Chạy chức năng RAG cho query: '{query}'")
retrieved_results = search_relevant_laws(query)
if isinstance(retrieved_results, str) or not retrieved_results:
context_for_llm = "Không tìm thấy thông tin luật liên quan."
context_for_display = context_for_llm
else:
# Định dạng context cho LLM
context_parts = []
for res in retrieved_results:
text = res.get('text', '')
context_parts.append(text)
context_for_llm = "\n\n---\n\n".join(context_parts)
# Định dạng context để hiển thị cho người dùng
context_for_display = run_retriever_only(query) # Tái sử dụng hàm retriever
# Bước 2: Sinh câu trả lời
progress(0.7, desc="Đã có ngữ cảnh, đang yêu cầu LLM tạo câu trả lời...")
final_answer = generate_llm_response(query, context_for_llm)
progress(1, desc="Hoàn tất!")
return final_answer, context_for_display
# --- PHẦN 5: KHỞI CHẠY ỨNG DỤNG GRADIO ---
# Tải tài nguyên ngay khi script được chạy
load_all_resources()
with gr.Blocks(theme=gr.themes.Soft(), title="Chatbot Luật Giao thông") as demo:
gr.Markdown(
"""
# ⚖️ Chatbot Luật Giao thông Việt Nam
Ứng dụng này sử dụng mô hình RAG (Retrieval-Augmented Generation) để trả lời các câu hỏi về luật giao thông.
"""
)
with gr.Tabs():
# Tab 1: Chỉ tìm kiếm
with gr.TabItem("Tìm kiếm Điều luật (Retriever)"):
with gr.Row():
retriever_query = gr.Textbox(label="Nhập nội dung cần tìm kiếm", placeholder="Ví dụ: Vượt đèn đỏ bị phạt bao nhiêu tiền?", scale=4)
retriever_button = gr.Button("Tìm kiếm", variant="secondary", scale=1)
retriever_output = gr.Markdown(label="Các điều luật liên quan")
# Tab 2: Hỏi-đáp RAG đầy đủ
with gr.TabItem("Hỏi-Đáp (RAG)"):
with gr.Row():
rag_query = gr.Textbox(label="Nhập câu hỏi của bạn", placeholder="Ví dụ: Phương tiện giao thông đường bộ gồm những loại nào?", scale=4)
rag_button = gr.Button("Gửi câu hỏi", variant="primary", scale=1)
rag_answer = gr.Textbox(label="Câu trả lời của Chatbot", lines=5)
with gr.Accordion("Xem ngữ cảnh đã sử dụng để tạo câu trả lời", open=False):
rag_context = gr.Markdown(label="Ngữ cảnh")
# Xử lý sự kiện click
retriever_button.click(fn=run_retriever_only, inputs=retriever_query, outputs=retriever_output)
rag_button.click(fn=run_full_rag, inputs=rag_query, outputs=[rag_answer, rag_context])
if __name__ == "__main__":
demo.launch(share=True) # share=True để tạo link public nếu chạy trên Colab/local