AIEXP_RAG_1 / scripts /rag_engine.py
MrSimple07's picture
fixed config + added all the necessary files
1b98e0e
raw
history blame
6.8 kB
from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
from scripts.document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks
import pandas as pd
import faiss
import pickle
import os
from scripts.config import *
def setup_llm_settings():
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
Settings.embed_model = embed_model
def create_vector_index_with_faiss(documents):
setup_llm_settings()
d = 384
faiss_index = faiss.IndexFlatIP(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context
)
return index, faiss_index
def create_retriever(index):
return VectorIndexRetriever(
index=index,
similarity_top_k=RETRIEVER_TOP_K,
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
)
def create_response_synthesizer():
return get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
streaming=False
)
def create_query_engine(index):
retriever = create_retriever(index)
response_synthesizer = create_response_synthesizer()
return RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer
)
def save_rag_system(index, faiss_index, documents):
os.makedirs(RAG_FILES_DIR, exist_ok=True)
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
index.storage_context.persist(persist_dir=RAG_FILES_DIR)
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
pickle.dump(documents, f)
metadata_dict = {}
for doc in documents:
metadata_dict[doc.id_] = doc.metadata
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
pickle.dump(metadata_dict, f)
config = {
'embed_model_name': EMBEDDING_MODEL,
'vector_dim': 384,
'total_documents': len(documents),
'index_type': 'faiss_flat_ip'
}
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
pickle.dump(config, f)
def load_rag_system():
if not os.path.exists(os.path.join(RAG_FILES_DIR, 'faiss_index.index')):
return None
try:
setup_llm_settings()
faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents([], storage_context=storage_context)
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'rb') as f:
documents = pickle.load(f)
for doc in documents:
index.insert(doc)
query_engine = create_query_engine(index)
return query_engine
except Exception as e:
print(f"Error loading RAG system: {str(e)}")
return None
def build_rag_system(processed_chunks):
setup_llm_settings()
documents = create_llama_documents(processed_chunks)
print(f"Created {len(documents)} documents for RAG system")
index, faiss_index = create_vector_index_with_faiss(documents)
query_engine = create_query_engine(index)
save_rag_system(index, faiss_index, documents)
return query_engine
def add_new_document_to_system(file_path, existing_query_engine):
try:
new_chunks = process_single_document(file_path)
if not new_chunks:
return existing_query_engine
if os.path.exists(PROCESSED_DATA_FILE):
existing_df = load_processed_chunks(PROCESSED_DATA_FILE)
existing_chunks = existing_df.to_dict('records')
else:
existing_chunks = []
all_chunks = existing_chunks + new_chunks
save_processed_chunks(all_chunks, PROCESSED_DATA_FILE)
query_engine = build_rag_system(all_chunks)
print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}")
return query_engine
except Exception as e:
print(f"Error adding new document: {str(e)}")
return existing_query_engine
def query_documents(query_engine, question):
response = query_engine.query(question)
return response
def get_response_sources(response):
sources = []
for i, node in enumerate(response.source_nodes):
source_info = {
'chunk_number': i + 1,
'section': node.metadata.get('section', 'Не указан'),
'subsection': node.metadata.get('subsection', 'Не указан'),
'chunk_id': node.metadata.get('chunk_id', 'Не указан'),
'document_id': node.metadata.get('document_id', 'Не указан'),
'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'),
'file_link': node.metadata.get('file_link', 'Не указан'),
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text,
'score': getattr(node, 'score', 0.0)
}
sources.append(source_info)
return sources
def format_response_with_sources(response):
formatted_response = {
'answer': response.response,
'sources': get_response_sources(response)
}
return formatted_response
def test_rag_system(query_engine, test_questions):
results = []
for question in test_questions:
print(f"Question: {question}")
response = query_documents(query_engine, question)
formatted_response = format_response_with_sources(response)
print(f"Answer: {formatted_response['answer']}")
print("Sources:")
for source in formatted_response['sources']:
print(f" - Chunk {source['chunk_number']}: {source['document_id']}")
print(f" Section: {source['section']}, Subsection: {source['subsection']}")
print(f" Preview: {source['text_preview']}")
print("=" * 80)
results.append({
'question': question,
'response': formatted_response
})
return results