Spaces:
Paused
Paused
File size: 6,621 Bytes
a53f1d8 f998afd c69c2f9 a53f1d8 c69c2f9 a53f1d8 f998afd a53f1d8 268c26f f998afd a53f1d8 c69c2f9 a53f1d8 3264b15 43d4d74 3264b15 b8e326d 43d4d74 a53f1d8 cb8759a a53f1d8 3264b15 a53f1d8 43d4d74 3264b15 8829c2d 3264b15 43d4d74 8829c2d 43d4d74 f998afd 43d4d74 c69c2f9 a53f1d8 43d4d74 3264b15 c69c2f9 a53f1d8 3264b15 c69c2f9 a53f1d8 3264b15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# file: rag_pipeline.py
import torch
import json
import faiss
import numpy as np
import re
from unsloth import FastLanguageModel
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import TextStreamer
# Import các hàm từ file khác
from data_processor import process_law_data_to_chunks
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup
def initialize_components(data_path):
"""
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline.
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động.
"""
print("--- Bắt đầu khởi tạo các thành phần ---")
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth
print("1. Tải mô hình LLM (Unsloth)...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
max_seq_length=4096, # Có thể tăng cho các mô hình mới
dtype=None,
load_in_4bit=True,
)
print("✅ Tải LLM và Tokenizer thành công.")
# 2. Tải mô hình Embedding
print("2. Tải mô hình Embedding...")
embedding_model = SentenceTransformer(
"bkai-foundation-models/vietnamese-bi-encoder",
device="cuda" if torch.cuda.is_available() else "cpu"
)
print("✅ Tải mô hình Embedding thành công.")
# 3. Tải và xử lý dữ liệu JSON
print(f"3. Tải và xử lý dữ liệu từ {data_path}...")
with open(data_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
chunks_data = process_law_data_to_chunks(raw_data)
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.")
# 4. Tạo Embeddings và FAISS Index
print("4. Tạo embeddings và FAISS index...")
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
chunk_embeddings_tensor = embedding_model.encode(
texts_to_encode,
convert_to_tensor=True,
device=embedding_model.device
)
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32')
faiss.normalize_L2(chunk_embeddings_np)
dimension = chunk_embeddings_np.shape[1]
faiss_index = faiss.IndexFlatIP(dimension)
faiss_index.add(chunk_embeddings_np)
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.")
# 5. Tạo BM25 Model
print("5. Tạo mô hình BM25...")
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
bm25_model = BM25Okapi(tokenized_corpus_bm25)
print("✅ Tạo mô hình BM25 thành công.")
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---")
return {
"llm_model": model,
"tokenizer": tokenizer,
"embedding_model": embedding_model,
"chunks_data": chunks_data,
"faiss_index": faiss_index,
"bm25_model": bm25_model
}
def generate_response(query: str, components: dict) -> str:
"""
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
Phiên bản này được sửa đổi để tương thích với các mô hình Vision (đa phương thức)
bằng cách sử dụng chat template.
"""
print("--- Bắt đầu quy trình RAG cho query mới ---")
# --- Bước 1: Truy xuất Ngữ cảnh (Không thay đổi) ---
# (Hàm này giả định bạn đã có phiên bản retriever.py đã sửa lỗi logic boosting)
retrieved_results = search_relevant_laws(
query_text=query,
embedding_model=components["embedding_model"],
faiss_index=components["faiss_index"],
chunks_data=components["chunks_data"],
bm25_model=components["bm25_model"],
k=5,
initial_k_multiplier=15
)
# --- Bước 2: Định dạng Ngữ cảnh (Không thay đổi, nhưng nên tách ra hàm riêng) ---
if not retrieved_results:
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
else:
context_parts = []
for i, res in enumerate(retrieved_results):
metadata = res.get('metadata', {})
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
text = res.get('text', '*Nội dung không có*')
context_parts.append(f"{header}\n{text}")
context = "\n\n---\n\n".join(context_parts)
# --- Bước 3: Xây dựng Prompt bằng Chat Template (Đây là phần thay đổi chính) ---
print("--- Xây dựng prompt bằng chat template ---")
llm_model = components["llm_model"]
tokenizer = components["tokenizer"]
# Tạo cấu trúc tin nhắn theo chuẩn của các mô hình hội thoại
messages = [
{
"role": "system",
"content": "Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam. Hãy dựa vào các thông tin được cung cấp để trả lời một cách chính xác và dễ hiểu."
},
{
"role": "user",
"content": f"""Dựa vào các trích dẫn luật dưới đây:
### Thông tin luật:
{context}
### Câu hỏi:
{query}
"""
}
]
# Sử dụng apply_chat_template để tạo prompt hoàn chỉnh.
# Thêm `add_generation_prompt=True` để nó tự động thêm vai trò "assistant" ở cuối,
# báo hiệu cho model bắt đầu sinh câu trả lời.
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# --- Bước 4: Tạo câu trả lời từ LLM (Không thay đổi) ---
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
generation_config = dict(
max_new_tokens=256,
temperature=0.1,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
output_ids = llm_model.generate(**inputs, **generation_config)
response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
print("--- Tạo câu trả lời hoàn tất ---")
return response_text |