Spaces:
Paused
Paused
# file: rag_pipeline.py | |
import torch | |
import json | |
import faiss | |
import numpy as np | |
import re | |
from unsloth import FastLanguageModel | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import TextStreamer | |
# Import các hàm từ file khác | |
from data_processor import process_law_data_to_chunks | |
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup | |
def initialize_components(data_path): | |
""" | |
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline. | |
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động. | |
""" | |
print("--- Bắt đầu khởi tạo các thành phần ---") | |
# 1. Tải LLM và Tokenizer từ Unsloth | |
print("1. Tải mô hình LLM (Unsloth)...") | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="unsloth/gemma-3-4b-it-unsloth-bnb-4bit", | |
max_seq_length=2048, | |
dtype=None, | |
load_in_4bit=True, | |
) | |
print("✅ Tải LLM và Tokenizer thành công.") | |
# 2. Tải mô hình Embedding | |
print("2. Tải mô hình Embedding...") | |
embedding_model = SentenceTransformer( | |
"bkai-foundation-models/vietnamese-bi-encoder", | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
print("✅ Tải mô hình Embedding thành công.") | |
# 3. Tải và xử lý dữ liệu JSON | |
print(f"3. Tải và xử lý dữ liệu từ {data_path}...") | |
with open(data_path, 'r', encoding='utf-8') as f: | |
raw_data = json.load(f) | |
chunks_data = process_law_data_to_chunks(raw_data) | |
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.") | |
# 4. Tạo Embeddings và FAISS Index | |
print("4. Tạo embeddings và FAISS index...") | |
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data] | |
chunk_embeddings_tensor = embedding_model.encode( | |
texts_to_encode, | |
convert_to_tensor=True, | |
device=embedding_model.device | |
) | |
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32') | |
faiss.normalize_L2(chunk_embeddings_np) | |
dimension = chunk_embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatIP(dimension) | |
faiss_index.add(chunk_embeddings_np) | |
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.") | |
# 5. Tạo BM25 Model | |
print("5. Tạo mô hình BM25...") | |
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] | |
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] | |
bm25_model = BM25Okapi(tokenized_corpus_bm25) | |
print("✅ Tạo mô hình BM25 thành công.") | |
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---") | |
return { | |
"llm_model": model, | |
"tokenizer": tokenizer, | |
"embedding_model": embedding_model, | |
"chunks_data": chunks_data, | |
"faiss_index": faiss_index, | |
"bm25_model": bm25_model | |
} | |
def generate_response(query, components): | |
""" | |
Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo. | |
""" | |
print("--- Bắt đầu quy trình RAG cho query mới ---") | |
# Unpack các thành phần | |
llm_model = components["llm_model"] | |
tokenizer = components["tokenizer"] | |
# 1. Truy xuất ngữ cảnh | |
retrieved_results = search_relevant_laws( | |
query_text=query, | |
embedding_model=components["embedding_model"], | |
faiss_index=components["faiss_index"], | |
chunks_data=components["chunks_data"], | |
bm25_model=components["bm25_model"], | |
k=5, | |
initial_k_multiplier=18 | |
) | |
# 2. Định dạng Context | |
if not retrieved_results: | |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu." | |
else: | |
context_parts = [] | |
for i, res in enumerate(retrieved_results): | |
metadata = res.get('metadata', {}) | |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})" | |
text = res.get('text', '*Nội dung không có*') | |
context_parts.append(f"{header}\n{text}") | |
context = "\n\n---\n\n".join(context_parts) | |
# 3. Xây dựng Prompt và tạo câu trả lời | |
prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam. | |
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ. | |
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất. | |
### Thông tin luật: | |
{context} | |
### Câu hỏi: | |
{query} | |
### Trả lời:""" | |
print("--- Bắt đầu tạo câu trả lời từ LLM ---") | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
generation_config = dict( | |
max_new_tokens=256, | |
temperature=0.5, | |
top_p=0.7, | |
top_k=50, | |
repetition_penalty=1.1, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
output_ids = llm_model.generate(**inputs, **generation_config) | |
input_length = inputs.input_ids.shape[1] | |
generated_ids = output_ids[0][input_length:] | |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
print("--- Tạo câu trả lời hoàn tất ---") | |
return response_text |