ashhal commited on
Commit
20f2352
·
verified ·
1 Parent(s): 1cb30a0

Create utils/rag_system.py

Browse files
Files changed (1) hide show
  1. utils/rag_system.py +322 -0
utils/rag_system.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced RAG (Retrieval-Augmented Generation) System
3
+ for Power Systems Knowledge Base
4
+ """
5
+
6
+ import json
7
+ import re
8
+ from typing import Dict, List, Tuple, Optional
9
+ import pandas as pd
10
+ from datetime import datetime
11
+ import os
12
+
13
+ class EnhancedRAGSystem:
14
+ """
15
+ Advanced RAG system with semantic search and context ranking
16
+ """
17
+
18
+ def __init__(self, knowledge_base_path: str = 'data/knowledge_base.json'):
19
+ self.knowledge_base_path = knowledge_base_path
20
+ self.knowledge_base = self.load_knowledge_base()
21
+ self.indexed_content = self.create_search_index()
22
+
23
+ def load_knowledge_base(self) -> Dict:
24
+ """Load the power systems knowledge base"""
25
+ try:
26
+ with open(self.knowledge_base_path, 'r', encoding='utf-8') as f:
27
+ return json.load(f)
28
+ except FileNotFoundError:
29
+ print(f"Knowledge base not found at {self.knowledge_base_path}")
30
+ return self.get_fallback_knowledge_base()
31
+
32
+ def get_fallback_knowledge_base(self) -> Dict:
33
+ """Fallback knowledge base if file is not found"""
34
+ return {
35
+ "faults": {
36
+ "symmetrical": "Three-phase faults with balanced conditions",
37
+ "unsymmetrical": "Single-phase or two-phase faults"
38
+ },
39
+ "protection": {
40
+ "overcurrent": "Current-based protection schemes",
41
+ "differential": "Current comparison protection"
42
+ }
43
+ }
44
+
45
+ def create_search_index(self) -> List[Dict]:
46
+ """Create searchable index from knowledge base"""
47
+ indexed_items = []
48
+
49
+ def index_recursive(data, path="", category=""):
50
+ if isinstance(data, dict):
51
+ for key, value in data.items():
52
+ current_path = f"{path}.{key}" if path else key
53
+ current_category = category or key
54
+
55
+ if isinstance(value, (str, int, float)):
56
+ indexed_items.append({
57
+ 'path': current_path,
58
+ 'category': current_category,
59
+ 'key': key,
60
+ 'content': str(value),
61
+ 'keywords': self.extract_keywords(f"{key} {value}")
62
+ })
63
+ else:
64
+ index_recursive(value, current_path, current_category)
65
+ elif isinstance(data, list):
66
+ for i, item in enumerate(data):
67
+ index_recursive(item, f"{path}[{i}]", category)
68
+
69
+ index_recursive(self.knowledge_base)
70
+ return indexed_items
71
+
72
+ def extract_keywords(self, text: str) -> List[str]:
73
+ """Extract keywords from text for better matching"""
74
+ # Convert to lowercase and split
75
+ words = re.findall(r'\b\w+\b', text.lower())
76
+
77
+ # Remove common stop words
78
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at',
79
+ 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were'}
80
+
81
+ keywords = [word for word in words if word not in stop_words and len(word) > 2]
82
+ return keywords
83
+
84
+ def semantic_search(self, query: str, top_k: int = 5) -> List[Dict]:
85
+ """Perform semantic search on the knowledge base"""
86
+ query_keywords = self.extract_keywords(query)
87
+ scored_results = []
88
+
89
+ for item in self.indexed_content:
90
+ score = self.calculate_relevance_score(query_keywords, item)
91
+ if score > 0:
92
+ scored_results.append({
93
+ **item,
94
+ 'relevance_score': score,
95
+ 'matched_keywords': self.get_matched_keywords(query_keywords, item['keywords'])
96
+ })
97
+
98
+ # Sort by relevance score
99
+ scored_results.sort(key=lambda x: x['relevance_score'], reverse=True)
100
+ return scored_results[:top_k]
101
+
102
+ def calculate_relevance_score(self, query_keywords: List[str], item: Dict) -> float:
103
+ """Calculate relevance score between query and item"""
104
+ item_keywords = item['keywords']
105
+ item_text = f"{item['key']} {item['content']}".lower()
106
+
107
+ score = 0.0
108
+
109
+ # Exact keyword matches
110
+ for keyword in query_keywords:
111
+ if keyword in item_keywords:
112
+ score += 2.0
113
+ elif keyword in item_text:
114
+ score += 1.0
115
+
116
+ # Category boost for relevant topics
117
+ category_boost = {
118
+ 'fault': 1.5, 'protection': 1.5, 'standard': 1.3,
119
+ 'power': 1.2, 'analysis': 1.2, 'calculation': 1.3
120
+ }
121
+
122
+ for boost_term, boost_value in category_boost.items():
123
+ if boost_term in item['category'].lower():
124
+ for keyword in query_keywords:
125
+ if boost_term in keyword:
126
+ score *= boost_value
127
+ break
128
+
129
+ # Length normalization
130
+ if len(item_keywords) > 0:
131
+ score = score / (1 + len(item_keywords) * 0.1)
132
+
133
+ return score
134
+
135
+ def get_matched_keywords(self, query_keywords: List[str], item_keywords: List[str]) -> List[str]:
136
+ """Get keywords that matched between query and item"""
137
+ return [kw for kw in query_keywords if kw in item_keywords]
138
+
139
+ def retrieve_context(self, query: str, max_context_length: int = 1000) -> str:
140
+ """Retrieve relevant context for the query"""
141
+ relevant_items = self.semantic_search(query, top_k=10)
142
+
143
+ if not relevant_items:
144
+ return "No specific context found in knowledge base."
145
+
146
+ context_parts = []
147
+ total_length = 0
148
+
149
+ for item in relevant_items:
150
+ context_part = f"**{item['category']} - {item['key']}**: {item['content']}"
151
+
152
+ if total_length + len(context_part) < max_context_length:
153
+ context_parts.append(context_part)
154
+ total_length += len(context_part)
155
+ else:
156
+ break
157
+
158
+ return "\n\n".join(context_parts)
159
+
160
+ def get_topic_overview(self, topic: str) -> str:
161
+ """Get comprehensive overview of a specific topic"""
162
+ topic_items = []
163
+
164
+ for item in self.indexed_content:
165
+ if topic.lower() in item['category'].lower() or topic.lower() in item['key'].lower():
166
+ topic_items.append(item)
167
+
168
+ if not topic_items:
169
+ return f"No information found for topic: {topic}"
170
+
171
+ # Group by category
172
+ categories = {}
173
+ for item in topic_items:
174
+ category = item['category']
175
+ if category not in categories:
176
+ categories[category] = []
177
+ categories[category].append(item)
178
+
179
+ overview_parts = []
180
+ for category, items in categories.items():
181
+ overview_parts.append(f"## {category.title()}")
182
+ for item in items[:5]: # Limit items per category
183
+ overview_parts.append(f"- **{item['key']}**: {item['content'][:200]}...")
184
+
185
+ return "\n".join(overview_parts)
186
+
187
+ def suggest_related_topics(self, query: str) -> List[str]:
188
+ """Suggest related topics based on the query"""
189
+ relevant_items = self.semantic_search(query, top_k=15)
190
+ categories = set()
191
+
192
+ for item in relevant_items:
193
+ categories.add(item['category'])
194
+
195
+ return list(categories)[:5]
196
+
197
+ def get_formulas_for_topic(self, topic: str) -> List[str]:
198
+ """Extract formulas related to a specific topic"""
199
+ formulas = []
200
+
201
+ # Search in formulas section
202
+ if 'formulas' in self.knowledge_base:
203
+ formulas_data = self.knowledge_base['formulas']
204
+ for category, formulas_dict in formulas_data.items():
205
+ if topic.lower() in category.lower():
206
+ if isinstance(formulas_dict, dict):
207
+ for formula_name, formula in formulas_dict.items():
208
+ formulas.append(f"**{formula_name}**: {formula}")
209
+
210
+ # Search in general content for formula patterns
211
+ formula_patterns = [
212
+ r'[A-Z]_[a-z]+ = [^.]+',
213
+ r'[A-Z] = [^.]+',
214
+ r'I_fault = [^.]+',
215
+ r'V_[a-z]+ = [^.]+',
216
+ r'Z_[a-z]+ = [^.]+',
217
+ r'P = [^.]+',
218
+ r'Q = [^.]+',
219
+ ]
220
+
221
+ for item in self.indexed_content:
222
+ if topic.lower() in item['content'].lower():
223
+ for pattern in formula_patterns:
224
+ matches = re.findall(pattern, item['content'])
225
+ formulas.extend(matches)
226
+
227
+ return list(set(formulas))[:10] # Remove duplicates and limit
228
+
229
+ def update_knowledge_base(self, new_data: Dict, category: str):
230
+ """Update knowledge base with new information"""
231
+ if category in self.knowledge_base:
232
+ self.knowledge_base[category].update(new_data)
233
+ else:
234
+ self.knowledge_base[category] = new_data
235
+
236
+ # Recreate search index
237
+ self.indexed_content = self.create_search_index()
238
+
239
+ # Save updated knowledge base
240
+ try:
241
+ with open(self.knowledge_base_path, 'w', encoding='utf-8') as f:
242
+ json.dump(self.knowledge_base, f, indent=2)
243
+ except Exception as e:
244
+ print(f"Error saving knowledge base: {e}")
245
+
246
+ def get_statistics(self) -> Dict:
247
+ """Get statistics about the knowledge base"""
248
+ stats = {
249
+ 'total_entries': len(self.indexed_content),
250
+ 'categories': len(set(item['category'] for item in self.indexed_content)),
251
+ 'total_keywords': sum(len(item['keywords']) for item in self.indexed_content),
252
+ 'last_updated': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
253
+ }
254
+
255
+ # Category breakdown
256
+ category_counts = {}
257
+ for item in self.indexed_content:
258
+ category = item['category']
259
+ category_counts[category] = category_counts.get(category, 0) + 1
260
+
261
+ stats['category_breakdown'] = category_counts
262
+ return stats
263
+
264
+ def export_context_report(self, query: str, filename: str = None) -> str:
265
+ """Export detailed context report for a query"""
266
+ if filename is None:
267
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
268
+ filename = f"context_report_{timestamp}.md"
269
+
270
+ relevant_items = self.semantic_search(query, top_k=20)
271
+
272
+ report_content = f"""# Context Report for Query: "{query}"
273
+
274
+ Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
275
+
276
+ ## Search Results ({len(relevant_items)} items found)
277
+
278
+ """
279
+
280
+ for i, item in enumerate(relevant_items, 1):
281
+ report_content += f"""### {i}. {item['category']} - {item['key']}
282
+ - **Content**: {item['content']}
283
+ - **Relevance Score**: {item['relevance_score']:.2f}
284
+ - **Matched Keywords**: {', '.join(item['matched_keywords'])}
285
+
286
+ """
287
+
288
+ # Save report
289
+ try:
290
+ with open(filename, 'w', encoding='utf-8') as f:
291
+ f.write(report_content)
292
+ return f"Context report saved to {filename}"
293
+ except Exception as e:
294
+ return f"Error saving report: {e}"
295
+
296
+ # Example usage and testing
297
+ if __name__ == "__main__":
298
+ # Test the RAG system
299
+ rag = EnhancedRAGSystem()
300
+
301
+ # Test queries
302
+ test_queries = [
303
+ "fault analysis",
304
+ "IEEE standards",
305
+ "protection systems",
306
+ "short circuit calculation",
307
+ "transformer protection"
308
+ ]
309
+
310
+ for query in test_queries:
311
+ print(f"\nQuery: {query}")
312
+ context = rag.retrieve_context(query)
313
+ print(f"Context: {context[:200]}...")
314
+
315
+ related_topics = rag.suggest_related_topics(query)
316
+ print(f"Related topics: {related_topics}")
317
+
318
+ # Print statistics
319
+ stats = rag.get_statistics()
320
+ print(f"\nKnowledge Base Statistics:")
321
+ for key, value in stats.items():
322
+ print(f" {key}: {value}")