Spaces:
Sleeping
Sleeping
Commit
·
36eab75
1
Parent(s):
c9db278
Update scripts/rag_engine.py
Browse files- scripts/rag_engine.py +199 -199
scripts/rag_engine.py
CHANGED
@@ -1,200 +1,200 @@
|
|
1 |
-
from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings
|
2 |
-
from llama_index.vector_stores.faiss import FaissVectorStore
|
3 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
4 |
-
from llama_index.core.query_engine import RetrieverQueryEngine
|
5 |
-
from llama_index.core.retrievers import VectorIndexRetriever
|
6 |
-
from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
|
7 |
-
from document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks
|
8 |
-
import pandas as pd
|
9 |
-
import faiss
|
10 |
-
import pickle
|
11 |
-
import os
|
12 |
-
|
13 |
-
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
14 |
-
RETRIEVER_TOP_K = 10
|
15 |
-
RETRIEVER_SIMILARITY_CUTOFF = 0.7
|
16 |
-
RAG_FILES_DIR = "processed_data"
|
17 |
-
PROCESSED_DATA_FILE = "processed_data/processed_chunks.csv"
|
18 |
-
|
19 |
-
def setup_llm_settings():
|
20 |
-
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
|
21 |
-
Settings.embed_model = embed_model
|
22 |
-
|
23 |
-
def create_vector_index_with_faiss(documents):
|
24 |
-
setup_llm_settings()
|
25 |
-
|
26 |
-
d = 384
|
27 |
-
faiss_index = faiss.IndexFlatIP(d)
|
28 |
-
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
29 |
-
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
30 |
-
|
31 |
-
index = VectorStoreIndex.from_documents(
|
32 |
-
documents,
|
33 |
-
storage_context=storage_context
|
34 |
-
)
|
35 |
-
|
36 |
-
return index, faiss_index
|
37 |
-
|
38 |
-
def create_retriever(index):
|
39 |
-
return VectorIndexRetriever(
|
40 |
-
index=index,
|
41 |
-
similarity_top_k=RETRIEVER_TOP_K,
|
42 |
-
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
|
43 |
-
)
|
44 |
-
|
45 |
-
def create_response_synthesizer():
|
46 |
-
return get_response_synthesizer(
|
47 |
-
response_mode=ResponseMode.TREE_SUMMARIZE,
|
48 |
-
streaming=False
|
49 |
-
)
|
50 |
-
|
51 |
-
def create_query_engine(index):
|
52 |
-
retriever = create_retriever(index)
|
53 |
-
response_synthesizer = create_response_synthesizer()
|
54 |
-
|
55 |
-
return RetrieverQueryEngine(
|
56 |
-
retriever=retriever,
|
57 |
-
response_synthesizer=response_synthesizer
|
58 |
-
)
|
59 |
-
|
60 |
-
def save_rag_system(index, faiss_index, documents):
|
61 |
-
os.makedirs(RAG_FILES_DIR, exist_ok=True)
|
62 |
-
|
63 |
-
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
|
64 |
-
|
65 |
-
index.storage_context.persist(persist_dir=RAG_FILES_DIR)
|
66 |
-
|
67 |
-
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
|
68 |
-
pickle.dump(documents, f)
|
69 |
-
|
70 |
-
metadata_dict = {}
|
71 |
-
for doc in documents:
|
72 |
-
metadata_dict[doc.id_] = doc.metadata
|
73 |
-
|
74 |
-
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
|
75 |
-
pickle.dump(metadata_dict, f)
|
76 |
-
|
77 |
-
config = {
|
78 |
-
'embed_model_name': EMBEDDING_MODEL,
|
79 |
-
'vector_dim': 384,
|
80 |
-
'total_documents': len(documents),
|
81 |
-
'index_type': 'faiss_flat_ip'
|
82 |
-
}
|
83 |
-
|
84 |
-
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
|
85 |
-
pickle.dump(config, f)
|
86 |
-
|
87 |
-
def load_rag_system():
|
88 |
-
if not os.path.exists(os.path.join(RAG_FILES_DIR, 'faiss_index.index')):
|
89 |
-
return None
|
90 |
-
|
91 |
-
try:
|
92 |
-
setup_llm_settings()
|
93 |
-
|
94 |
-
faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
|
95 |
-
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
96 |
-
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
97 |
-
|
98 |
-
index = VectorStoreIndex.from_documents([], storage_context=storage_context)
|
99 |
-
|
100 |
-
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'rb') as f:
|
101 |
-
documents = pickle.load(f)
|
102 |
-
|
103 |
-
for doc in documents:
|
104 |
-
index.insert(doc)
|
105 |
-
|
106 |
-
query_engine = create_query_engine(index)
|
107 |
-
return query_engine
|
108 |
-
|
109 |
-
except Exception as e:
|
110 |
-
print(f"Error loading RAG system: {str(e)}")
|
111 |
-
return None
|
112 |
-
|
113 |
-
def build_rag_system(processed_chunks):
|
114 |
-
setup_llm_settings()
|
115 |
-
|
116 |
-
documents = create_llama_documents(processed_chunks)
|
117 |
-
print(f"Created {len(documents)} documents for RAG system")
|
118 |
-
|
119 |
-
index, faiss_index = create_vector_index_with_faiss(documents)
|
120 |
-
query_engine = create_query_engine(index)
|
121 |
-
|
122 |
-
save_rag_system(index, faiss_index, documents)
|
123 |
-
|
124 |
-
return query_engine
|
125 |
-
|
126 |
-
def add_new_document_to_system(file_path, existing_query_engine):
|
127 |
-
try:
|
128 |
-
new_chunks = process_single_document(file_path)
|
129 |
-
|
130 |
-
if not new_chunks:
|
131 |
-
return existing_query_engine
|
132 |
-
|
133 |
-
if os.path.exists(PROCESSED_DATA_FILE):
|
134 |
-
existing_df = load_processed_chunks(PROCESSED_DATA_FILE)
|
135 |
-
existing_chunks = existing_df.to_dict('records')
|
136 |
-
else:
|
137 |
-
existing_chunks = []
|
138 |
-
|
139 |
-
all_chunks = existing_chunks + new_chunks
|
140 |
-
save_processed_chunks(all_chunks, PROCESSED_DATA_FILE)
|
141 |
-
|
142 |
-
query_engine = build_rag_system(all_chunks)
|
143 |
-
|
144 |
-
print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}")
|
145 |
-
return query_engine
|
146 |
-
|
147 |
-
except Exception as e:
|
148 |
-
print(f"Error adding new document: {str(e)}")
|
149 |
-
return existing_query_engine
|
150 |
-
|
151 |
-
def query_documents(query_engine, question):
|
152 |
-
response = query_engine.query(question)
|
153 |
-
return response
|
154 |
-
|
155 |
-
def get_response_sources(response):
|
156 |
-
sources = []
|
157 |
-
for i, node in enumerate(response.source_nodes):
|
158 |
-
source_info = {
|
159 |
-
'chunk_number': i + 1,
|
160 |
-
'section': node.metadata.get('section', 'Не указан'),
|
161 |
-
'subsection': node.metadata.get('subsection', 'Не указан'),
|
162 |
-
'chunk_id': node.metadata.get('chunk_id', 'Не указан'),
|
163 |
-
'document_id': node.metadata.get('document_id', 'Не указан'),
|
164 |
-
'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'),
|
165 |
-
'file_link': node.metadata.get('file_link', 'Не указан'),
|
166 |
-
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text,
|
167 |
-
'score': getattr(node, 'score', 0.0)
|
168 |
-
}
|
169 |
-
sources.append(source_info)
|
170 |
-
return sources
|
171 |
-
|
172 |
-
def format_response_with_sources(response):
|
173 |
-
formatted_response = {
|
174 |
-
'answer': response.response,
|
175 |
-
'sources': get_response_sources(response)
|
176 |
-
}
|
177 |
-
return formatted_response
|
178 |
-
|
179 |
-
def test_rag_system(query_engine, test_questions):
|
180 |
-
results = []
|
181 |
-
|
182 |
-
for question in test_questions:
|
183 |
-
print(f"Question: {question}")
|
184 |
-
response = query_documents(query_engine, question)
|
185 |
-
formatted_response = format_response_with_sources(response)
|
186 |
-
|
187 |
-
print(f"Answer: {formatted_response['answer']}")
|
188 |
-
print("Sources:")
|
189 |
-
for source in formatted_response['sources']:
|
190 |
-
print(f" - Chunk {source['chunk_number']}: {source['document_id']}")
|
191 |
-
print(f" Section: {source['section']}, Subsection: {source['subsection']}")
|
192 |
-
print(f" Preview: {source['text_preview']}")
|
193 |
-
print("=" * 80)
|
194 |
-
|
195 |
-
results.append({
|
196 |
-
'question': question,
|
197 |
-
'response': formatted_response
|
198 |
-
})
|
199 |
-
|
200 |
return results
|
|
|
1 |
+
from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings
|
2 |
+
from llama_index.vector_stores.faiss import FaissVectorStore
|
3 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
4 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
5 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
6 |
+
from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
|
7 |
+
from script.document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks
|
8 |
+
import pandas as pd
|
9 |
+
import faiss
|
10 |
+
import pickle
|
11 |
+
import os
|
12 |
+
|
13 |
+
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
14 |
+
RETRIEVER_TOP_K = 10
|
15 |
+
RETRIEVER_SIMILARITY_CUTOFF = 0.7
|
16 |
+
RAG_FILES_DIR = "processed_data"
|
17 |
+
PROCESSED_DATA_FILE = "processed_data/processed_chunks.csv"
|
18 |
+
|
19 |
+
def setup_llm_settings():
|
20 |
+
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
|
21 |
+
Settings.embed_model = embed_model
|
22 |
+
|
23 |
+
def create_vector_index_with_faiss(documents):
|
24 |
+
setup_llm_settings()
|
25 |
+
|
26 |
+
d = 384
|
27 |
+
faiss_index = faiss.IndexFlatIP(d)
|
28 |
+
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
29 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
30 |
+
|
31 |
+
index = VectorStoreIndex.from_documents(
|
32 |
+
documents,
|
33 |
+
storage_context=storage_context
|
34 |
+
)
|
35 |
+
|
36 |
+
return index, faiss_index
|
37 |
+
|
38 |
+
def create_retriever(index):
|
39 |
+
return VectorIndexRetriever(
|
40 |
+
index=index,
|
41 |
+
similarity_top_k=RETRIEVER_TOP_K,
|
42 |
+
similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
|
43 |
+
)
|
44 |
+
|
45 |
+
def create_response_synthesizer():
|
46 |
+
return get_response_synthesizer(
|
47 |
+
response_mode=ResponseMode.TREE_SUMMARIZE,
|
48 |
+
streaming=False
|
49 |
+
)
|
50 |
+
|
51 |
+
def create_query_engine(index):
|
52 |
+
retriever = create_retriever(index)
|
53 |
+
response_synthesizer = create_response_synthesizer()
|
54 |
+
|
55 |
+
return RetrieverQueryEngine(
|
56 |
+
retriever=retriever,
|
57 |
+
response_synthesizer=response_synthesizer
|
58 |
+
)
|
59 |
+
|
60 |
+
def save_rag_system(index, faiss_index, documents):
|
61 |
+
os.makedirs(RAG_FILES_DIR, exist_ok=True)
|
62 |
+
|
63 |
+
faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
|
64 |
+
|
65 |
+
index.storage_context.persist(persist_dir=RAG_FILES_DIR)
|
66 |
+
|
67 |
+
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
|
68 |
+
pickle.dump(documents, f)
|
69 |
+
|
70 |
+
metadata_dict = {}
|
71 |
+
for doc in documents:
|
72 |
+
metadata_dict[doc.id_] = doc.metadata
|
73 |
+
|
74 |
+
with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
|
75 |
+
pickle.dump(metadata_dict, f)
|
76 |
+
|
77 |
+
config = {
|
78 |
+
'embed_model_name': EMBEDDING_MODEL,
|
79 |
+
'vector_dim': 384,
|
80 |
+
'total_documents': len(documents),
|
81 |
+
'index_type': 'faiss_flat_ip'
|
82 |
+
}
|
83 |
+
|
84 |
+
with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
|
85 |
+
pickle.dump(config, f)
|
86 |
+
|
87 |
+
def load_rag_system():
|
88 |
+
if not os.path.exists(os.path.join(RAG_FILES_DIR, 'faiss_index.index')):
|
89 |
+
return None
|
90 |
+
|
91 |
+
try:
|
92 |
+
setup_llm_settings()
|
93 |
+
|
94 |
+
faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
|
95 |
+
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
96 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
97 |
+
|
98 |
+
index = VectorStoreIndex.from_documents([], storage_context=storage_context)
|
99 |
+
|
100 |
+
with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'rb') as f:
|
101 |
+
documents = pickle.load(f)
|
102 |
+
|
103 |
+
for doc in documents:
|
104 |
+
index.insert(doc)
|
105 |
+
|
106 |
+
query_engine = create_query_engine(index)
|
107 |
+
return query_engine
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
print(f"Error loading RAG system: {str(e)}")
|
111 |
+
return None
|
112 |
+
|
113 |
+
def build_rag_system(processed_chunks):
|
114 |
+
setup_llm_settings()
|
115 |
+
|
116 |
+
documents = create_llama_documents(processed_chunks)
|
117 |
+
print(f"Created {len(documents)} documents for RAG system")
|
118 |
+
|
119 |
+
index, faiss_index = create_vector_index_with_faiss(documents)
|
120 |
+
query_engine = create_query_engine(index)
|
121 |
+
|
122 |
+
save_rag_system(index, faiss_index, documents)
|
123 |
+
|
124 |
+
return query_engine
|
125 |
+
|
126 |
+
def add_new_document_to_system(file_path, existing_query_engine):
|
127 |
+
try:
|
128 |
+
new_chunks = process_single_document(file_path)
|
129 |
+
|
130 |
+
if not new_chunks:
|
131 |
+
return existing_query_engine
|
132 |
+
|
133 |
+
if os.path.exists(PROCESSED_DATA_FILE):
|
134 |
+
existing_df = load_processed_chunks(PROCESSED_DATA_FILE)
|
135 |
+
existing_chunks = existing_df.to_dict('records')
|
136 |
+
else:
|
137 |
+
existing_chunks = []
|
138 |
+
|
139 |
+
all_chunks = existing_chunks + new_chunks
|
140 |
+
save_processed_chunks(all_chunks, PROCESSED_DATA_FILE)
|
141 |
+
|
142 |
+
query_engine = build_rag_system(all_chunks)
|
143 |
+
|
144 |
+
print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}")
|
145 |
+
return query_engine
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Error adding new document: {str(e)}")
|
149 |
+
return existing_query_engine
|
150 |
+
|
151 |
+
def query_documents(query_engine, question):
|
152 |
+
response = query_engine.query(question)
|
153 |
+
return response
|
154 |
+
|
155 |
+
def get_response_sources(response):
|
156 |
+
sources = []
|
157 |
+
for i, node in enumerate(response.source_nodes):
|
158 |
+
source_info = {
|
159 |
+
'chunk_number': i + 1,
|
160 |
+
'section': node.metadata.get('section', 'Не указан'),
|
161 |
+
'subsection': node.metadata.get('subsection', 'Не указан'),
|
162 |
+
'chunk_id': node.metadata.get('chunk_id', 'Не указан'),
|
163 |
+
'document_id': node.metadata.get('document_id', 'Не указан'),
|
164 |
+
'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'),
|
165 |
+
'file_link': node.metadata.get('file_link', 'Не указан'),
|
166 |
+
'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text,
|
167 |
+
'score': getattr(node, 'score', 0.0)
|
168 |
+
}
|
169 |
+
sources.append(source_info)
|
170 |
+
return sources
|
171 |
+
|
172 |
+
def format_response_with_sources(response):
|
173 |
+
formatted_response = {
|
174 |
+
'answer': response.response,
|
175 |
+
'sources': get_response_sources(response)
|
176 |
+
}
|
177 |
+
return formatted_response
|
178 |
+
|
179 |
+
def test_rag_system(query_engine, test_questions):
|
180 |
+
results = []
|
181 |
+
|
182 |
+
for question in test_questions:
|
183 |
+
print(f"Question: {question}")
|
184 |
+
response = query_documents(query_engine, question)
|
185 |
+
formatted_response = format_response_with_sources(response)
|
186 |
+
|
187 |
+
print(f"Answer: {formatted_response['answer']}")
|
188 |
+
print("Sources:")
|
189 |
+
for source in formatted_response['sources']:
|
190 |
+
print(f" - Chunk {source['chunk_number']}: {source['document_id']}")
|
191 |
+
print(f" Section: {source['section']}, Subsection: {source['subsection']}")
|
192 |
+
print(f" Preview: {source['text_preview']}")
|
193 |
+
print("=" * 80)
|
194 |
+
|
195 |
+
results.append({
|
196 |
+
'question': question,
|
197 |
+
'response': formatted_response
|
198 |
+
})
|
199 |
+
|
200 |
return results
|