File size: 5,613 Bytes
a53f1d8
c69c2f9
a53f1d8
 
 
 
c69c2f9
a53f1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2f7dde
a53f1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c69c2f9
a53f1d8
 
 
 
 
 
 
 
 
 
 
cb8759a
8829c2d
cb8759a
8829c2d
cb8759a
a53f1d8
 
 
 
b8e326d
cb8759a
a53f1d8
cb8759a
a53f1d8
 
 
 
cb8759a
a53f1d8
 
 
cb8759a
a53f1d8
 
 
 
 
 
 
 
 
 
 
cb8759a
 
 
 
8829c2d
cb8759a
8829c2d
 
cb8759a
8829c2d
 
cb8759a
c69c2f9
cb8759a
a53f1d8
c69c2f9
 
a53f1d8
 
 
 
 
b8e326d
a53f1d8
 
c69c2f9
a53f1d8
 
 
 
 
c69c2f9
a53f1d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# file: rag_pipeline.py
import torch
import json
import faiss
import numpy as np
import re
from unsloth import FastLanguageModel
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import TextStreamer

# Import các hàm từ file khác
from data_processor import process_law_data_to_chunks
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup

def initialize_components(data_path):
    """
    Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline.
    Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động.
    """
    print("--- Bắt đầu khởi tạo các thành phần ---")

    # 1. Tải LLM và Tokenizer từ Unsloth
    print("1. Tải mô hình LLM (Unsloth)...")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
        max_seq_length=2048,
        dtype=None,
        load_in_4bit=True,
    )
    print("✅ Tải LLM và Tokenizer thành công.")

    # 2. Tải mô hình Embedding
    print("2. Tải mô hình Embedding...")
    embedding_model = SentenceTransformer(
        "bkai-foundation-models/vietnamese-bi-encoder",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    print("✅ Tải mô hình Embedding thành công.")

    # 3. Tải và xử lý dữ liệu JSON
    print(f"3. Tải và xử lý dữ liệu từ {data_path}...")
    with open(data_path, 'r', encoding='utf-8') as f:
        raw_data = json.load(f)
    chunks_data = process_law_data_to_chunks(raw_data)
    print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.")

    # 4. Tạo Embeddings và FAISS Index
    print("4. Tạo embeddings và FAISS index...")
    texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
    chunk_embeddings_tensor = embedding_model.encode(
        texts_to_encode,
        convert_to_tensor=True,
        device=embedding_model.device
    )
    chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32')
    faiss.normalize_L2(chunk_embeddings_np)
    
    dimension = chunk_embeddings_np.shape[1]
    faiss_index = faiss.IndexFlatIP(dimension)
    faiss_index.add(chunk_embeddings_np)
    print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.")

    # 5. Tạo BM25 Model
    print("5. Tạo mô hình BM25...")
    corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
    tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
    bm25_model = BM25Okapi(tokenized_corpus_bm25)
    print("✅ Tạo mô hình BM25 thành công.")

    print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---")
    
    return {
        "llm_model": model,
        "tokenizer": tokenizer,
        "embedding_model": embedding_model,
        "chunks_data": chunks_data,
        "faiss_index": faiss_index,
        "bm25_model": bm25_model
    }

def generate_response(query, components):
    """
    Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo.
    """
    print("--- Bắt đầu quy trình RAG cho query mới ---")
    
    # Unpack các thành phần
    llm_model = components["llm_model"]
    tokenizer = components["tokenizer"]
    
    # 1. Truy xuất ngữ cảnh
    retrieved_results = search_relevant_laws(
        query_text=query,
        embedding_model=components["embedding_model"],
        faiss_index=components["faiss_index"],
        chunks_data=components["chunks_data"],
        bm25_model=components["bm25_model"],
        k=5,
        initial_k_multiplier=18
    )

    # 2. Định dạng Context
    if not retrieved_results:
        context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
    else:
        context_parts = []
        for i, res in enumerate(retrieved_results):
            metadata = res.get('metadata', {})
            header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
            text = res.get('text', '*Nội dung không có*')
            context_parts.append(f"{header}\n{text}")
        context = "\n\n---\n\n".join(context_parts)

    # 3. Xây dựng Prompt và tạo câu trả lời
    prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.

### Thông tin luật:
{context}

### Câu hỏi:
{query}

### Trả lời:"""

    print("--- Bắt đầu tạo câu trả lời từ LLM ---")
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    generation_config = dict(
        max_new_tokens=256,
        temperature=0.5,
        top_p=0.7,
        top_k=50,
        repetition_penalty=1.1,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    
    output_ids = llm_model.generate(**inputs, **generation_config)
    input_length = inputs.input_ids.shape[1]
    generated_ids = output_ids[0][input_length:]
    response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

    print("--- Tạo câu trả lời hoàn tất ---")
    return response_text