Spaces:
Paused
Paused
update
Browse files- rag_pipeline.py +15 -8
rag_pipeline.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# file: rag_pipeline.py
|
|
|
2 |
import torch
|
3 |
import json
|
4 |
import faiss
|
@@ -20,11 +21,11 @@ def initialize_components(data_path):
|
|
20 |
"""
|
21 |
print("--- Bắt đầu khởi tạo các thành phần ---")
|
22 |
|
23 |
-
# 1. Tải LLM và Tokenizer từ Unsloth
|
24 |
print("1. Tải mô hình LLM (Unsloth)...")
|
25 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
26 |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
|
27 |
-
max_seq_length=
|
28 |
dtype=None,
|
29 |
load_in_4bit=True,
|
30 |
)
|
@@ -81,15 +82,15 @@ def initialize_components(data_path):
|
|
81 |
|
82 |
def generate_response(query, components):
|
83 |
"""
|
84 |
-
Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo.
|
85 |
"""
|
86 |
-
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
87 |
|
88 |
# Unpack các thành phần
|
89 |
llm_model = components["llm_model"]
|
90 |
tokenizer = components["tokenizer"]
|
91 |
|
92 |
-
# 1. Truy xuất ngữ cảnh
|
93 |
retrieved_results = search_relevant_laws(
|
94 |
query_text=query,
|
95 |
embedding_model=components["embedding_model"],
|
@@ -112,7 +113,7 @@ def generate_response(query, components):
|
|
112 |
context_parts.append(f"{header}\n{text}")
|
113 |
context = "\n\n---\n\n".join(context_parts)
|
114 |
|
115 |
-
# 3. Xây dựng Prompt
|
116 |
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
|
117 |
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
|
118 |
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
|
@@ -126,7 +127,13 @@ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhấ
|
|
126 |
### Trả lời:"""
|
127 |
|
128 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
generation_config = dict(
|
132 |
max_new_tokens=256,
|
@@ -145,4 +152,4 @@ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhấ
|
|
145 |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
146 |
|
147 |
print("--- Tạo câu trả lời hoàn tất ---")
|
148 |
-
return response_text
|
|
|
1 |
# file: rag_pipeline.py
|
2 |
+
|
3 |
import torch
|
4 |
import json
|
5 |
import faiss
|
|
|
21 |
"""
|
22 |
print("--- Bắt đầu khởi tạo các thành phần ---")
|
23 |
|
24 |
+
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth
|
25 |
print("1. Tải mô hình LLM (Unsloth)...")
|
26 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
27 |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
|
28 |
+
max_seq_length=4096, # Có thể tăng cho các mô hình mới
|
29 |
dtype=None,
|
30 |
load_in_4bit=True,
|
31 |
)
|
|
|
82 |
|
83 |
def generate_response(query, components):
|
84 |
"""
|
85 |
+
Tạo câu trả lời cho một query (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
86 |
"""
|
87 |
+
print("--- Bắt đầu quy trình RAG (Single-turn) cho query mới ---")
|
88 |
|
89 |
# Unpack các thành phần
|
90 |
llm_model = components["llm_model"]
|
91 |
tokenizer = components["tokenizer"]
|
92 |
|
93 |
+
# 1. Truy xuất ngữ cảnh trực tiếp từ câu hỏi của người dùng
|
94 |
retrieved_results = search_relevant_laws(
|
95 |
query_text=query,
|
96 |
embedding_model=components["embedding_model"],
|
|
|
113 |
context_parts.append(f"{header}\n{text}")
|
114 |
context = "\n\n---\n\n".join(context_parts)
|
115 |
|
116 |
+
# 3. Xây dựng Prompt đơn giản (không có lịch sử trò chuyện)
|
117 |
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
|
118 |
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
|
119 |
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
|
|
|
127 |
### Trả lời:"""
|
128 |
|
129 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
130 |
+
|
131 |
+
# SỬA LỖI CHO VISION MODEL: Sử dụng API tường minh
|
132 |
+
inputs = tokenizer(
|
133 |
+
text=prompt,
|
134 |
+
images=None,
|
135 |
+
return_tensors="pt"
|
136 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
137 |
|
138 |
generation_config = dict(
|
139 |
max_new_tokens=256,
|
|
|
152 |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
153 |
|
154 |
print("--- Tạo câu trả lời hoàn tất ---")
|
155 |
+
return response_text
|