Spaces:
Sleeping
Sleeping
from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.core.retrievers import VectorIndexRetriever | |
from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer | |
from scripts.document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks | |
import pandas as pd | |
import faiss | |
import pickle | |
import os | |
from scripts.config import * | |
def setup_llm_settings(): | |
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL) | |
Settings.embed_model = embed_model | |
def create_vector_index_with_faiss(documents): | |
# Setup settings FIRST before creating any objects | |
setup_llm_settings() | |
d = 384 # Dimension for the embedding model | |
faiss_index = faiss.IndexFlatIP(d) | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
# Use the embedding model from Settings | |
index = VectorStoreIndex.from_documents( | |
documents, | |
storage_context=storage_context, | |
embed_model=Settings.embed_model # Use Settings instead of string | |
) | |
return index, faiss_index | |
def create_retriever(index): | |
return VectorIndexRetriever( | |
index=index, | |
similarity_top_k=RETRIEVER_TOP_K, | |
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF | |
) | |
def create_enhanced_retriever(index, query_str=None): | |
"""Create retriever with score transparency""" | |
retriever = VectorIndexRetriever( | |
index=index, | |
similarity_top_k=RETRIEVER_TOP_K, | |
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF | |
) | |
return retriever | |
def query_documents_with_scores(query_engine, question): | |
response = query_engine.query(question) | |
# Extract scores from source nodes | |
scored_sources = [] | |
max_score = 0.0 | |
for node in response.source_nodes: | |
score = getattr(node, 'score', 0.0) | |
max_score = max(max_score, score) | |
scored_sources.append({ | |
'node': node, | |
'score': score, | |
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text | |
}) | |
scored_sources.sort(key=lambda x: x['score'], reverse=True) | |
QUERY_RELEVANCE_THRESHOLD = 0.6 | |
is_query_relevant = max_score >= QUERY_RELEVANCE_THRESHOLD | |
# Enhanced response object | |
enhanced_response = { | |
'original_response': response, | |
'answer': response.response, | |
'max_similarity_score': max_score, | |
'is_query_relevant': is_query_relevant, | |
'scored_sources': scored_sources, | |
'total_sources': len(scored_sources) | |
} | |
# If query is not relevant, modify the answer | |
if not is_query_relevant: | |
enhanced_response['answer'] = ( | |
"На основе доступных нормативных документов я не могу дать точный ответ на ваш вопрос. " | |
f"Максимальная релевантность найденных документов: {max_score:.2f}. " | |
"Попробуйте переформулировать вопрос или быть более конкретным." | |
) | |
enhanced_response['scored_sources'] = [] # Don't show irrelevant sources | |
return enhanced_response | |
def format_enhanced_response_with_sources(enhanced_response): | |
"""Format response with detailed scoring info""" | |
sources_info = [] | |
if enhanced_response['is_query_relevant']: | |
sources_info.append("📚 Источники из нормативной документации:") | |
sources_info.append(f"🎯 Максимальная релевантность: {enhanced_response['max_similarity_score']:.3f}") | |
for i, source_data in enumerate(enhanced_response['scored_sources'][:5], 1): | |
node = source_data['node'] | |
score = source_data['score'] | |
sources_info.append(f"\n{i}. Релевантность: {score:.3f}") | |
sources_info.append(f" Документ: {node.metadata.get('document_id', 'Неизвестен')}") | |
if node.metadata.get('section'): | |
sources_info.append(f" Раздел: {node.metadata.get('section')}") | |
if node.metadata.get('subsection'): | |
sources_info.append(f" Подраздел: {node.metadata.get('subsection')}") | |
sources_info.append(f" Фрагмент: ...{source_data['text_preview']}") | |
else: | |
sources_info.append("⚠️ Запрос имеет низкую релевантность к базе нормативных документов") | |
sources_info.append(f"🎯 Максимальная найденная релевантность: {enhanced_response['max_similarity_score']:.3f}") | |
sources_info.append("💡 Рекомендация: Переформулируйте вопрос более конкретно") | |
return { | |
'answer': enhanced_response['answer'], | |
'sources': "\n".join(sources_info), | |
'is_relevant': enhanced_response['is_query_relevant'], | |
'max_score': enhanced_response['max_similarity_score'] | |
} | |
def create_response_synthesizer(): | |
return get_response_synthesizer( | |
response_mode=ResponseMode.TREE_SUMMARIZE, | |
streaming=False | |
) | |
def create_query_engine(index): | |
retriever = create_retriever(index) | |
response_synthesizer = create_response_synthesizer() | |
return RetrieverQueryEngine( | |
retriever=retriever, | |
response_synthesizer=response_synthesizer | |
) | |
def save_rag_system(index, faiss_index, documents): | |
os.makedirs(RAG_FILES_DIR, exist_ok=True) | |
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index')) | |
index.storage_context.persist(persist_dir=RAG_FILES_DIR) | |
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f: | |
pickle.dump(documents, f) | |
metadata_dict = {} | |
for doc in documents: | |
metadata_dict[doc.id_] = doc.metadata | |
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f: | |
pickle.dump(metadata_dict, f) | |
config = { | |
'embed_model_name': EMBEDDING_MODEL, | |
'vector_dim': 384, | |
'total_documents': len(documents), | |
'index_type': 'faiss_flat_ip' | |
} | |
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f: | |
pickle.dump(config, f) | |
def load_rag_system(): | |
"""Load RAG system with better error handling and file verification""" | |
required_files = [ | |
'faiss_index.index', | |
'default__vector_store.json', | |
'docstore.json', | |
'index_store.json' | |
] | |
# Check if all required files exist | |
missing_files = [] | |
for file in required_files: | |
if not os.path.exists(os.path.join(RAG_FILES_DIR, file)): | |
missing_files.append(file) | |
if missing_files: | |
print(f"Missing RAG system files: {missing_files}") | |
return None | |
try: | |
setup_llm_settings() | |
# Load FAISS index | |
faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index')) | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
# Load storage context from persisted files | |
storage_context = StorageContext.from_defaults( | |
vector_store=vector_store, | |
persist_dir=RAG_FILES_DIR | |
) | |
# Create index from storage context | |
index = VectorStoreIndex.from_documents( | |
[], | |
storage_context=storage_context, | |
embed_model=Settings.embed_model | |
) | |
# Verify the index loaded correctly | |
print(f"✅ RAG system loaded with {faiss_index.ntotal} vectors") | |
query_engine = create_query_engine(index) | |
return query_engine | |
except Exception as e: | |
print(f"❌ Error loading RAG system: {str(e)}") | |
return None | |
def save_rag_system(index, faiss_index, documents): | |
"""Enhanced save function with verification""" | |
try: | |
os.makedirs(RAG_FILES_DIR, exist_ok=True) | |
# Save FAISS index | |
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index')) | |
# Persist storage context (saves docstore.json, index_store.json, default__vector_store.json) | |
index.storage_context.persist(persist_dir=RAG_FILES_DIR) | |
# Save documents pickle (for compatibility) | |
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f: | |
pickle.dump(documents, f) | |
# Save metadata pickle (for compatibility) | |
metadata_dict = {} | |
for doc in documents: | |
metadata_dict[doc.id_] = doc.metadata | |
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f: | |
pickle.dump(metadata_dict, f) | |
# Save config | |
config = { | |
'embed_model_name': EMBEDDING_MODEL, | |
'vector_dim': 384, | |
'total_documents': len(documents), | |
'index_type': 'faiss_flat_ip' | |
} | |
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f: | |
pickle.dump(config, f) | |
print(f"✅ RAG system saved successfully with {len(documents)} documents") | |
except Exception as e: | |
print(f"❌ Error saving RAG system: {str(e)}") | |
raise | |
def build_rag_system(processed_chunks): | |
setup_llm_settings() | |
documents = create_llama_documents(processed_chunks) | |
print(f"Created {len(documents)} documents for RAG system") | |
index, faiss_index = create_vector_index_with_faiss(documents) | |
query_engine = create_query_engine(index) | |
save_rag_system(index, faiss_index, documents) | |
return query_engine | |
def add_new_document_to_system(file_path, existing_query_engine): | |
try: | |
new_chunks = process_single_document(file_path) | |
if not new_chunks: | |
return existing_query_engine | |
if os.path.exists(PROCESSED_DATA_FILE): | |
existing_df = load_processed_chunks(PROCESSED_DATA_FILE) | |
existing_chunks = existing_df.to_dict('records') | |
else: | |
existing_chunks = [] | |
all_chunks = existing_chunks + new_chunks | |
save_processed_chunks(all_chunks, PROCESSED_DATA_FILE) | |
query_engine = build_rag_system(all_chunks) | |
print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}") | |
return query_engine | |
except Exception as e: | |
print(f"Error adding new document: {str(e)}") | |
return existing_query_engine | |
def query_documents(query_engine, question): | |
response = query_engine.query(question) | |
return response | |
def get_response_sources(response): | |
sources = [] | |
for i, node in enumerate(response.source_nodes): | |
source_info = { | |
'chunk_number': i + 1, | |
'section': node.metadata.get('section', 'Не указан'), | |
'subsection': node.metadata.get('subsection', 'Не указан'), | |
'chunk_id': node.metadata.get('chunk_id', 'Не указан'), | |
'document_id': node.metadata.get('document_id', 'Не указан'), | |
'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'), | |
'file_link': node.metadata.get('file_link', 'Не указан'), | |
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text, | |
'score': getattr(node, 'score', 0.0) | |
} | |
sources.append(source_info) | |
return sources | |
def format_response_with_sources(response): | |
formatted_response = { | |
'answer': response.response, | |
'sources': get_response_sources(response) | |
} | |
return formatted_response | |
def test_rag_system(query_engine, test_questions): | |
results = [] | |
for question in test_questions: | |
print(f"Question: {question}") | |
response = query_documents(query_engine, question) | |
formatted_response = format_response_with_sources(response) | |
print(f"Answer: {formatted_response['answer']}") | |
print("Sources:") | |
for source in formatted_response['sources']: | |
print(f" - Chunk {source['chunk_number']}: {source['document_id']}") | |
print(f" Section: {source['section']}, Subsection: {source['subsection']}") | |
print(f" Preview: {source['text_preview']}") | |
print("=" * 80) | |
results.append({ | |
'question': question, | |
'response': formatted_response | |
}) | |
return results |