deddoggo commited on
Commit
f998afd
·
1 Parent(s): 268c26f
Files changed (1) hide show
  1. rag_pipeline.py +15 -8
rag_pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  # file: rag_pipeline.py
 
2
  import torch
3
  import json
4
  import faiss
@@ -20,11 +21,11 @@ def initialize_components(data_path):
20
  """
21
  print("--- Bắt đầu khởi tạo các thành phần ---")
22
 
23
- # 1. Tải LLM và Tokenizer từ Unsloth
24
  print("1. Tải mô hình LLM (Unsloth)...")
25
  model, tokenizer = FastLanguageModel.from_pretrained(
26
  model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
27
- max_seq_length=2048,
28
  dtype=None,
29
  load_in_4bit=True,
30
  )
@@ -81,15 +82,15 @@ def initialize_components(data_path):
81
 
82
  def generate_response(query, components):
83
  """
84
- Tạo câu trả lời cho một query bằng cách sử dụng các thành phần đã được khởi tạo.
85
  """
86
- print("--- Bắt đầu quy trình RAG cho query mới ---")
87
 
88
  # Unpack các thành phần
89
  llm_model = components["llm_model"]
90
  tokenizer = components["tokenizer"]
91
 
92
- # 1. Truy xuất ngữ cảnh
93
  retrieved_results = search_relevant_laws(
94
  query_text=query,
95
  embedding_model=components["embedding_model"],
@@ -112,7 +113,7 @@ def generate_response(query, components):
112
  context_parts.append(f"{header}\n{text}")
113
  context = "\n\n---\n\n".join(context_parts)
114
 
115
- # 3. Xây dựng Prompt tạo câu trả lời
116
  prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
117
  Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
118
  Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
@@ -126,7 +127,13 @@ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhấ
126
  ### Trả lời:"""
127
 
128
  print("--- Bắt đầu tạo câu trả lời từ LLM ---")
129
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
130
 
131
  generation_config = dict(
132
  max_new_tokens=256,
@@ -145,4 +152,4 @@ Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhấ
145
  response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
146
 
147
  print("--- Tạo câu trả lời hoàn tất ---")
148
- return response_text
 
1
  # file: rag_pipeline.py
2
+
3
  import torch
4
  import json
5
  import faiss
 
21
  """
22
  print("--- Bắt đầu khởi tạo các thành phần ---")
23
 
24
+ # 1. Tải LLM và Tokenizer (Processor) từ Unsloth
25
  print("1. Tải mô hình LLM (Unsloth)...")
26
  model, tokenizer = FastLanguageModel.from_pretrained(
27
  model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
28
+ max_seq_length=4096, # Có thể tăng cho các mô hình mới
29
  dtype=None,
30
  load_in_4bit=True,
31
  )
 
82
 
83
  def generate_response(query, components):
84
  """
85
+ Tạo câu trả lời cho một query (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
86
  """
87
+ print("--- Bắt đầu quy trình RAG (Single-turn) cho query mới ---")
88
 
89
  # Unpack các thành phần
90
  llm_model = components["llm_model"]
91
  tokenizer = components["tokenizer"]
92
 
93
+ # 1. Truy xuất ngữ cảnh trực tiếp từ câu hỏi của người dùng
94
  retrieved_results = search_relevant_laws(
95
  query_text=query,
96
  embedding_model=components["embedding_model"],
 
113
  context_parts.append(f"{header}\n{text}")
114
  context = "\n\n---\n\n".join(context_parts)
115
 
116
+ # 3. Xây dựng Prompt đơn giản (không lịch sử trò chuyện)
117
  prompt = f"""Dưới đây là một số thông tin trích dẫn từ văn bản luật giao thông đường bộ Việt Nam.
118
  Hãy SỬ DỤNG CÁC THÔNG TIN NÀY để trả lời câu hỏi một cách chính xác và đầy đủ.
119
  Nếu câu hỏi đưa ra nhiều đáp án thì chọn 1 đáp án đúng nhất.
 
127
  ### Trả lời:"""
128
 
129
  print("--- Bắt đầu tạo câu trả lời từ LLM ---")
130
+
131
+ # SỬA LỖI CHO VISION MODEL: Sử dụng API tường minh
132
+ inputs = tokenizer(
133
+ text=prompt,
134
+ images=None,
135
+ return_tensors="pt"
136
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
137
 
138
  generation_config = dict(
139
  max_new_tokens=256,
 
152
  response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
153
 
154
  print("--- Tạo câu trả lời hoàn tất ---")
155
+ return response_text