Spaces:
Paused
Paused
update
Browse files- rag_pipeline.py +15 -8
rag_pipeline.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
# file: rag_pipeline.py
|
|
|
|
| 2 |
import torch
|
| 3 |
import json
|
| 4 |
import faiss
|
|
@@ -20,11 +21,11 @@ def initialize_components(data_path):
|
|
| 20 |
"""
|
| 21 |
print("--- Bắt đầu khởi tạo các thành phần ---")
|
| 22 |
|
| 23 |
-
# 1. Tải LLM và Tokenizer từ Unsloth
|
| 24 |
print("1. Tải mô hình LLM (Unsloth)...")
|
| 25 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 26 |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
|
| 27 |
-
max_seq_length=
|
| 28 |
dtype=None,
|
| 29 |
load_in_4bit=True,
|
| 30 |
)
|
|
@@ -81,15 +82,15 @@ def initialize_components(data_path):
|
|
| 81 |
|
| 82 |
def generate_response(query, components):
|
| 83 |
"""
|
| 84 |
-
Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo.
|
| 85 |
"""
|
| 86 |
-
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
| 87 |
|
| 88 |
# Unpack các thành phần
|
| 89 |
llm_model = components["llm_model"]
|
| 90 |
tokenizer = components["tokenizer"]
|
| 91 |
|
| 92 |
-
# 1. Truy xuất ngữ cảnh
|
| 93 |
retrieved_results = search_relevant_laws(
|
| 94 |
query_text=query,
|
| 95 |
embedding_model=components["embedding_model"],
|
|
@@ -112,7 +113,7 @@ def generate_response(query, components):
|
|
| 112 |
context_parts.append(f"{header}\n{text}")
|
| 113 |
context = "\n\n---\n\n".join(context_parts)
|
| 114 |
|
| 115 |
-
# 3. Xây dựng Prompt
|
| 116 |
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
|
| 117 |
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
|
| 118 |
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
|
|
@@ -126,7 +127,13 @@ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhấ
|
|
| 126 |
### Trả lời:"""
|
| 127 |
|
| 128 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
generation_config = dict(
|
| 132 |
max_new_tokens=256,
|
|
@@ -145,4 +152,4 @@ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhấ
|
|
| 145 |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 146 |
|
| 147 |
print("--- Tạo câu trả lời hoàn tất ---")
|
| 148 |
-
return response_text
|
|
|
|
| 1 |
# file: rag_pipeline.py
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import json
|
| 5 |
import faiss
|
|
|
|
| 21 |
"""
|
| 22 |
print("--- Bắt đầu khởi tạo các thành phần ---")
|
| 23 |
|
| 24 |
+
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth
|
| 25 |
print("1. Tải mô hình LLM (Unsloth)...")
|
| 26 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 27 |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
|
| 28 |
+
max_seq_length=4096, # Có thể tăng cho các mô hình mới
|
| 29 |
dtype=None,
|
| 30 |
load_in_4bit=True,
|
| 31 |
)
|
|
|
|
| 82 |
|
| 83 |
def generate_response(query, components):
|
| 84 |
"""
|
| 85 |
+
Tạo câu trả lời cho một query (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
| 86 |
"""
|
| 87 |
+
print("--- Bắt đầu quy trình RAG (Single-turn) cho query mới ---")
|
| 88 |
|
| 89 |
# Unpack các thành phần
|
| 90 |
llm_model = components["llm_model"]
|
| 91 |
tokenizer = components["tokenizer"]
|
| 92 |
|
| 93 |
+
# 1. Truy xuất ngữ cảnh trực tiếp từ câu hỏi của người dùng
|
| 94 |
retrieved_results = search_relevant_laws(
|
| 95 |
query_text=query,
|
| 96 |
embedding_model=components["embedding_model"],
|
|
|
|
| 113 |
context_parts.append(f"{header}\n{text}")
|
| 114 |
context = "\n\n---\n\n".join(context_parts)
|
| 115 |
|
| 116 |
+
# 3. Xây dựng Prompt đơn giản (không có lịch sử trò chuyện)
|
| 117 |
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
|
| 118 |
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
|
| 119 |
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
|
|
|
|
| 127 |
### Trả lời:"""
|
| 128 |
|
| 129 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
| 130 |
+
|
| 131 |
+
# SỬA LỖI CHO VISION MODEL: Sử dụng API tường minh
|
| 132 |
+
inputs = tokenizer(
|
| 133 |
+
text=prompt,
|
| 134 |
+
images=None,
|
| 135 |
+
return_tensors="pt"
|
| 136 |
+
).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 137 |
|
| 138 |
generation_config = dict(
|
| 139 |
max_new_tokens=256,
|
|
|
|
| 152 |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 153 |
|
| 154 |
print("--- Tạo câu trả lời hoàn tất ---")
|
| 155 |
+
return response_text
|