Spaces:
Paused
Paused
update main
Browse files
app.py
CHANGED
@@ -1,7 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
|
|
1 |
+
# app.py
|
2 |
+
# File triển khai hoàn chỉnh cho đồ án Chatbot Luật Giao thông
|
3 |
+
# Tác giả: (Tên của bạn)
|
4 |
+
# Ngày: (Ngày bạn tạo)
|
5 |
+
|
6 |
+
# --- PHẦN 1: IMPORT CÁC THƯ VIỆN CẦN THIẾT ---
|
7 |
+
print("Bắt đầu import các thư viện...")
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import json
|
11 |
+
import re
|
12 |
+
import time
|
13 |
+
from collections import defaultdict
|
14 |
+
|
15 |
+
# Core ML/DL và Unsloth
|
16 |
+
import torch
|
17 |
+
from unsloth import FastLanguageModel
|
18 |
+
from transformers import TextStreamer
|
19 |
+
|
20 |
+
# RAG - Retrieval
|
21 |
+
import faiss
|
22 |
+
from sentence_transformers import SentenceTransformer
|
23 |
+
from rank_bm25 import BM25Okapi
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
# Deployment
|
27 |
import gradio as gr
|
28 |
|
29 |
+
print("✅ Import thư viện thành công.")
|
30 |
+
|
31 |
+
# --- PHẦN 2: CẤU HÌNH VÀ TẢI TÀI NGUYÊN (MODELS & DATA) ---
|
32 |
+
# Phần này sẽ chỉ chạy một lần khi ứng dụng khởi động.
|
33 |
+
|
34 |
+
# Cấu hình mô hình
|
35 |
+
MAX_SEQ_LENGTH = 2048
|
36 |
+
DTYPE = None
|
37 |
+
LOAD_IN_4BIT = True
|
38 |
+
EMBEDDING_MODEL_NAME = "bkai-foundation-models/vietnamese-bi-encoder"
|
39 |
+
LLM_MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
|
40 |
+
LAW_DATA_FILE = "luat_chi_tiet_output_openai_sdk_final_cleaned.json"
|
41 |
+
|
42 |
+
# Biến toàn cục để lưu các tài nguyên đã tải
|
43 |
+
# Điều này giúp tránh việc phải tải lại mô hình mỗi khi người dùng gửi yêu cầu.
|
44 |
+
MODELS_AND_DATA = {
|
45 |
+
"llm_model": None,
|
46 |
+
"tokenizer": None,
|
47 |
+
"embedding_model": None,
|
48 |
+
"faiss_index": None,
|
49 |
+
"bm25_model": None,
|
50 |
+
"chunks_data": None,
|
51 |
+
"tokenized_corpus_bm25": None,
|
52 |
+
}
|
53 |
+
|
54 |
+
# --- Các hàm xử lý dữ liệu (từ các notebook của bạn) ---
|
55 |
+
|
56 |
+
def process_law_data_to_chunks(structured_data_input):
|
57 |
+
"""
|
58 |
+
Hàm làm phẳng dữ liệu luật có cấu trúc chi tiết thành danh sách các chunks.
|
59 |
+
Mỗi chunk chứa 'text' và 'metadata'.
|
60 |
+
"""
|
61 |
+
flat_list = []
|
62 |
+
articles_list = []
|
63 |
+
if isinstance(structured_data_input, dict) and "article" in structured_data_input:
|
64 |
+
articles_list = [structured_data_input]
|
65 |
+
elif isinstance(structured_data_input, list):
|
66 |
+
articles_list = structured_data_input
|
67 |
+
else:
|
68 |
+
print("Lỗi: Dữ liệu đầu vào không hợp lệ.")
|
69 |
+
return flat_list
|
70 |
+
|
71 |
+
for article_data in articles_list:
|
72 |
+
if not isinstance(article_data, dict): continue
|
73 |
+
article_metadata_base = {
|
74 |
+
"source_document": article_data.get("source_document"),
|
75 |
+
"article": article_data.get("article"),
|
76 |
+
"article_title": article_data.get("article_title")
|
77 |
+
}
|
78 |
+
clauses = article_data.get("clauses", [])
|
79 |
+
if not isinstance(clauses, list): continue
|
80 |
+
|
81 |
+
for clause_data in clauses:
|
82 |
+
if not isinstance(clause_data, dict): continue
|
83 |
+
clause_metadata_base = article_metadata_base.copy()
|
84 |
+
clause_metadata_base.update({
|
85 |
+
"clause_number": clause_data.get("clause_number"),
|
86 |
+
"clause_metadata_summary": clause_data.get("clause_metadata_summary")
|
87 |
+
})
|
88 |
+
points_in_clause = clause_data.get("points_in_clause", [])
|
89 |
+
if not isinstance(points_in_clause, list): continue
|
90 |
+
|
91 |
+
if points_in_clause:
|
92 |
+
for point_data in points_in_clause:
|
93 |
+
if not isinstance(point_data, dict): continue
|
94 |
+
chunk_text = point_data.get("point_text_original") or point_data.get("violation_description_summary")
|
95 |
+
if not chunk_text: continue
|
96 |
+
|
97 |
+
current_point_metadata = clause_metadata_base.copy()
|
98 |
+
point_specific_metadata = point_data.copy()
|
99 |
+
if "point_text_original" in point_specific_metadata:
|
100 |
+
del point_specific_metadata["point_text_original"]
|
101 |
+
current_point_metadata.update(point_specific_metadata)
|
102 |
+
final_metadata_cleaned = {k: v for k, v in current_point_metadata.items() if v is not None}
|
103 |
+
flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
|
104 |
+
else:
|
105 |
+
chunk_text = clause_data.get("clause_text_original")
|
106 |
+
if chunk_text:
|
107 |
+
current_clause_metadata = clause_metadata_base.copy()
|
108 |
+
additional_clause_info = {k: v for k, value in clause_data.items() if k not in ["clause_text_original", "points_in_clause", "clause_number", "clause_metadata_summary"]}
|
109 |
+
if additional_clause_info:
|
110 |
+
current_clause_metadata.update(additional_clause_info)
|
111 |
+
final_metadata_cleaned = {k: v for k, v in current_clause_metadata.items() if v is not None}
|
112 |
+
flat_list.append({"text": chunk_text, "metadata": final_metadata_cleaned})
|
113 |
+
return flat_list
|
114 |
+
|
115 |
+
def tokenize_vi_for_bm25(text):
|
116 |
+
"""Hàm tokenize tiếng Việt đơn giản cho BM25."""
|
117 |
+
text = text.lower()
|
118 |
+
text = re.sub(r'[^\w\s]', '', text)
|
119 |
+
return text.split()
|
120 |
+
|
121 |
+
def load_all_resources():
|
122 |
+
"""
|
123 |
+
Hàm chính để tải tất cả mô hình và dữ liệu cần thiết.
|
124 |
+
Chỉ chạy một lần khi ứng dụng khởi động.
|
125 |
+
"""
|
126 |
+
print("--- Bắt đầu quá trình tải tài nguyên ---")
|
127 |
+
|
128 |
+
# 1. Tải mô hình LLM và Tokenizer
|
129 |
+
print(f"1. Đang tải LLM và Tokenizer: {LLM_MODEL_NAME}...")
|
130 |
+
llm_model, tokenizer = FastLanguageModel.from_pretrained(
|
131 |
+
model_name=LLM_MODEL_NAME,
|
132 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
133 |
+
dtype=DTYPE,
|
134 |
+
load_in_4bit=LOAD_IN_4BIT,
|
135 |
+
)
|
136 |
+
FastLanguageModel.for_inference(llm_model)
|
137 |
+
MODELS_AND_DATA["llm_model"] = llm_model
|
138 |
+
MODELS_AND_DATA["tokenizer"] = tokenizer
|
139 |
+
print("✅ Tải LLM và Tokenizer thành công.")
|
140 |
+
|
141 |
+
# 2. Tải mô hình Embedding
|
142 |
+
print(f"2. Đang tải Embedding Model: {EMBEDDING_MODEL_NAME}...")
|
143 |
+
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device="cuda" if torch.cuda.is_available() else "cpu")
|
144 |
+
MODELS_AND_DATA["embedding_model"] = embedding_model
|
145 |
+
print("✅ Tải Embedding Model thành công.")
|
146 |
+
|
147 |
+
# 3. Tải và xử lý dữ liệu luật
|
148 |
+
print(f"3. Đang tải và xử lý dữ liệu từ: {LAW_DATA_FILE}...")
|
149 |
+
if not os.path.exists(LAW_DATA_FILE):
|
150 |
+
raise FileNotFoundError(f"Không tìm thấy file dữ liệu luật: {LAW_DATA_FILE}. Vui lòng upload file này lên Space.")
|
151 |
+
with open(LAW_DATA_FILE, 'r', encoding='utf-8') as f:
|
152 |
+
raw_data_from_file = json.load(f)
|
153 |
+
chunks_data = process_law_data_to_chunks(raw_data_from_file)
|
154 |
+
MODELS_AND_DATA["chunks_data"] = chunks_data
|
155 |
+
print(f"✅ Đã xử lý thành {len(chunks_data)} chunks.")
|
156 |
+
|
157 |
+
# 4. Tạo BM25 Model
|
158 |
+
print("4. Đang tạo BM25 Model...")
|
159 |
+
corpus_texts = [chunk.get('text', '') for chunk in chunks_data]
|
160 |
+
tokenized_corpus = [tokenize_vi_for_bm25(text) for text in corpus_texts]
|
161 |
+
bm25_model = BM25Okapi(tokenized_corpus)
|
162 |
+
MODELS_AND_DATA["bm25_model"] = bm25_model
|
163 |
+
MODELS_AND_DATA["tokenized_corpus_bm25"] = tokenized_corpus
|
164 |
+
print("✅ Tạo BM25 Model thành công.")
|
165 |
+
|
166 |
+
# 5. Tạo FAISS Index
|
167 |
+
print("5. Đang tạo FAISS Index...")
|
168 |
+
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
|
169 |
+
chunk_embeddings = embedding_model.encode(texts_to_encode, convert_to_tensor=True, device=embedding_model.device)
|
170 |
+
chunk_embeddings_np = chunk_embeddings.cpu().numpy().astype('float32')
|
171 |
+
faiss.normalize_L2(chunk_embeddings_np)
|
172 |
+
dimension = chunk_embeddings_np.shape[1]
|
173 |
+
index = faiss.IndexFlatIP(dimension)
|
174 |
+
index.add(chunk_embeddings_np)
|
175 |
+
MODELS_AND_DATA["faiss_index"] = index
|
176 |
+
print(f"✅ Tạo FAISS Index thành công với {index.ntotal} vectors.")
|
177 |
+
|
178 |
+
print("\n--- Tải tài nguyên hoàn tất! Ứng dụng đã sẵn sàng. ---")
|
179 |
+
|
180 |
+
# --- PHẦN 3: CÁC HÀM LÕI CHO RAG ---
|
181 |
+
|
182 |
+
def search_relevant_laws(query_text, k=5, initial_k_multiplier=10, rrf_k_constant=60):
|
183 |
+
"""
|
184 |
+
Hàm thực hiện Hybrid Search để tìm các đoạn luật liên quan.
|
185 |
+
"""
|
186 |
+
# Lấy các tài nguyên đã tải
|
187 |
+
embedding_model = MODELS_AND_DATA["embedding_model"]
|
188 |
+
faiss_index = MODELS_AND_DATA["faiss_index"]
|
189 |
+
chunks_data = MODELS_AND_DATA["chunks_data"]
|
190 |
+
bm25_model = MODELS_AND_DATA["bm25_model"]
|
191 |
+
|
192 |
+
if not all([embedding_model, faiss_index, chunks_data, bm25_model]):
|
193 |
+
return "Lỗi: Tài nguyên chưa được tải xong. Vui lòng chờ."
|
194 |
+
|
195 |
+
# 1. Semantic Search (FAISS)
|
196 |
+
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
|
197 |
+
query_embedding_np = query_embedding.cpu().numpy().astype('float32')
|
198 |
+
faiss.normalize_L2(query_embedding_np)
|
199 |
+
num_candidates = min(k * initial_k_multiplier, faiss_index.ntotal)
|
200 |
+
semantic_scores, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
|
201 |
+
|
202 |
+
# 2. Keyword Search (BM25)
|
203 |
+
tokenized_query = tokenize_vi_for_bm25(query_text)
|
204 |
+
bm25_scores = bm25_model.get_scores(tokenized_query)
|
205 |
+
bm25_results = sorted(enumerate(bm25_scores), key=lambda x: x[1], reverse=True)[:num_candidates]
|
206 |
+
|
207 |
+
# 3. Reciprocal Rank Fusion (RRF)
|
208 |
+
rrf_scores = defaultdict(float)
|
209 |
+
if semantic_indices.size > 0:
|
210 |
+
for rank, doc_idx in enumerate(semantic_indices[0]):
|
211 |
+
if doc_idx != -1: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
|
212 |
+
for rank, (doc_idx, score) in enumerate(bm25_results):
|
213 |
+
if score > 0: rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
|
214 |
+
|
215 |
+
fused_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
|
216 |
+
|
217 |
+
# 4. Lấy kết quả cuối cùng
|
218 |
+
final_results = []
|
219 |
+
for doc_idx, score in fused_results[:k]:
|
220 |
+
result = chunks_data[doc_idx].copy()
|
221 |
+
result['score'] = score
|
222 |
+
final_results.append(result)
|
223 |
+
|
224 |
+
return final_results
|
225 |
+
|
226 |
+
def generate_llm_response(query, context):
|
227 |
+
"""
|
228 |
+
Hàm sinh câu trả lời từ LLM dựa trên query và context.
|
229 |
+
"""
|
230 |
+
llm_model = MODELS_AND_DATA["llm_model"]
|
231 |
+
tokenizer = MODELS_AND_DATA["tokenizer"]
|
232 |
+
|
233 |
+
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
|
234 |
+
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
|
235 |
+
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
|
236 |
+
|
237 |
+
### Thông tin luật:
|
238 |
+
{context}
|
239 |
+
|
240 |
+
### Câu hỏi:
|
241 |
+
{query}
|
242 |
+
|
243 |
+
### Trả lời:"""
|
244 |
+
|
245 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
246 |
+
generation_config = dict(
|
247 |
+
max_new_tokens=300,
|
248 |
+
temperature=0.2,
|
249 |
+
top_p=0.7,
|
250 |
+
do_sample=True,
|
251 |
+
pad_token_id=tokenizer.eos_token_id,
|
252 |
+
eos_token_id=tokenizer.eos_token_id
|
253 |
+
)
|
254 |
+
output_ids = llm_model.generate(**inputs, **generation_config)
|
255 |
+
input_length = inputs.input_ids.shape[1]
|
256 |
+
generated_ids = output_ids[0][input_length:]
|
257 |
+
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
258 |
+
return response_text
|
259 |
+
|
260 |
+
# --- PHẦN 4: CÁC HÀM XỬ LÝ CHO GRADIO INTERFACE ---
|
261 |
+
|
262 |
+
def run_retriever_only(query):
|
263 |
+
"""
|
264 |
+
Chức năng 1: Chỉ tìm kiếm và trả về các điều luật liên quan.
|
265 |
+
"""
|
266 |
+
print(f"Chạy chức năng Retriever cho query: '{query}'")
|
267 |
+
retrieved_results = search_relevant_laws(query)
|
268 |
+
|
269 |
+
if isinstance(retrieved_results, str): # Xử lý trường hợp lỗi
|
270 |
+
return retrieved_results
|
271 |
+
|
272 |
+
if not retrieved_results:
|
273 |
+
return "Không tìm thấy điều luật nào liên quan."
|
274 |
+
|
275 |
+
# Định dạng output cho Gradio Markdown
|
276 |
+
formatted_output = f"### Các điều luật liên quan nhất đến truy vấn của bạn:\n\n"
|
277 |
+
for i, res in enumerate(retrieved_results):
|
278 |
+
metadata = res.get('metadata', {})
|
279 |
+
article = metadata.get('article', 'N/A')
|
280 |
+
clause = metadata.get('clause_number', 'N/A')
|
281 |
+
source = metadata.get('source_document', 'N/A')
|
282 |
+
text = res.get('text', 'N/A')
|
283 |
+
|
284 |
+
formatted_output += f"**{i+1}. Nguồn: {source} | Điều {article} | Khoản {clause}**\n"
|
285 |
+
formatted_output += f"> {text}\n\n---\n\n"
|
286 |
+
|
287 |
+
return formatted_output
|
288 |
+
|
289 |
+
def run_full_rag(query, progress=gr.Progress()):
|
290 |
+
"""
|
291 |
+
Chức năng 2: Thực hiện toàn bộ pipeline RAG.
|
292 |
+
"""
|
293 |
+
progress(0, desc="Bắt đầu...")
|
294 |
+
|
295 |
+
# Bước 1: Truy xuất ngữ cảnh
|
296 |
+
progress(0.2, desc="Đang tìm kiếm các điều luật liên quan (Hybrid Search)...")
|
297 |
+
print(f"Chạy chức năng RAG cho query: '{query}'")
|
298 |
+
retrieved_results = search_relevant_laws(query)
|
299 |
+
|
300 |
+
if isinstance(retrieved_results, str) or not retrieved_results:
|
301 |
+
context_for_llm = "Không tìm thấy thông tin luật liên quan."
|
302 |
+
context_for_display = context_for_llm
|
303 |
+
else:
|
304 |
+
# Định dạng context cho LLM
|
305 |
+
context_parts = []
|
306 |
+
for res in retrieved_results:
|
307 |
+
text = res.get('text', '')
|
308 |
+
context_parts.append(text)
|
309 |
+
context_for_llm = "\n\n---\n\n".join(context_parts)
|
310 |
+
|
311 |
+
# Định dạng context để hiển thị cho người dùng
|
312 |
+
context_for_display = run_retriever_only(query) # Tái sử dụng hàm retriever
|
313 |
+
|
314 |
+
# Bước 2: Sinh câu trả lời
|
315 |
+
progress(0.7, desc="Đã có ngữ cảnh, đang yêu cầu LLM tạo câu trả lời...")
|
316 |
+
final_answer = generate_llm_response(query, context_for_llm)
|
317 |
+
|
318 |
+
progress(1, desc="Hoàn tất!")
|
319 |
+
|
320 |
+
return final_answer, context_for_display
|
321 |
+
|
322 |
+
|
323 |
+
# --- PHẦN 5: KHỞI CHẠY ỨNG DỤNG GRADIO ---
|
324 |
+
|
325 |
+
# Tải tài nguyên ngay khi script được chạy
|
326 |
+
load_all_resources()
|
327 |
+
|
328 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Chatbot Luật Giao thông") as demo:
|
329 |
+
gr.Markdown(
|
330 |
+
"""
|
331 |
+
# ⚖️ Chatbot Luật Giao thông Việt Nam
|
332 |
+
Ứng dụng này sử dụng mô hình RAG (Retrieval-Augmented Generation) để trả lời các câu hỏi về luật giao thông.
|
333 |
+
"""
|
334 |
+
)
|
335 |
+
|
336 |
+
with gr.Tabs():
|
337 |
+
# Tab 1: Chỉ tìm kiếm
|
338 |
+
with gr.TabItem("Tìm kiếm Điều luật (Retriever)"):
|
339 |
+
with gr.Row():
|
340 |
+
retriever_query = gr.Textbox(label="Nhập nội dung cần tìm kiếm", placeholder="Ví dụ: Vượt đèn đỏ bị phạt bao nhiêu tiền?", scale=4)
|
341 |
+
retriever_button = gr.Button("Tìm kiếm", variant="secondary", scale=1)
|
342 |
+
retriever_output = gr.Markdown(label="Các điều luật liên quan")
|
343 |
+
|
344 |
+
# Tab 2: Hỏi-đáp RAG đầy đủ
|
345 |
+
with gr.TabItem("Hỏi-Đáp (RAG)"):
|
346 |
+
with gr.Row():
|
347 |
+
rag_query = gr.Textbox(label="Nhập câu hỏi của bạn", placeholder="Ví dụ: Phương tiện giao thông đường bộ gồm những loại nào?", scale=4)
|
348 |
+
rag_button = gr.Button("Gửi câu hỏi", variant="primary", scale=1)
|
349 |
+
rag_answer = gr.Textbox(label="Câu trả lời của Chatbot", lines=5)
|
350 |
+
with gr.Accordion("Xem ngữ cảnh đã sử dụng để tạo câu trả lời", open=False):
|
351 |
+
rag_context = gr.Markdown(label="Ngữ cảnh")
|
352 |
+
|
353 |
+
# Xử lý sự kiện click
|
354 |
+
retriever_button.click(fn=run_retriever_only, inputs=retriever_query, outputs=retriever_output)
|
355 |
+
rag_button.click(fn=run_full_rag, inputs=rag_query, outputs=[rag_answer, rag_context])
|
356 |
|
357 |
+
if __name__ == "__main__":
|
358 |
+
demo.launch(share=True) # share=True để tạo link public nếu chạy trên Colab/local
|