Spaces:
Paused
Paused
# file: rag_pipeline.py | |
import torch | |
import json | |
import faiss | |
import numpy as np | |
import re | |
from unsloth import FastLanguageModel | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import TextStreamer | |
# Import các hàm từ file khác | |
from data_processor import process_law_data_to_chunks | |
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup | |
def initialize_components(data_path): | |
""" | |
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline. | |
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động. | |
""" | |
print("--- Bắt đầu khởi tạo các thành phần ---") | |
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth | |
print("1. Tải mô hình LLM (Unsloth)...") | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit", | |
max_seq_length=4096, # Có thể tăng cho các mô hình mới | |
dtype=None, | |
load_in_4bit=True, | |
) | |
print("✅ Tải LLM và Tokenizer thành công.") | |
# 2. Tải mô hình Embedding | |
print("2. Tải mô hình Embedding...") | |
embedding_model = SentenceTransformer( | |
"bkai-foundation-models/vietnamese-bi-encoder", | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
print("✅ Tải mô hình Embedding thành công.") | |
# 3. Tải và xử lý dữ liệu JSON | |
print(f"3. Tải và xử lý dữ liệu từ {data_path}...") | |
with open(data_path, 'r', encoding='utf-8') as f: | |
raw_data = json.load(f) | |
chunks_data = process_law_data_to_chunks(raw_data) | |
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.") | |
# 4. Tạo Embeddings và FAISS Index | |
print("4. Tạo embeddings và FAISS index...") | |
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data] | |
chunk_embeddings_tensor = embedding_model.encode( | |
texts_to_encode, | |
convert_to_tensor=True, | |
device=embedding_model.device | |
) | |
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32') | |
faiss.normalize_L2(chunk_embeddings_np) | |
dimension = chunk_embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatIP(dimension) | |
faiss_index.add(chunk_embeddings_np) | |
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.") | |
# 5. Tạo BM25 Model | |
print("5. Tạo mô hình BM25...") | |
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] | |
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] | |
bm25_model = BM25Okapi(tokenized_corpus_bm25) | |
print("✅ Tạo mô hình BM25 thành công.") | |
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---") | |
return { | |
"llm_model": model, | |
"tokenizer": tokenizer, | |
"embedding_model": embedding_model, | |
"chunks_data": chunks_data, | |
"faiss_index": faiss_index, | |
"bm25_model": bm25_model | |
} | |
def generate_response(query: str, components: dict) -> str: | |
""" | |
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo. | |
Phiên bản cuối cùng, sửa lỗi ValueError cho mô hình Vision bằng cách | |
sử dụng apply_chat_template để tokenization trực tiếp. | |
""" | |
print("--- Bắt đầu quy trình RAG cho query mới ---") | |
# --- Bước 1: Truy xuất Ngữ cảnh --- | |
retrieved_results = search_relevant_laws( | |
query_text=query, | |
embedding_model=components["embedding_model"], | |
faiss_index=components["faiss_index"], | |
chunks_data=components["chunks_data"], | |
bm25_model=components["bm25_model"], | |
k=5, | |
initial_k_multiplier=15 | |
) | |
# --- Bước 2: Định dạng Ngữ cảnh --- | |
if not retrieved_results: | |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu." | |
else: | |
context_parts = [] | |
for i, res in enumerate(retrieved_results): | |
metadata = res.get('metadata', {}) | |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})" | |
text = res.get('text', '*Nội dung không có*') | |
context_parts.append(f"{header}\n{text}") | |
context = "\n\n---\n\n".join(context_parts) | |
# --- Bước 3: Chuẩn bị Dữ liệu và Tokenize bằng Chat Template (Phần sửa lỗi cốt lõi) --- | |
print("--- Chuẩn bị và tokenize prompt bằng chat template ---") | |
llm_model = components["llm_model"] | |
tokenizer = components["tokenizer"] | |
# Tạo cấu trúc tin nhắn theo chuẩn | |
messages = [ | |
{ | |
"role": "system", | |
"content": [ | |
{"type": "text", "text": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam..."} | |
] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"""Dựa vào các trích dẫn luật dưới đây: | |
### Thông tin luật: | |
{context} | |
### Câu hỏi: | |
{query} | |
"""} | |
] | |
} | |
] | |
# SỬA LỖI: Dùng apply_chat_template để tokenize trực tiếp | |
# Nó sẽ tự động định dạng và chuyển thành tensor, tương thích với mô hình Vision | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(llm_model.device) | |
# --- Bước 4: Tạo câu trả lời từ LLM --- | |
print("--- Bắt đầu tạo câu trả lời từ LLM ---") | |
generation_config = dict( | |
max_new_tokens=256, | |
temperature=0.1, | |
repetition_penalty=1.1, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
output_ids = llm_model.generate(inputs, **generation_config) | |
# Decode như cũ, nhưng đầu vào là `inputs` thay vì `inputs.input_ids` | |
response_text = tokenizer.decode(output_ids[0][inputs.shape[1]:], skip_special_tokens=True) | |
print("--- Tạo câu trả lời hoàn tất ---") | |
return response_text |