Spaces:
Paused
Paused
| # file: rag_pipeline.py | |
| import torch | |
| import json | |
| import faiss | |
| import numpy as np | |
| import re | |
| from unsloth import FastLanguageModel | |
| from sentence_transformers import SentenceTransformer | |
| from rank_bm25 import BM25Okapi | |
| from transformers import TextStreamer | |
| # Import các hàm từ file khác | |
| from data_processor import process_law_data_to_chunks | |
| from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup | |
| def initialize_components(data_path): | |
| """ | |
| Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline. | |
| Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động. | |
| """ | |
| print("--- Bắt đầu khởi tạo các thành phần ---") | |
| # 1. Tải LLM và Tokenizer từ Unsloth | |
| print("1. Tải mô hình LLM (Unsloth)...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name="unsloth/gemma-3-4b-it-unsloth-bnb-4bit", | |
| max_seq_length=2048, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| print("✅ Tải LLM và Tokenizer thành công.") | |
| # 2. Tải mô hình Embedding | |
| print("2. Tải mô hình Embedding...") | |
| embedding_model = SentenceTransformer( | |
| "bkai-foundation-models/vietnamese-bi-encoder", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("✅ Tải mô hình Embedding thành công.") | |
| # 3. Tải và xử lý dữ liệu JSON | |
| print(f"3. Tải và xử lý dữ liệu từ {data_path}...") | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| raw_data = json.load(f) | |
| chunks_data = process_law_data_to_chunks(raw_data) | |
| print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.") | |
| # 4. Tạo Embeddings và FAISS Index | |
| print("4. Tạo embeddings và FAISS index...") | |
| texts_to_encode = [chunk.get('text', '') for chunk in chunks_data] | |
| chunk_embeddings_tensor = embedding_model.encode( | |
| texts_to_encode, | |
| convert_to_tensor=True, | |
| device=embedding_model.device | |
| ) | |
| chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32') | |
| faiss.normalize_L2(chunk_embeddings_np) | |
| dimension = chunk_embeddings_np.shape[1] | |
| faiss_index = faiss.IndexFlatIP(dimension) | |
| faiss_index.add(chunk_embeddings_np) | |
| print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.") | |
| # 5. Tạo BM25 Model | |
| print("5. Tạo mô hình BM25...") | |
| corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] | |
| tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] | |
| bm25_model = BM25Okapi(tokenized_corpus_bm25) | |
| print("✅ Tạo mô hình BM25 thành công.") | |
| print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---") | |
| return { | |
| "llm_model": model, | |
| "tokenizer": tokenizer, | |
| "embedding_model": embedding_model, | |
| "chunks_data": chunks_data, | |
| "faiss_index": faiss_index, | |
| "bm25_model": bm25_model | |
| } | |
| def generate_response(query, components): | |
| """ | |
| Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo. | |
| """ | |
| print("--- Bắt đầu quy trình RAG cho query mới ---") | |
| # Unpack các thành phần | |
| llm_model = components["llm_model"] | |
| tokenizer = components["tokenizer"] | |
| # 1. Truy xuất ngữ cảnh | |
| retrieved_results = search_relevant_laws( | |
| query_text=query, | |
| embedding_model=components["embedding_model"], | |
| faiss_index=components["faiss_index"], | |
| chunks_data=components["chunks_data"], | |
| bm25_model=components["bm25_model"], | |
| k=5, | |
| initial_k_multiplier=18 | |
| ) | |
| # 2. Định dạng Context | |
| if not retrieved_results: | |
| context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu." | |
| else: | |
| context_parts = [] | |
| for i, res in enumerate(retrieved_results): | |
| metadata = res.get('metadata', {}) | |
| header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})" | |
| text = res.get('text', '*Nội dung không có*') | |
| context_parts.append(f"{header}\n{text}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| # 3. Xây dựng Prompt và tạo câu trả lời | |
| prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam. | |
| Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ. | |
| Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất. | |
| ### Thông tin luật: | |
| {context} | |
| ### Câu hỏi: | |
| {query} | |
| ### Trả lời:""" | |
| print("--- Bắt đầu tạo câu trả lời từ LLM ---") | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
| generation_config = dict( | |
| max_new_tokens=256, | |
| temperature=0.5, | |
| top_p=0.7, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| output_ids = llm_model.generate(**inputs, **generation_config) | |
| input_length = inputs.input_ids.shape[1] | |
| generated_ids = output_ids[0][input_length:] | |
| response_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| print("--- Tạo câu trả lời hoàn tất ---") | |
| return response_text |