AIEXP_RAG_1 / scripts /rag_engine.py
MrSimple07's picture
added download link + dataset from hf
aa622c0
raw
history blame
13 kB
from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
from scripts.document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks
import pandas as pd
import faiss
import pickle
import os
from scripts.config import *
def setup_llm_settings():
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
Settings.embed_model = embed_model
def create_vector_index_with_faiss(documents):
# Setup settings FIRST before creating any objects
setup_llm_settings()
d = 384 # Dimension for the embedding model
faiss_index = faiss.IndexFlatIP(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Use the embedding model from Settings
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
embed_model=Settings.embed_model # Use Settings instead of string
)
return index, faiss_index
def create_retriever(index):
return VectorIndexRetriever(
index=index,
similarity_top_k=RETRIEVER_TOP_K,
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
)
def create_enhanced_retriever(index, query_str=None):
"""Create retriever with score transparency"""
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=RETRIEVER_TOP_K,
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
)
return retriever
def query_documents_with_scores(query_engine, question):
response = query_engine.query(question)
# Extract scores from source nodes
scored_sources = []
max_score = 0.0
for node in response.source_nodes:
score = getattr(node, 'score', 0.0)
max_score = max(max_score, score)
scored_sources.append({
'node': node,
'score': score,
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text
})
scored_sources.sort(key=lambda x: x['score'], reverse=True)
QUERY_RELEVANCE_THRESHOLD = 0.6
is_query_relevant = max_score >= QUERY_RELEVANCE_THRESHOLD
# Enhanced response object
enhanced_response = {
'original_response': response,
'answer': response.response,
'max_similarity_score': max_score,
'is_query_relevant': is_query_relevant,
'scored_sources': scored_sources,
'total_sources': len(scored_sources)
}
# If query is not relevant, modify the answer
if not is_query_relevant:
enhanced_response['answer'] = (
"На основе доступных нормативных документов я не могу дать точный ответ на ваш вопрос. "
f"Максимальная релевантность найденных документов: {max_score:.2f}. "
"Попробуйте переформулировать вопрос или быть более конкретным."
)
enhanced_response['scored_sources'] = [] # Don't show irrelevant sources
return enhanced_response
def format_enhanced_response_with_sources(enhanced_response):
"""Format response with detailed scoring info"""
sources_info = []
if enhanced_response['is_query_relevant']:
sources_info.append("📚 Источники из нормативной документации:")
sources_info.append(f"🎯 Максимальная релевантность: {enhanced_response['max_similarity_score']:.3f}")
for i, source_data in enumerate(enhanced_response['scored_sources'][:5], 1):
node = source_data['node']
score = source_data['score']
sources_info.append(f"\n{i}. Релевантность: {score:.3f}")
sources_info.append(f" Документ: {node.metadata.get('document_id', 'Неизвестен')}")
if node.metadata.get('section'):
sources_info.append(f" Раздел: {node.metadata.get('section')}")
if node.metadata.get('subsection'):
sources_info.append(f" Подраздел: {node.metadata.get('subsection')}")
sources_info.append(f" Фрагмент: ...{source_data['text_preview']}")
else:
sources_info.append("⚠️ Запрос имеет низкую релевантность к базе нормативных документов")
sources_info.append(f"🎯 Максимальная найденная релевантность: {enhanced_response['max_similarity_score']:.3f}")
sources_info.append("💡 Рекомендация: Переформулируйте вопрос более конкретно")
return {
'answer': enhanced_response['answer'],
'sources': "\n".join(sources_info),
'is_relevant': enhanced_response['is_query_relevant'],
'max_score': enhanced_response['max_similarity_score']
}
def create_response_synthesizer():
return get_response_synthesizer(
response_mode=ResponseMode.TREE_SUMMARIZE,
streaming=False
)
def create_query_engine(index):
retriever = create_retriever(index)
response_synthesizer = create_response_synthesizer()
return RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer
)
def save_rag_system(index, faiss_index, documents):
os.makedirs(RAG_FILES_DIR, exist_ok=True)
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
index.storage_context.persist(persist_dir=RAG_FILES_DIR)
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
pickle.dump(documents, f)
metadata_dict = {}
for doc in documents:
metadata_dict[doc.id_] = doc.metadata
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
pickle.dump(metadata_dict, f)
config = {
'embed_model_name': EMBEDDING_MODEL,
'vector_dim': 384,
'total_documents': len(documents),
'index_type': 'faiss_flat_ip'
}
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
pickle.dump(config, f)
def load_rag_system():
"""Load RAG system with better error handling and file verification"""
required_files = [
'faiss_index.index',
'default__vector_store.json',
'docstore.json',
'index_store.json'
]
# Check if all required files exist
missing_files = []
for file in required_files:
if not os.path.exists(os.path.join(RAG_FILES_DIR, file)):
missing_files.append(file)
if missing_files:
print(f"Missing RAG system files: {missing_files}")
return None
try:
setup_llm_settings()
# Load FAISS index
faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
vector_store = FaissVectorStore(faiss_index=faiss_index)
# Load storage context from persisted files
storage_context = StorageContext.from_defaults(
vector_store=vector_store,
persist_dir=RAG_FILES_DIR
)
# Create index from storage context
index = VectorStoreIndex.from_documents(
[],
storage_context=storage_context,
embed_model=Settings.embed_model
)
# Verify the index loaded correctly
print(f"✅ RAG system loaded with {faiss_index.ntotal} vectors")
query_engine = create_query_engine(index)
return query_engine
except Exception as e:
print(f"❌ Error loading RAG system: {str(e)}")
return None
def save_rag_system(index, faiss_index, documents):
"""Enhanced save function with verification"""
try:
os.makedirs(RAG_FILES_DIR, exist_ok=True)
# Save FAISS index
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
# Persist storage context (saves docstore.json, index_store.json, default__vector_store.json)
index.storage_context.persist(persist_dir=RAG_FILES_DIR)
# Save documents pickle (for compatibility)
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
pickle.dump(documents, f)
# Save metadata pickle (for compatibility)
metadata_dict = {}
for doc in documents:
metadata_dict[doc.id_] = doc.metadata
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
pickle.dump(metadata_dict, f)
# Save config
config = {
'embed_model_name': EMBEDDING_MODEL,
'vector_dim': 384,
'total_documents': len(documents),
'index_type': 'faiss_flat_ip'
}
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
pickle.dump(config, f)
print(f"✅ RAG system saved successfully with {len(documents)} documents")
except Exception as e:
print(f"❌ Error saving RAG system: {str(e)}")
raise
def build_rag_system(processed_chunks):
setup_llm_settings()
documents = create_llama_documents(processed_chunks)
print(f"Created {len(documents)} documents for RAG system")
index, faiss_index = create_vector_index_with_faiss(documents)
query_engine = create_query_engine(index)
save_rag_system(index, faiss_index, documents)
return query_engine
def add_new_document_to_system(file_path, existing_query_engine):
try:
new_chunks = process_single_document(file_path)
if not new_chunks:
return existing_query_engine
if os.path.exists(PROCESSED_DATA_FILE):
existing_df = load_processed_chunks(PROCESSED_DATA_FILE)
existing_chunks = existing_df.to_dict('records')
else:
existing_chunks = []
all_chunks = existing_chunks + new_chunks
save_processed_chunks(all_chunks, PROCESSED_DATA_FILE)
query_engine = build_rag_system(all_chunks)
print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}")
return query_engine
except Exception as e:
print(f"Error adding new document: {str(e)}")
return existing_query_engine
def query_documents(query_engine, question):
response = query_engine.query(question)
return response
def get_response_sources(response):
sources = []
for i, node in enumerate(response.source_nodes):
source_info = {
'chunk_number': i + 1,
'section': node.metadata.get('section', 'Не указан'),
'subsection': node.metadata.get('subsection', 'Не указан'),
'chunk_id': node.metadata.get('chunk_id', 'Не указан'),
'document_id': node.metadata.get('document_id', 'Не указан'),
'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'),
'file_link': node.metadata.get('file_link', 'Не указан'),
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text,
'score': getattr(node, 'score', 0.0)
}
sources.append(source_info)
return sources
def format_response_with_sources(response):
formatted_response = {
'answer': response.response,
'sources': get_response_sources(response)
}
return formatted_response
def test_rag_system(query_engine, test_questions):
results = []
for question in test_questions:
print(f"Question: {question}")
response = query_documents(query_engine, question)
formatted_response = format_response_with_sources(response)
print(f"Answer: {formatted_response['answer']}")
print("Sources:")
for source in formatted_response['sources']:
print(f" - Chunk {source['chunk_number']}: {source['document_id']}")
print(f" Section: {source['section']}, Subsection: {source['subsection']}")
print(f" Preview: {source['text_preview']}")
print("=" * 80)
results.append({
'question': question,
'response': formatted_response
})
return results