Spaces:
Paused
Paused
# file: rag_pipeline.py | |
import torch | |
import json | |
import faiss | |
import numpy as np | |
import re | |
from unsloth import FastLanguageModel | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import TextStreamer | |
# Import các hàm từ file khác | |
from data_processor import process_law_data_to_chunks | |
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup | |
def initialize_components(data_path): | |
""" | |
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline. | |
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động. | |
""" | |
print("--- Bắt đầu khởi tạo các thành phần ---") | |
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth | |
print("1. Tải mô hình LLM (Unsloth)...") | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit", | |
max_seq_length=4096, # Có thể tăng cho các mô hình mới | |
dtype=None, | |
load_in_4bit=True, | |
) | |
print("✅ Tải LLM và Tokenizer thành công.") | |
# 2. Tải mô hình Embedding | |
print("2. Tải mô hình Embedding...") | |
embedding_model = SentenceTransformer( | |
"bkai-foundation-models/vietnamese-bi-encoder", | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
print("✅ Tải mô hình Embedding thành công.") | |
# 3. Tải và xử lý dữ liệu JSON | |
print(f"3. Tải và xử lý dữ liệu từ {data_path}...") | |
with open(data_path, 'r', encoding='utf-8') as f: | |
raw_data = json.load(f) | |
chunks_data = process_law_data_to_chunks(raw_data) | |
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.") | |
# 4. Tạo Embeddings và FAISS Index | |
print("4. Tạo embeddings và FAISS index...") | |
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data] | |
chunk_embeddings_tensor = embedding_model.encode( | |
texts_to_encode, | |
convert_to_tensor=True, | |
device=embedding_model.device | |
) | |
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32') | |
faiss.normalize_L2(chunk_embeddings_np) | |
dimension = chunk_embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatIP(dimension) | |
faiss_index.add(chunk_embeddings_np) | |
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.") | |
# 5. Tạo BM25 Model | |
print("5. Tạo mô hình BM25...") | |
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] | |
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] | |
bm25_model = BM25Okapi(tokenized_corpus_bm25) | |
print("✅ Tạo mô hình BM25 thành công.") | |
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---") | |
return { | |
"llm_model": model, | |
"tokenizer": tokenizer, | |
"embedding_model": embedding_model, | |
"chunks_data": chunks_data, | |
"faiss_index": faiss_index, | |
"bm25_model": bm25_model | |
} | |
def generate_response(query: str, components: dict) -> str: | |
""" | |
Tạo câu trả lời (single-turn). | |
Phiên bản đơn giản hóa, không có logic vehicle_type. | |
""" | |
print("--- Bắt đầu quy trình RAG cho query mới ---") | |
# === THAY ĐỔI 1: Chỉ nhận 1 giá trị trả về === | |
# 1. Truy xuất ngữ cảnh | |
retrieved_results = search_relevant_laws( | |
query_text=query, | |
embedding_model=components["embedding_model"], | |
faiss_index=components["faiss_index"], | |
chunks_data=components["chunks_data"], | |
bm25_model=components["bm25_model"], | |
k=5, | |
initial_k_multiplier=15 | |
) | |
# === THAY ĐỔI 2: Loại bỏ logic vehicle_type trong context === | |
# 2. Định dạng Context | |
if not retrieved_results: | |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu." | |
else: | |
context_parts = [] | |
for i, res in enumerate(retrieved_results): | |
metadata = res.get('metadata', {}) | |
# Tạo header đơn giản, không có gợi ý | |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})" | |
text = res.get('text', '*Nội dung không có*') | |
context_parts.append(f"{header}\n{text}") | |
context = "\n\n---\n\n".join(context_parts) | |
# 3. Xây dựng Prompt bằng Chat Template (giữ nguyên logic tương thích Vision) | |
print("--- Xây dựng prompt bằng chat template ---") | |
llm_model = components["llm_model"] | |
tokenizer = components["tokenizer"] | |
messages = [ | |
{ | |
"role": "system", | |
"content": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam. Hãy dựa vào các thông tin được cung cấp để trả lời một cách chính xác và dễ hiểu." | |
}, | |
{ | |
"role": "user", | |
"content": f"""Dựa vào các trích dẫn luật dưới đây: | |
### Thông tin luật: | |
{context} | |
### Câu hỏi: | |
{query} | |
""" | |
} | |
] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# 4. Tạo câu trả lời từ LLM | |
print("--- Bắt đầu tạo câu trả lời từ LLM ---") | |
inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device) | |
generation_config = dict( | |
max_new_tokens=256, | |
temperature=0.1, | |
repetition_penalty=1.1, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
output_ids = llm_model.generate(**inputs, **generation_config) | |
response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
print("--- Tạo câu trả lời hoàn tất ---") | |
return response_text |