File size: 6,658 Bytes
25950e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import torch # Cần cho việc kiểm tra CUDA và device
# Import các hàm và lớp từ các file của bạn
from retrieval import (
    process_law_data_to_chunks,
    # VEHICLE_TYPE_MAP, # Có thể không cần import trực tiếp nếu chỉ dùng trong retrieval.py
    # get_standardized_vehicle_type, # Tương tự
    # analyze_query, # Được gọi bởi search_relevant_laws
    tokenize_vi_for_bm25_setup, # Cần cho BM25
    search_relevant_laws
)
from llm_handler import generate_response # Giả sử hàm này đã được điều chỉnh để nhận model, tokenizer, etc.
from sentence_transformers import SentenceTransformer
import faiss
from rank_bm25 import BM25Okapi
import json
from unsloth import FastLanguageModel # Từ llm_handler.py hoặc import trực tiếp nếu logic tải model ở đây

# --- KHỞI TẠO MỘT LẦN KHI APP KHỞI ĐỘNG ---
# Đường dẫn (điều chỉnh nếu cần, có thể dùng os.path.join)
JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json"
FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index"
LLM_MODEL_PATH = "models/lora_model_base" # Hoặc đường dẫn cục bộ
EMBEDDING_MODEL_PATH = "models/embedding_model"

# 1. Tải và xử lý dữ liệu luật
print("Loading and processing law data...")
try:
    with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f:
        raw_data_from_file = json.load(f)
    chunks_data = process_law_data_to_chunks(raw_data_from_file)
    print(f"Loaded {len(chunks_data)} chunks.")
    if not chunks_data:
        raise ValueError("Chunks data is empty after processing.")
except Exception as e:
    print(f"Error loading/processing law data: {e}")
    chunks_data = [] # Hoặc xử lý lỗi khác

# 2. Tải mô hình embedding
print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
    embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device)
    print("Embedding model loaded successfully.")
except Exception as e:
    print(f"Error loading embedding model: {e}")
    embedding_model = None # Xử lý lỗi

# 3. Tải FAISS index
print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...")
try:
    faiss_index = faiss.read_index(FAISS_INDEX_PATH)
    print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}")
except Exception as e:
    print(f"Error loading FAISS index: {e}")
    faiss_index = None # Xử lý lỗi

# 4. Tạo BM25 model
print("Creating BM25 model...")
bm25_model = None
if chunks_data:
    try:
        corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
        tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
        bm25_model = BM25Okapi(tokenized_corpus_bm25)
        print("BM25 model created successfully.")
    except Exception as e:
        print(f"Error creating BM25 model: {e}")
else:
    print("Skipping BM25 model creation as chunks_data is empty.")


# 5. Tải mô hình LLM và tokenizer (sử dụng Unsloth)
print(f"Loading LLM model: {LLM_MODEL_PATH}...")
try:
    # Nên đặt logic tải model LLM vào llm_handler.py và gọi hàm đó ở đây
    # Hoặc trực tiếp:
    llm_model, llm_tokenizer = FastLanguageModel.from_pretrained(
        model_name=LLM_MODEL_PATH, # Đường dẫn tới model đã fine-tune
        max_seq_length=2048,
        dtype=None, # Unsloth sẽ tự động chọn
        load_in_4bit=True, # Sử dụng 4-bit quantization
    )
    FastLanguageModel.for_inference(llm_model) # Tối ưu cho inference
    print("LLM model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading LLM model: {e}")
    llm_model = None
    llm_tokenizer = None
# --- KẾT THÚC KHỞI TẠO MỘT LẦN ---

# Hàm respond mới sẽ sử dụng các model và data đã tải ở trên
def respond(message, history: list[tuple[str, str]]):
    if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]):
        # Ghi log chi tiết hơn ở đây nếu cần để biết thành phần nào bị thiếu
        missing_components = []
        if not chunks_data: missing_components.append("chunks_data")
        if not embedding_model: missing_components.append("embedding_model")
        if not faiss_index: missing_components.append("faiss_index")
        if not bm25_model: missing_components.append("bm25_model")
        if not llm_model: missing_components.append("llm_model")
        if not llm_tokenizer: missing_components.append("llm_tokenizer")
        error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space."
        print(error_msg) # In ra console log của Space
        return error_msg # Trả về cho người dùng

    try:
        response_text = generate_response(
            query=message,
            llama_model=llm_model,
            tokenizer=llm_tokenizer,
            faiss_index=faiss_index,
            embed_model=embedding_model,
            chunks_data_list=chunks_data,
            bm25_model=bm25_model,
            search_function=search_relevant_laws  # << RẤT QUAN TRỌNG: Đã thêm tham số này
            # Bạn có thể truyền thêm các tham số search_k, search_multiplier,
            # rrf_k_constant, max_new_tokens, temperature, etc. vào đây
            # nếu bạn muốn ghi đè giá trị mặc định trong llm_handler.generate_response
            # Ví dụ:
            # search_k=5,
            # max_new_tokens=768
        )
        yield response_text

    except Exception as e:
        # Ghi log lỗi chi tiết hơn
        import traceback
        print(f"Error during response generation for query '{message}': {e}")
        print(traceback.format_exc()) # In stack trace để debug
        yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên."

# Giao diện Gradio
# Bỏ các additional_inputs không cần thiết nếu chúng được xử lý bên trong generate_response
# hoặc nếu bạn không muốn người dùng cuối thay đổi chúng.
demo = gr.ChatInterface(
    respond,
    # additional_inputs=[ # Bạn có thể thêm lại nếu muốn người dùng tùy chỉnh
    #     gr.Textbox(value="You are a helpful Law Chatbot.", label="System message"), # Ví dụ
    # ]
)

if __name__ == "__main__":
    demo.launch()