File size: 6,658 Bytes
25950e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import torch # Cần cho việc kiểm tra CUDA và device
# Import các hàm và lớp từ các file của bạn
from retrieval import (
process_law_data_to_chunks,
# VEHICLE_TYPE_MAP, # Có thể không cần import trực tiếp nếu chỉ dùng trong retrieval.py
# get_standardized_vehicle_type, # Tương tự
# analyze_query, # Được gọi bởi search_relevant_laws
tokenize_vi_for_bm25_setup, # Cần cho BM25
search_relevant_laws
)
from llm_handler import generate_response # Giả sử hàm này đã được điều chỉnh để nhận model, tokenizer, etc.
from sentence_transformers import SentenceTransformer
import faiss
from rank_bm25 import BM25Okapi
import json
from unsloth import FastLanguageModel # Từ llm_handler.py hoặc import trực tiếp nếu logic tải model ở đây
# --- KHỞI TẠO MỘT LẦN KHI APP KHỞI ĐỘNG ---
# Đường dẫn (điều chỉnh nếu cần, có thể dùng os.path.join)
JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json"
FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index"
LLM_MODEL_PATH = "models/lora_model_base" # Hoặc đường dẫn cục bộ
EMBEDDING_MODEL_PATH = "models/embedding_model"
# 1. Tải và xử lý dữ liệu luật
print("Loading and processing law data...")
try:
with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f:
raw_data_from_file = json.load(f)
chunks_data = process_law_data_to_chunks(raw_data_from_file)
print(f"Loaded {len(chunks_data)} chunks.")
if not chunks_data:
raise ValueError("Chunks data is empty after processing.")
except Exception as e:
print(f"Error loading/processing law data: {e}")
chunks_data = [] # Hoặc xử lý lỗi khác
# 2. Tải mô hình embedding
print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device)
print("Embedding model loaded successfully.")
except Exception as e:
print(f"Error loading embedding model: {e}")
embedding_model = None # Xử lý lỗi
# 3. Tải FAISS index
print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...")
try:
faiss_index = faiss.read_index(FAISS_INDEX_PATH)
print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}")
except Exception as e:
print(f"Error loading FAISS index: {e}")
faiss_index = None # Xử lý lỗi
# 4. Tạo BM25 model
print("Creating BM25 model...")
bm25_model = None
if chunks_data:
try:
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
bm25_model = BM25Okapi(tokenized_corpus_bm25)
print("BM25 model created successfully.")
except Exception as e:
print(f"Error creating BM25 model: {e}")
else:
print("Skipping BM25 model creation as chunks_data is empty.")
# 5. Tải mô hình LLM và tokenizer (sử dụng Unsloth)
print(f"Loading LLM model: {LLM_MODEL_PATH}...")
try:
# Nên đặt logic tải model LLM vào llm_handler.py và gọi hàm đó ở đây
# Hoặc trực tiếp:
llm_model, llm_tokenizer = FastLanguageModel.from_pretrained(
model_name=LLM_MODEL_PATH, # Đường dẫn tới model đã fine-tune
max_seq_length=2048,
dtype=None, # Unsloth sẽ tự động chọn
load_in_4bit=True, # Sử dụng 4-bit quantization
)
FastLanguageModel.for_inference(llm_model) # Tối ưu cho inference
print("LLM model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading LLM model: {e}")
llm_model = None
llm_tokenizer = None
# --- KẾT THÚC KHỞI TẠO MỘT LẦN ---
# Hàm respond mới sẽ sử dụng các model và data đã tải ở trên
def respond(message, history: list[tuple[str, str]]):
if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]):
# Ghi log chi tiết hơn ở đây nếu cần để biết thành phần nào bị thiếu
missing_components = []
if not chunks_data: missing_components.append("chunks_data")
if not embedding_model: missing_components.append("embedding_model")
if not faiss_index: missing_components.append("faiss_index")
if not bm25_model: missing_components.append("bm25_model")
if not llm_model: missing_components.append("llm_model")
if not llm_tokenizer: missing_components.append("llm_tokenizer")
error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space."
print(error_msg) # In ra console log của Space
return error_msg # Trả về cho người dùng
try:
response_text = generate_response(
query=message,
llama_model=llm_model,
tokenizer=llm_tokenizer,
faiss_index=faiss_index,
embed_model=embedding_model,
chunks_data_list=chunks_data,
bm25_model=bm25_model,
search_function=search_relevant_laws # << RẤT QUAN TRỌNG: Đã thêm tham số này
# Bạn có thể truyền thêm các tham số search_k, search_multiplier,
# rrf_k_constant, max_new_tokens, temperature, etc. vào đây
# nếu bạn muốn ghi đè giá trị mặc định trong llm_handler.generate_response
# Ví dụ:
# search_k=5,
# max_new_tokens=768
)
yield response_text
except Exception as e:
# Ghi log lỗi chi tiết hơn
import traceback
print(f"Error during response generation for query '{message}': {e}")
print(traceback.format_exc()) # In stack trace để debug
yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên."
# Giao diện Gradio
# Bỏ các additional_inputs không cần thiết nếu chúng được xử lý bên trong generate_response
# hoặc nếu bạn không muốn người dùng cuối thay đổi chúng.
demo = gr.ChatInterface(
respond,
# additional_inputs=[ # Bạn có thể thêm lại nếu muốn người dùng tùy chỉnh
# gr.Textbox(value="You are a helpful Law Chatbot.", label="System message"), # Ví dụ
# ]
)
if __name__ == "__main__":
demo.launch() |