Spaces:
Runtime error
Runtime error
from langchain_huggingface import HuggingFaceEmbeddings | |
import config | |
import prompt_templete | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_core.documents import Document | |
import logging | |
from langchain_core.output_parsers import StrOutputParser | |
from typing import List,Any,Dict | |
from langchain_weaviate.vectorstores import WeaviateVectorStore | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from utils.process_data import filter_and_serialize_complex_metadata | |
import weaviate | |
import weaviate.classes.config as wvc_config | |
from weaviate.exceptions import WeaviateQueryException | |
from google.api_core.exceptions import ResourceExhausted, PermissionDenied | |
import time | |
import json | |
import re | |
from operator import itemgetter | |
logger = logging.getLogger(__name__) | |
WEAVIATE_SCHEMA_CONFIG: List[Dict[str, Any]] = [ | |
# Tên trường, Kiểu dữ liệu trong Weaviate, Có nên vector hóa trường này không? | |
{"name": "source", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
{"name": "title", "dataType": wvc_config.DataType.TEXT, "index_searchable": True, "tokenization": wvc_config.Tokenization.WORD, "vectorize": True}, | |
{"name": "field", "dataType": wvc_config.DataType.TEXT,"index_searchable": True, "vectorize": True}, | |
{"name": "so_hieu", "dataType": wvc_config.DataType.TEXT, "index_searchable": False,"vectorize": False}, | |
{"name": "loai_van_ban", "dataType": wvc_config.DataType.TEXT, "index_searchable": True,"vectorize": True}, | |
{"name": "ten_van_ban", "dataType": wvc_config.DataType.TEXT,"index_searchable": True, "tokenization": wvc_config.Tokenization.WORD, "vectorize": True}, | |
{"name": "co_quan_ban_hanh", "dataType": wvc_config.DataType.TEXT, "index_searchable": False,"vectorize": False}, | |
{"name": "ngay_ban_hanh_str", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
{"name": "nam_ban_hanh", "dataType": wvc_config.DataType.INT,"index_searchable": True, "vectorize": False}, | |
{"name": "phan_code", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
{"name": "chuong_code", "dataType": wvc_config.DataType.TEXT, "index_searchable": False,"vectorize": False}, | |
{"name": "muc_code", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
{"name": "dieu_code", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
{"name": "entity_type", "dataType": wvc_config.DataType.TEXT,"index_searchable": True, "vectorize": False}, | |
{"name": "penalties", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
{"name": "cross_references", "dataType": wvc_config.DataType.TEXT, "index_searchable": False, "vectorize": False}, | |
] | |
# Hàm get_huggingface_embeddings giữ nguyên | |
def get_huggingface_embeddings(model_name: str, device: str = 'cpu'): | |
logger.info(f"🔸Đang khởi tạo model embedding: {model_name} trên thiết bị {device}...") | |
model_kwargs = { | |
'device': device, | |
'trust_remote_code': True # thêm để đảm bảo load được những model custom | |
} | |
encode_kwargs = { | |
'batch_size': 32, # kích thước batch cho embedding | |
'normalize_embeddings': True # normalize để cosine similarity chuẩn | |
} | |
try: | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs | |
) | |
logger.info("🔸Khởi tạo model embedding thành công.") | |
return embeddings | |
except Exception as e: | |
logger.error(f"🔸Lỗi khi khởi tạo model embedding: {e}") | |
raise Exception(f"Khởi tạo model embedding thất bại: {str(e)}") | |
# Begin New | |
def create_weaviate_schema_if_not_exists(client: weaviate.WeaviateClient, collection_name: str): | |
""" | |
CẢI TIẾN: Tạo schema với cấu hình chi tiết cho filtering và hybrid search. | |
""" | |
if client.collections.exists(collection_name): | |
logger.info(f"✅ Schema for collection '{collection_name}' already exists.") | |
return | |
logger.info(f"🔸 Schema for collection '{collection_name}' not found. Creating...") | |
try: | |
properties = [] | |
for prop_config in WEAVIATE_SCHEMA_CONFIG: | |
properties.append( | |
wvc_config.Property( | |
name=prop_config["name"], | |
data_type=prop_config["dataType"], | |
# Bỏ qua vector hóa nếu vectorize=False hoặc không được định nghĩa | |
skip_vectorization=not prop_config.get("vectorize", False), | |
# Kích hoạt tokenization cho các trường cần tìm kiếm từ khóa | |
tokenization=prop_config.get("tokenization") | |
) | |
) | |
# Thêm trường 'text' chính, tối ưu cho cả vector và keyword search | |
properties.append( | |
wvc_config.Property( | |
name="text", | |
data_type=wvc_config.DataType.TEXT, | |
skip_vectorization=False, # Luôn vector hóa nội dung chính | |
tokenization=wvc_config.Tokenization.WORD # Cho phép tìm kiếm BM25 trên nội dung | |
) | |
) | |
client.collections.create( | |
name=collection_name, | |
properties=properties, | |
# Kích hoạt inverted index (bắt buộc cho filtering và BM25) | |
inverted_index_config=wvc_config.Configure.inverted_index( | |
index_null_state=True, | |
index_property_length=True, | |
index_timestamps=True, | |
bm25_b=0.75, # Tham số BM25, có thể điều chỉnh | |
bm25_k1=1.2 # Tham số K1 cho BM25 | |
), | |
vectorizer_config=wvc_config.Configure.Vectorizer.none(), | |
vector_index_config=wvc_config.Configure.VectorIndex.hnsw( | |
distance_metric=wvc_config.VectorDistances.COSINE | |
) | |
) | |
logger.info(f"✅ Successfully created schema for collection '{collection_name}'.") | |
except WeaviateQueryException as e: | |
logger.error(f"❌ Error creating schema: {e}", exc_info=True) | |
raise | |
def ingest_chunks_with_native_batching(client: weaviate.WeaviateClient, collection_name: str, chunks: List[Document], embeddings_model): | |
"""Sử dụng API batch gốc của Weaviate, an toàn và hiệu suất cao.""" | |
logger.info(f"🚀 Bắt đầu quá trình ingestion cho {len(chunks)} chunks...") | |
texts_to_embed = [chunk.page_content for chunk in chunks] | |
logger.info(f"🧠 Đang tạo embeddings cho {len(texts_to_embed)} chunks...") | |
start_embed_time = time.time() | |
chunk_vectors = embeddings_model.embed_documents(texts_to_embed) | |
logger.info(f"⏱️ Thời gian tạo embedding: {time.time() - start_embed_time:.2f} giây.") | |
# 3. CẢI TIẾN: Đảm bảo chỉ ingest các thuộc tính hợp lệ | |
valid_property_names = {prop["name"] for prop in WEAVIATE_SCHEMA_CONFIG} | |
valid_property_names.add("text") # Thêm trường 'text' | |
with client.batch.dynamic() as batch: | |
for i, chunk in enumerate(chunks): | |
if not isinstance(chunk,Document) or not hasattr(chunk, 'id') or not chunk.id: | |
logger.warning(f"Bỏ qua chunk ở vị trí {i} do không hợp lệ (sai type hoặc thiếu ID).") | |
continue | |
properties = {"text": chunk.page_content} | |
# Lọc metadata để chỉ giữ lại các key hợp lệ đã định nghĩa trong schema | |
filtered_metadata = { | |
k: v for k, v in chunk.metadata.items() if k in valid_property_names | |
} | |
properties.update(filtered_metadata) | |
batch.add_object( | |
collection=collection_name, | |
properties=properties, | |
uuid=chunk.id, | |
vector=chunk_vectors[i] | |
) | |
logger.info(f"✅ Batching hoàn tất. Đã gửi {len(chunks)} objects.") | |
if batch.number_errors > 0: | |
logger.error(f"❌ Có {batch.number_errors} lỗi xảy ra trong quá trình batching.") | |
# Log ra 5 lỗi đầu tiên để dễ gỡ lỗi | |
for i, error_msg in enumerate(batch.errors): | |
if i >= 5: break | |
logger.error(f" - Lỗi {i+1}: {error_msg}") | |
# End new | |
def create_or_load_vectorstore(embeddings, weaviate_url, collection_name, weaviate_client, chunks=None): | |
vectorstore = None | |
if not embeddings: | |
logger.error("🔸Không có model embedding để tạo/tải vector store.") | |
return None | |
logger.info(f"🔸Truy cập Weaviate tại: {weaviate_url} với collection: {collection_name}") | |
try: | |
# Kết nối tới Weaviate | |
client = weaviate_client | |
if not client: | |
logger.error("🔸Không thể kết nối tới Weaviate.") | |
return None | |
# Tên collection cần kiểm tra | |
collection_name = config.WEAVIATE_COLLECTION_NAME | |
# Kiểm tra xem collection có tồn tại không | |
collection_exists = client.collections.exists(collection_name) | |
logger.info(f"Collection {collection_name} exists: {collection_exists}") | |
if chunks is not None and not collection_exists: | |
logger.info(f"🔸Tạo Weaviate collection mới từ {len(chunks)} chunks...") | |
# Kiểm tra mẫu dữ liệu đầu tiên | |
logger.info(f"🔸Chunk đầu tiên:\n{chunks[0].metadata}") | |
logger.info(f"🔸Nội dung:\n{chunks[0].page_content[:500]}...") | |
# Lọc metadata để đảm bảo tương thích với Weaviate | |
chunks = filter_and_serialize_complex_metadata(chunks) | |
logger.info(f"🔸Metadata chunk đầu tiên sau khi lọc/serialize:\n{chunks[0].metadata}") | |
# KIỂM TRA TYPE | |
if chunks: | |
logger.info(f"Type của chunk đầu tiên: {type(chunks[0])}") | |
# Kiểm tra xem có phải là langchain Document không | |
from langchain_core.documents import Document as LangchainDocument | |
is_langchain_doc = isinstance(chunks[0], LangchainDocument) | |
logger.info(f"Chunk đầu tiên có phải là langchain_core.documents.Document không? {is_langchain_doc}") | |
if not is_langchain_doc: | |
logger.error("!!! LỖI NGHIÊM TRỌNG: Chunks không phải là instance của langchain_core.documents.Document") | |
# In ra các attribute của object để xem nó là gì | |
try: | |
logger.error(f"Attributes của chunk[0]: {dir(chunks[0])}") | |
if hasattr(chunks[0], "metadata"): | |
logger.error(f"Metadata của chunk[0] (nếu có): {chunks[0].metadata}") | |
if hasattr(chunks[0], "page_content"): | |
logger.error(f"Page_content của chunk[0] (nếu có): {chunks[0].page_content[:100]}") | |
except: | |
pass # Bỏ qua nếu không thể dir() | |
return None # Dừng ở đây nếu type sai | |
# Tạo vectorstore | |
max_batch_size = 1000 # Kích thước batch an toàn | |
total_chunks = len(chunks) | |
logger.info("🔸Đang nhúng dữ liệu...") | |
# Tạo collection mới | |
vectorstore = WeaviateVectorStore.from_documents( | |
documents=chunks[:1], # Khởi tạo với 1 tài liệu để tạo schema | |
embedding=embeddings, | |
client=client, | |
index_name=collection_name, | |
text_key="text", # Tên trường văn bản trong tài liệu | |
# by_texts=False # Nếu dùng ids thì không cần by_texts, nhưng để rõ ràng | |
) | |
# Thêm tài liệu theo batch | |
for i in range(1, total_chunks, max_batch_size): | |
end_idx = min(i + max_batch_size, total_chunks) | |
current_batch = chunks[i:end_idx] | |
logger.info(f"🔸Đang xử lý batch {i//max_batch_size + 1}/{(total_chunks-1)//max_batch_size + 1}: từ {i} đến {end_idx-1}") | |
try: | |
vectorstore.add_documents(current_batch) | |
logger.info(f"🔸Đã thêm batch {i//max_batch_size + 1} thành công") | |
except Exception as batch_error: | |
logger.error(f"🔸Lỗi khi xử lý batch từ {i} đến {end_idx-1}: {str(batch_error)}") | |
# Thử với batch nhỏ hơn | |
smaller_batch_size = max_batch_size // 2 | |
if smaller_batch_size >= 10: | |
logger.info(f"🔸Thử lại với batch size nhỏ hơn: {smaller_batch_size}") | |
for j in range(i, end_idx, smaller_batch_size): | |
end_j = min(j + smaller_batch_size, end_idx) | |
smaller_batch = chunks[j:end_j] | |
try: | |
vectorstore.add_documents(smaller_batch) | |
logger.info(f"🔸Đã thêm batch nhỏ từ {j} đến {end_j-1} thành công") | |
except Exception as small_batch_error: | |
logger.error(f"🔸Vẫn lỗi với batch nhỏ hơn từ {j} đến {end_j-1}: {str(small_batch_error)}") | |
else: | |
logger.error(f"🔸Batch size đã quá nhỏ, không thể giảm thêm. Bỏ qua batch này.") | |
logger.info(f"🔸Tạo Weaviate collection thành công: {collection_name}") | |
elif collection_exists: | |
logger.info(f"🔸Tải Weaviate collection đã tồn tại: {collection_name}") | |
vectorstore = WeaviateVectorStore( | |
client=client, | |
index_name=collection_name, | |
embedding=embeddings, | |
text_key="text", | |
attributes=[ # Liệt kê TẤT CẢ các metadata bạn cần để retriever hoạt động | |
"nam_ban_hanh", "title", "source", "field", "loai_van_ban", "so_hieu", | |
"ten_van_ban", "ngay_ban_hanh_str", "co_quan_ban_hanh", "entity_type", | |
# Các trường serialize thành JSON cũng cần được liệt kê nếu muốn lấy về | |
"cross_references", "penalties" | |
] | |
) | |
logger.info("🔸Tải Weaviate collection thành công.") | |
else: | |
logger.error(f"🔸Collection '{collection_name}' không tồn tại và không có dữ liệu chunks để tạo mới.") | |
return None | |
logger.info("🔸Vectorstore sẵn sàng.") | |
return vectorstore | |
except Exception as e: | |
if client: | |
client.close() | |
logger.info("🔸Đã đóng kết nối tới Weaviate.") | |
logger.error(f"🔸Lỗi khi tạo/tải Weaviate vector store: {e}") | |
return None | |
def get_google_llm(google_api_key): | |
logger.info("🔸Đang khởi tạo LLM từ Google Generative AI...") | |
if not google_api_key: | |
logger.error("🔸Google API Key không được cung cấp.") | |
return None | |
try: | |
def create_chat_google(): | |
return ChatGoogleGenerativeAI( | |
model="gemini-2.5-flash", | |
google_api_key=google_api_key, | |
temperature=0.0, # Điều chỉnh nhiệt độ nếu cần, 0.1-0.3 thường tốt cho RAG | |
safety_settings={}, | |
) | |
llm = create_chat_google() | |
logger.info("🔸Khởi tạo Google Generative AI LLM thành công.") | |
return llm | |
except Exception as e: | |
logger.error(f"🔸Lỗi khi khởi tạo Google Generative AI LLM: {e}") | |
return None | |
def create_llm_from_google_key_list(google_api_keys: List[str]): | |
""" | |
Khởi tạo một LLM duy nhất có khả năng tự động fallback qua một danh sách | |
các API key của Google. | |
Khi một key hết hạn mức (lỗi ResourceExhausted), nó sẽ tự động thử key tiếp theo. | |
:param google_api_keys: Một list chứa các chuỗi API key của Google. | |
:return: Một đối tượng LLM của LangChain, hoặc None nếu có lỗi. | |
""" | |
if not google_api_keys or not isinstance(google_api_keys, list): | |
logger.error("❌ Danh sách API key không hợp lệ hoặc bị rỗng.") | |
return None | |
logger.info(f"🔸 Đang khởi tạo chuỗi LLM từ {len(google_api_keys)} API key của Google...") | |
try: | |
# --- 1. Tạo một danh sách các instance LLM, mỗi cái với một key khác nhau --- | |
llm_instances = [ | |
ChatGoogleGenerativeAI( | |
model="gemini-2.5-flash-preview-05-20", | |
google_api_key=key, | |
temperature=0.0, # Điều chỉnh nhiệt độ nếu cần, 0.1-0.3 thường tốt cho RAG | |
safety_settings={}, | |
) | |
for key in google_api_keys | |
] | |
# --- 2. Nếu chỉ có một key, không cần fallback --- | |
if len(llm_instances) == 1: | |
logger.info("✅ Chỉ có một API key được cung cấp. Không cấu hình fallback.") | |
return llm_instances[0] | |
# --- 3. Dùng LLM đầu tiên làm LLM chính, phần còn lại làm fallback --- | |
primary_llm = llm_instances[0] | |
fallback_llms = llm_instances[1:] | |
logger.info(f"▶️ LLM chính sẽ dùng key: '...{google_api_keys[0][-4:]}'") | |
for i, llm in enumerate(fallback_llms): | |
logger.info(f"↪️ Fallback {i+1} sẽ dùng key: '...{google_api_keys[i+1][-4:]}'") | |
# --- 4. Kết hợp chúng lại --- | |
llm_with_fallbacks = primary_llm.with_fallbacks( | |
fallbacks=fallback_llms, | |
exceptions_to_handle=(ResourceExhausted, PermissionDenied) | |
) | |
logger.info("✅ Đã tạo thành công chuỗi LLM với cơ chế fallback giữa các key Google!") | |
return llm_with_fallbacks | |
except Exception as e: | |
logger.error(f"❌ Lỗi nghiêm trọng khi tạo chuỗi LLM từ danh sách key: {e}", exc_info=True) | |
return None | |
#new update | |
def _extract_and_parse_json(text_with_json: str) -> dict: | |
""" | |
Hàm trợ giúp để tìm và trích xuất khối JSON đầu tiên từ một chuỗi văn bản. | |
""" | |
json_match = re.search(r'\{.*\}', text_with_json, re.DOTALL) | |
if json_match: | |
json_str = json_match.group(0) | |
try: | |
return json.loads(json_str) | |
except json.JSONDecodeError: | |
logger.error(f"Không thể phân tích chuỗi JSON được trích xuất: {json_str}") | |
raise | |
else: | |
logger.error(f"Không tìm thấy khối JSON nào trong output: {text_with_json}") | |
raise ValueError("Không tìm thấy đối tượng JSON trong output của LLM") | |
def _extract_final_answer(rag_output_with_thinking: str) -> str: | |
""" | |
Hàm trợ giúp để trích xuất câu trả lời cuối cùng từ output của QA_PROMPT_TEMPLATE. | |
""" | |
start_tag = "[BEGIN FINAL ANSWER]" | |
end_tag = "[END FINAL ANSWER]" | |
start_index = rag_output_with_thinking.find(start_tag) | |
end_index = rag_output_with_thinking.find(end_tag) | |
if start_index != -1 and end_index != -1: | |
return rag_output_with_thinking[start_index + len(start_tag):end_index].strip() | |
logger.warning("Không tìm thấy thẻ đánh dấu trả lời trong output của RAG. Trả về toàn bộ output.") | |
return rag_output_with_thinking | |
def create_qa_chain( | |
llm: any, | |
retriever: any, | |
process_input_llm: any = None | |
): | |
""" | |
PHIÊN BẢN CUỐI CÙNG: Tạo ra một RAG chain hoàn chỉnh, có khả năng xử lý đa ngôn ngữ | |
bằng cách bảo toàn dữ liệu đầu vào gốc. | |
""" | |
if not all([llm, retriever]): | |
logger.error("🔸 Thiếu LLM hoặc Retriever chính để tạo QA Chain.") | |
return None | |
try: | |
logger.info("🔸 Bắt đầu tạo QA Chain phiên bản Tối Ưu Nhất...") | |
preprocessing_llm = process_input_llm or llm | |
# ----- 1. KHAI BÁO PROMPTS ----- | |
unified_preprocessing_prompt = ChatPromptTemplate.from_template( | |
prompt_templete.UNIFIED_PREPROCESSING_PROMPT | |
) | |
qa_rag_prompt = ChatPromptTemplate.from_template( | |
prompt_templete.QA_PROMPT_TEMPLATE | |
) | |
general_response_prompt = ChatPromptTemplate.from_template( | |
prompt_templete.GENERAL_RESPONSE_PROMPT | |
) | |
# ----- 2. ĐỊNH NGHĨA CÁC NHÁNH XỬ LÝ ----- | |
# --- Nhánh A: LEGAL RAG (Nhận dict đầy đủ) --- | |
legal_rag_chain = ( | |
RunnablePassthrough.assign( | |
context=itemgetter("rewritten_question") | retriever | |
).assign( | |
# Lấy câu trả lời đã được dọn dẹp | |
answer=( | |
qa_rag_prompt | |
| llm | |
| StrOutputParser() | |
| RunnableLambda(_extract_final_answer) | |
) | |
) | |
# Chỉ trả về 2 key quan trọng nhất | |
| (lambda x: {"answer": x["answer"], "context": x["context"]}) | |
).with_config({"run_name": "LegalRAGChain"}) | |
# --- Nhánh B: GENERAL RESPONSE (Nhận dict đầy đủ) --- | |
general_response_chain = ( | |
general_response_prompt | |
| llm | |
| StrOutputParser() | |
| (lambda answer_str: {"answer": answer_str, "context": []}) | |
).with_config({"run_name": "GeneralResponseChain"}) | |
# ----- 3. BỘ ĐỊNH TUYẾN (ROUTER) ----- | |
def route(info: dict): | |
classification = info.get("classification") | |
logger.info(f"➡️ Định tuyến truy vấn với phân loại: '{classification}'") | |
if classification == "legal_rag": | |
return legal_rag_chain | |
else: | |
return general_response_chain | |
# ----- 4. KẾT HỢP THÀNH FULL CHAIN ----- | |
# <--- THAY ĐỔI QUAN TRỌNG BẮT ĐẦU TỪ ĐÂY ---> | |
# 4.1. Chuỗi con để thực hiện tiền xử lý và trả về JSON | |
preprocessing_logic = ( | |
unified_preprocessing_prompt | |
| preprocessing_llm | |
| StrOutputParser() | |
| RunnableLambda(_extract_and_parse_json) | |
) | |
# 4.2. Xây dựng chuỗi chính để BẢO TOÀN và HỢP NHẤT dữ liệu | |
full_chain = ( | |
# Bắt đầu với một Passthrough để giữ lại dữ liệu gốc (input, chat_history) | |
RunnablePassthrough.assign( | |
# Chạy chuỗi tiền xử lý và gán kết quả của nó vào một key mới là `processed` | |
processed=preprocessing_logic | |
) | |
| RunnableLambda( | |
# Hàm này sẽ "làm phẳng" dict trên thành một dict duy nhất | |
# để các nhánh sau có thể truy cập tất cả các key | |
lambda x: { | |
"input": x["input"], | |
"chat_history": x.get("chat_history", []), | |
"classification": x["processed"]["classification"], | |
"rewritten_question": x["processed"]["rewritten_question"] | |
} | |
) | |
# 4.3. Chạy bộ định tuyến với dict đã được làm phẳng | |
| RunnableLambda( | |
# `info_dict` bây giờ chứa tất cả các key cần thiết | |
lambda info_dict: route(info_dict).invoke(info_dict) | |
) | |
) | |
# <--- THAY ĐỔI QUAN TRỌNG KẾT THÚC TẠI ĐÂY ---> | |
logger.info("✅ Tạo thành công QA Chain phiên bản TỐI ƯU NHẤT.") | |
return full_chain | |
except Exception as e: | |
logger.error(f"❌ Lỗi khi tạo QA Chain: {e}", exc_info=True) | |
return None |