MrSimple01 commited on
Commit
36eab75
·
1 Parent(s): c9db278

Update scripts/rag_engine.py

Browse files
Files changed (1) hide show
  1. scripts/rag_engine.py +199 -199
scripts/rag_engine.py CHANGED
@@ -1,200 +1,200 @@
1
- from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings
2
- from llama_index.vector_stores.faiss import FaissVectorStore
3
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
4
- from llama_index.core.query_engine import RetrieverQueryEngine
5
- from llama_index.core.retrievers import VectorIndexRetriever
6
- from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
7
- from document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks
8
- import pandas as pd
9
- import faiss
10
- import pickle
11
- import os
12
-
13
- EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
14
- RETRIEVER_TOP_K = 10
15
- RETRIEVER_SIMILARITY_CUTOFF = 0.7
16
- RAG_FILES_DIR = "processed_data"
17
- PROCESSED_DATA_FILE = "processed_data/processed_chunks.csv"
18
-
19
- def setup_llm_settings():
20
- embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
21
- Settings.embed_model = embed_model
22
-
23
- def create_vector_index_with_faiss(documents):
24
- setup_llm_settings()
25
-
26
- d = 384
27
- faiss_index = faiss.IndexFlatIP(d)
28
- vector_store = FaissVectorStore(faiss_index=faiss_index)
29
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
30
-
31
- index = VectorStoreIndex.from_documents(
32
- documents,
33
- storage_context=storage_context
34
- )
35
-
36
- return index, faiss_index
37
-
38
- def create_retriever(index):
39
- return VectorIndexRetriever(
40
- index=index,
41
- similarity_top_k=RETRIEVER_TOP_K,
42
- similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
43
- )
44
-
45
- def create_response_synthesizer():
46
- return get_response_synthesizer(
47
- response_mode=ResponseMode.TREE_SUMMARIZE,
48
- streaming=False
49
- )
50
-
51
- def create_query_engine(index):
52
- retriever = create_retriever(index)
53
- response_synthesizer = create_response_synthesizer()
54
-
55
- return RetrieverQueryEngine(
56
- retriever=retriever,
57
- response_synthesizer=response_synthesizer
58
- )
59
-
60
- def save_rag_system(index, faiss_index, documents):
61
- os.makedirs(RAG_FILES_DIR, exist_ok=True)
62
-
63
- faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
64
-
65
- index.storage_context.persist(persist_dir=RAG_FILES_DIR)
66
-
67
- with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
68
- pickle.dump(documents, f)
69
-
70
- metadata_dict = {}
71
- for doc in documents:
72
- metadata_dict[doc.id_] = doc.metadata
73
-
74
- with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
75
- pickle.dump(metadata_dict, f)
76
-
77
- config = {
78
- 'embed_model_name': EMBEDDING_MODEL,
79
- 'vector_dim': 384,
80
- 'total_documents': len(documents),
81
- 'index_type': 'faiss_flat_ip'
82
- }
83
-
84
- with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
85
- pickle.dump(config, f)
86
-
87
- def load_rag_system():
88
- if not os.path.exists(os.path.join(RAG_FILES_DIR, 'faiss_index.index')):
89
- return None
90
-
91
- try:
92
- setup_llm_settings()
93
-
94
- faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
95
- vector_store = FaissVectorStore(faiss_index=faiss_index)
96
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
97
-
98
- index = VectorStoreIndex.from_documents([], storage_context=storage_context)
99
-
100
- with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'rb') as f:
101
- documents = pickle.load(f)
102
-
103
- for doc in documents:
104
- index.insert(doc)
105
-
106
- query_engine = create_query_engine(index)
107
- return query_engine
108
-
109
- except Exception as e:
110
- print(f"Error loading RAG system: {str(e)}")
111
- return None
112
-
113
- def build_rag_system(processed_chunks):
114
- setup_llm_settings()
115
-
116
- documents = create_llama_documents(processed_chunks)
117
- print(f"Created {len(documents)} documents for RAG system")
118
-
119
- index, faiss_index = create_vector_index_with_faiss(documents)
120
- query_engine = create_query_engine(index)
121
-
122
- save_rag_system(index, faiss_index, documents)
123
-
124
- return query_engine
125
-
126
- def add_new_document_to_system(file_path, existing_query_engine):
127
- try:
128
- new_chunks = process_single_document(file_path)
129
-
130
- if not new_chunks:
131
- return existing_query_engine
132
-
133
- if os.path.exists(PROCESSED_DATA_FILE):
134
- existing_df = load_processed_chunks(PROCESSED_DATA_FILE)
135
- existing_chunks = existing_df.to_dict('records')
136
- else:
137
- existing_chunks = []
138
-
139
- all_chunks = existing_chunks + new_chunks
140
- save_processed_chunks(all_chunks, PROCESSED_DATA_FILE)
141
-
142
- query_engine = build_rag_system(all_chunks)
143
-
144
- print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}")
145
- return query_engine
146
-
147
- except Exception as e:
148
- print(f"Error adding new document: {str(e)}")
149
- return existing_query_engine
150
-
151
- def query_documents(query_engine, question):
152
- response = query_engine.query(question)
153
- return response
154
-
155
- def get_response_sources(response):
156
- sources = []
157
- for i, node in enumerate(response.source_nodes):
158
- source_info = {
159
- 'chunk_number': i + 1,
160
- 'section': node.metadata.get('section', 'Не указан'),
161
- 'subsection': node.metadata.get('subsection', 'Не указан'),
162
- 'chunk_id': node.metadata.get('chunk_id', 'Не указан'),
163
- 'document_id': node.metadata.get('document_id', 'Не указан'),
164
- 'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'),
165
- 'file_link': node.metadata.get('file_link', 'Не указан'),
166
- 'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text,
167
- 'score': getattr(node, 'score', 0.0)
168
- }
169
- sources.append(source_info)
170
- return sources
171
-
172
- def format_response_with_sources(response):
173
- formatted_response = {
174
- 'answer': response.response,
175
- 'sources': get_response_sources(response)
176
- }
177
- return formatted_response
178
-
179
- def test_rag_system(query_engine, test_questions):
180
- results = []
181
-
182
- for question in test_questions:
183
- print(f"Question: {question}")
184
- response = query_documents(query_engine, question)
185
- formatted_response = format_response_with_sources(response)
186
-
187
- print(f"Answer: {formatted_response['answer']}")
188
- print("Sources:")
189
- for source in formatted_response['sources']:
190
- print(f" - Chunk {source['chunk_number']}: {source['document_id']}")
191
- print(f" Section: {source['section']}, Subsection: {source['subsection']}")
192
- print(f" Preview: {source['text_preview']}")
193
- print("=" * 80)
194
-
195
- results.append({
196
- 'question': question,
197
- 'response': formatted_response
198
- })
199
-
200
  return results
 
1
+ from llama_index.core import VectorStoreIndex, Document, StorageContext, Settings
2
+ from llama_index.vector_stores.faiss import FaissVectorStore
3
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
4
+ from llama_index.core.query_engine import RetrieverQueryEngine
5
+ from llama_index.core.retrievers import VectorIndexRetriever
6
+ from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
7
+ from script.document_processor import create_llama_documents, process_single_document, save_processed_chunks, load_processed_chunks
8
+ import pandas as pd
9
+ import faiss
10
+ import pickle
11
+ import os
12
+
13
+ EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
14
+ RETRIEVER_TOP_K = 10
15
+ RETRIEVER_SIMILARITY_CUTOFF = 0.7
16
+ RAG_FILES_DIR = "processed_data"
17
+ PROCESSED_DATA_FILE = "processed_data/processed_chunks.csv"
18
+
19
+ def setup_llm_settings():
20
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL)
21
+ Settings.embed_model = embed_model
22
+
23
+ def create_vector_index_with_faiss(documents):
24
+ setup_llm_settings()
25
+
26
+ d = 384
27
+ faiss_index = faiss.IndexFlatIP(d)
28
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
29
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
30
+
31
+ index = VectorStoreIndex.from_documents(
32
+ documents,
33
+ storage_context=storage_context
34
+ )
35
+
36
+ return index, faiss_index
37
+
38
+ def create_retriever(index):
39
+ return VectorIndexRetriever(
40
+ index=index,
41
+ similarity_top_k=RETRIEVER_TOP_K,
42
+ similarity_cutoff=RETRIEVER_SIMILARITY_CUTOFF
43
+ )
44
+
45
+ def create_response_synthesizer():
46
+ return get_response_synthesizer(
47
+ response_mode=ResponseMode.TREE_SUMMARIZE,
48
+ streaming=False
49
+ )
50
+
51
+ def create_query_engine(index):
52
+ retriever = create_retriever(index)
53
+ response_synthesizer = create_response_synthesizer()
54
+
55
+ return RetrieverQueryEngine(
56
+ retriever=retriever,
57
+ response_synthesizer=response_synthesizer
58
+ )
59
+
60
+ def save_rag_system(index, faiss_index, documents):
61
+ os.makedirs(RAG_FILES_DIR, exist_ok=True)
62
+
63
+ faiss.write_index(faiss_index, os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
64
+
65
+ index.storage_context.persist(persist_dir=RAG_FILES_DIR)
66
+
67
+ with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'wb') as f:
68
+ pickle.dump(documents, f)
69
+
70
+ metadata_dict = {}
71
+ for doc in documents:
72
+ metadata_dict[doc.id_] = doc.metadata
73
+
74
+ with open(os.path.join(RAG_FILES_DIR, 'chunk_metadata.pkl'), 'wb') as f:
75
+ pickle.dump(metadata_dict, f)
76
+
77
+ config = {
78
+ 'embed_model_name': EMBEDDING_MODEL,
79
+ 'vector_dim': 384,
80
+ 'total_documents': len(documents),
81
+ 'index_type': 'faiss_flat_ip'
82
+ }
83
+
84
+ with open(os.path.join(RAG_FILES_DIR, 'config.pkl'), 'wb') as f:
85
+ pickle.dump(config, f)
86
+
87
+ def load_rag_system():
88
+ if not os.path.exists(os.path.join(RAG_FILES_DIR, 'faiss_index.index')):
89
+ return None
90
+
91
+ try:
92
+ setup_llm_settings()
93
+
94
+ faiss_index = faiss.read_index(os.path.join(RAG_FILES_DIR, 'faiss_index.index'))
95
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
96
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
97
+
98
+ index = VectorStoreIndex.from_documents([], storage_context=storage_context)
99
+
100
+ with open(os.path.join(RAG_FILES_DIR, 'documents.pkl'), 'rb') as f:
101
+ documents = pickle.load(f)
102
+
103
+ for doc in documents:
104
+ index.insert(doc)
105
+
106
+ query_engine = create_query_engine(index)
107
+ return query_engine
108
+
109
+ except Exception as e:
110
+ print(f"Error loading RAG system: {str(e)}")
111
+ return None
112
+
113
+ def build_rag_system(processed_chunks):
114
+ setup_llm_settings()
115
+
116
+ documents = create_llama_documents(processed_chunks)
117
+ print(f"Created {len(documents)} documents for RAG system")
118
+
119
+ index, faiss_index = create_vector_index_with_faiss(documents)
120
+ query_engine = create_query_engine(index)
121
+
122
+ save_rag_system(index, faiss_index, documents)
123
+
124
+ return query_engine
125
+
126
+ def add_new_document_to_system(file_path, existing_query_engine):
127
+ try:
128
+ new_chunks = process_single_document(file_path)
129
+
130
+ if not new_chunks:
131
+ return existing_query_engine
132
+
133
+ if os.path.exists(PROCESSED_DATA_FILE):
134
+ existing_df = load_processed_chunks(PROCESSED_DATA_FILE)
135
+ existing_chunks = existing_df.to_dict('records')
136
+ else:
137
+ existing_chunks = []
138
+
139
+ all_chunks = existing_chunks + new_chunks
140
+ save_processed_chunks(all_chunks, PROCESSED_DATA_FILE)
141
+
142
+ query_engine = build_rag_system(all_chunks)
143
+
144
+ print(f"Added {len(new_chunks)} new chunks from {os.path.basename(file_path)}")
145
+ return query_engine
146
+
147
+ except Exception as e:
148
+ print(f"Error adding new document: {str(e)}")
149
+ return existing_query_engine
150
+
151
+ def query_documents(query_engine, question):
152
+ response = query_engine.query(question)
153
+ return response
154
+
155
+ def get_response_sources(response):
156
+ sources = []
157
+ for i, node in enumerate(response.source_nodes):
158
+ source_info = {
159
+ 'chunk_number': i + 1,
160
+ 'section': node.metadata.get('section', 'Не указан'),
161
+ 'subsection': node.metadata.get('subsection', 'Не указан'),
162
+ 'chunk_id': node.metadata.get('chunk_id', 'Не указан'),
163
+ 'document_id': node.metadata.get('document_id', 'Не указан'),
164
+ 'txt_file_id': node.metadata.get('txt_file_id', 'Не указан'),
165
+ 'file_link': node.metadata.get('file_link', 'Не указан'),
166
+ 'text_preview': node.text[:200] + "..." if len(node.text) > 200 else node.text,
167
+ 'score': getattr(node, 'score', 0.0)
168
+ }
169
+ sources.append(source_info)
170
+ return sources
171
+
172
+ def format_response_with_sources(response):
173
+ formatted_response = {
174
+ 'answer': response.response,
175
+ 'sources': get_response_sources(response)
176
+ }
177
+ return formatted_response
178
+
179
+ def test_rag_system(query_engine, test_questions):
180
+ results = []
181
+
182
+ for question in test_questions:
183
+ print(f"Question: {question}")
184
+ response = query_documents(query_engine, question)
185
+ formatted_response = format_response_with_sources(response)
186
+
187
+ print(f"Answer: {formatted_response['answer']}")
188
+ print("Sources:")
189
+ for source in formatted_response['sources']:
190
+ print(f" - Chunk {source['chunk_number']}: {source['document_id']}")
191
+ print(f" Section: {source['section']}, Subsection: {source['subsection']}")
192
+ print(f" Preview: {source['text_preview']}")
193
+ print("=" * 80)
194
+
195
+ results.append({
196
+ 'question': question,
197
+ 'response': formatted_response
198
+ })
199
+
200
  return results