Spaces:
Runtime error
Runtime error
Create utils/rag_system.py
Browse files- utils/rag_system.py +322 -0
utils/rag_system.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Enhanced RAG (Retrieval-Augmented Generation) System
|
3 |
+
for Power Systems Knowledge Base
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
from typing import Dict, List, Tuple, Optional
|
9 |
+
import pandas as pd
|
10 |
+
from datetime import datetime
|
11 |
+
import os
|
12 |
+
|
13 |
+
class EnhancedRAGSystem:
|
14 |
+
"""
|
15 |
+
Advanced RAG system with semantic search and context ranking
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, knowledge_base_path: str = 'data/knowledge_base.json'):
|
19 |
+
self.knowledge_base_path = knowledge_base_path
|
20 |
+
self.knowledge_base = self.load_knowledge_base()
|
21 |
+
self.indexed_content = self.create_search_index()
|
22 |
+
|
23 |
+
def load_knowledge_base(self) -> Dict:
|
24 |
+
"""Load the power systems knowledge base"""
|
25 |
+
try:
|
26 |
+
with open(self.knowledge_base_path, 'r', encoding='utf-8') as f:
|
27 |
+
return json.load(f)
|
28 |
+
except FileNotFoundError:
|
29 |
+
print(f"Knowledge base not found at {self.knowledge_base_path}")
|
30 |
+
return self.get_fallback_knowledge_base()
|
31 |
+
|
32 |
+
def get_fallback_knowledge_base(self) -> Dict:
|
33 |
+
"""Fallback knowledge base if file is not found"""
|
34 |
+
return {
|
35 |
+
"faults": {
|
36 |
+
"symmetrical": "Three-phase faults with balanced conditions",
|
37 |
+
"unsymmetrical": "Single-phase or two-phase faults"
|
38 |
+
},
|
39 |
+
"protection": {
|
40 |
+
"overcurrent": "Current-based protection schemes",
|
41 |
+
"differential": "Current comparison protection"
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
def create_search_index(self) -> List[Dict]:
|
46 |
+
"""Create searchable index from knowledge base"""
|
47 |
+
indexed_items = []
|
48 |
+
|
49 |
+
def index_recursive(data, path="", category=""):
|
50 |
+
if isinstance(data, dict):
|
51 |
+
for key, value in data.items():
|
52 |
+
current_path = f"{path}.{key}" if path else key
|
53 |
+
current_category = category or key
|
54 |
+
|
55 |
+
if isinstance(value, (str, int, float)):
|
56 |
+
indexed_items.append({
|
57 |
+
'path': current_path,
|
58 |
+
'category': current_category,
|
59 |
+
'key': key,
|
60 |
+
'content': str(value),
|
61 |
+
'keywords': self.extract_keywords(f"{key} {value}")
|
62 |
+
})
|
63 |
+
else:
|
64 |
+
index_recursive(value, current_path, current_category)
|
65 |
+
elif isinstance(data, list):
|
66 |
+
for i, item in enumerate(data):
|
67 |
+
index_recursive(item, f"{path}[{i}]", category)
|
68 |
+
|
69 |
+
index_recursive(self.knowledge_base)
|
70 |
+
return indexed_items
|
71 |
+
|
72 |
+
def extract_keywords(self, text: str) -> List[str]:
|
73 |
+
"""Extract keywords from text for better matching"""
|
74 |
+
# Convert to lowercase and split
|
75 |
+
words = re.findall(r'\b\w+\b', text.lower())
|
76 |
+
|
77 |
+
# Remove common stop words
|
78 |
+
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at',
|
79 |
+
'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were'}
|
80 |
+
|
81 |
+
keywords = [word for word in words if word not in stop_words and len(word) > 2]
|
82 |
+
return keywords
|
83 |
+
|
84 |
+
def semantic_search(self, query: str, top_k: int = 5) -> List[Dict]:
|
85 |
+
"""Perform semantic search on the knowledge base"""
|
86 |
+
query_keywords = self.extract_keywords(query)
|
87 |
+
scored_results = []
|
88 |
+
|
89 |
+
for item in self.indexed_content:
|
90 |
+
score = self.calculate_relevance_score(query_keywords, item)
|
91 |
+
if score > 0:
|
92 |
+
scored_results.append({
|
93 |
+
**item,
|
94 |
+
'relevance_score': score,
|
95 |
+
'matched_keywords': self.get_matched_keywords(query_keywords, item['keywords'])
|
96 |
+
})
|
97 |
+
|
98 |
+
# Sort by relevance score
|
99 |
+
scored_results.sort(key=lambda x: x['relevance_score'], reverse=True)
|
100 |
+
return scored_results[:top_k]
|
101 |
+
|
102 |
+
def calculate_relevance_score(self, query_keywords: List[str], item: Dict) -> float:
|
103 |
+
"""Calculate relevance score between query and item"""
|
104 |
+
item_keywords = item['keywords']
|
105 |
+
item_text = f"{item['key']} {item['content']}".lower()
|
106 |
+
|
107 |
+
score = 0.0
|
108 |
+
|
109 |
+
# Exact keyword matches
|
110 |
+
for keyword in query_keywords:
|
111 |
+
if keyword in item_keywords:
|
112 |
+
score += 2.0
|
113 |
+
elif keyword in item_text:
|
114 |
+
score += 1.0
|
115 |
+
|
116 |
+
# Category boost for relevant topics
|
117 |
+
category_boost = {
|
118 |
+
'fault': 1.5, 'protection': 1.5, 'standard': 1.3,
|
119 |
+
'power': 1.2, 'analysis': 1.2, 'calculation': 1.3
|
120 |
+
}
|
121 |
+
|
122 |
+
for boost_term, boost_value in category_boost.items():
|
123 |
+
if boost_term in item['category'].lower():
|
124 |
+
for keyword in query_keywords:
|
125 |
+
if boost_term in keyword:
|
126 |
+
score *= boost_value
|
127 |
+
break
|
128 |
+
|
129 |
+
# Length normalization
|
130 |
+
if len(item_keywords) > 0:
|
131 |
+
score = score / (1 + len(item_keywords) * 0.1)
|
132 |
+
|
133 |
+
return score
|
134 |
+
|
135 |
+
def get_matched_keywords(self, query_keywords: List[str], item_keywords: List[str]) -> List[str]:
|
136 |
+
"""Get keywords that matched between query and item"""
|
137 |
+
return [kw for kw in query_keywords if kw in item_keywords]
|
138 |
+
|
139 |
+
def retrieve_context(self, query: str, max_context_length: int = 1000) -> str:
|
140 |
+
"""Retrieve relevant context for the query"""
|
141 |
+
relevant_items = self.semantic_search(query, top_k=10)
|
142 |
+
|
143 |
+
if not relevant_items:
|
144 |
+
return "No specific context found in knowledge base."
|
145 |
+
|
146 |
+
context_parts = []
|
147 |
+
total_length = 0
|
148 |
+
|
149 |
+
for item in relevant_items:
|
150 |
+
context_part = f"**{item['category']} - {item['key']}**: {item['content']}"
|
151 |
+
|
152 |
+
if total_length + len(context_part) < max_context_length:
|
153 |
+
context_parts.append(context_part)
|
154 |
+
total_length += len(context_part)
|
155 |
+
else:
|
156 |
+
break
|
157 |
+
|
158 |
+
return "\n\n".join(context_parts)
|
159 |
+
|
160 |
+
def get_topic_overview(self, topic: str) -> str:
|
161 |
+
"""Get comprehensive overview of a specific topic"""
|
162 |
+
topic_items = []
|
163 |
+
|
164 |
+
for item in self.indexed_content:
|
165 |
+
if topic.lower() in item['category'].lower() or topic.lower() in item['key'].lower():
|
166 |
+
topic_items.append(item)
|
167 |
+
|
168 |
+
if not topic_items:
|
169 |
+
return f"No information found for topic: {topic}"
|
170 |
+
|
171 |
+
# Group by category
|
172 |
+
categories = {}
|
173 |
+
for item in topic_items:
|
174 |
+
category = item['category']
|
175 |
+
if category not in categories:
|
176 |
+
categories[category] = []
|
177 |
+
categories[category].append(item)
|
178 |
+
|
179 |
+
overview_parts = []
|
180 |
+
for category, items in categories.items():
|
181 |
+
overview_parts.append(f"## {category.title()}")
|
182 |
+
for item in items[:5]: # Limit items per category
|
183 |
+
overview_parts.append(f"- **{item['key']}**: {item['content'][:200]}...")
|
184 |
+
|
185 |
+
return "\n".join(overview_parts)
|
186 |
+
|
187 |
+
def suggest_related_topics(self, query: str) -> List[str]:
|
188 |
+
"""Suggest related topics based on the query"""
|
189 |
+
relevant_items = self.semantic_search(query, top_k=15)
|
190 |
+
categories = set()
|
191 |
+
|
192 |
+
for item in relevant_items:
|
193 |
+
categories.add(item['category'])
|
194 |
+
|
195 |
+
return list(categories)[:5]
|
196 |
+
|
197 |
+
def get_formulas_for_topic(self, topic: str) -> List[str]:
|
198 |
+
"""Extract formulas related to a specific topic"""
|
199 |
+
formulas = []
|
200 |
+
|
201 |
+
# Search in formulas section
|
202 |
+
if 'formulas' in self.knowledge_base:
|
203 |
+
formulas_data = self.knowledge_base['formulas']
|
204 |
+
for category, formulas_dict in formulas_data.items():
|
205 |
+
if topic.lower() in category.lower():
|
206 |
+
if isinstance(formulas_dict, dict):
|
207 |
+
for formula_name, formula in formulas_dict.items():
|
208 |
+
formulas.append(f"**{formula_name}**: {formula}")
|
209 |
+
|
210 |
+
# Search in general content for formula patterns
|
211 |
+
formula_patterns = [
|
212 |
+
r'[A-Z]_[a-z]+ = [^.]+',
|
213 |
+
r'[A-Z] = [^.]+',
|
214 |
+
r'I_fault = [^.]+',
|
215 |
+
r'V_[a-z]+ = [^.]+',
|
216 |
+
r'Z_[a-z]+ = [^.]+',
|
217 |
+
r'P = [^.]+',
|
218 |
+
r'Q = [^.]+',
|
219 |
+
]
|
220 |
+
|
221 |
+
for item in self.indexed_content:
|
222 |
+
if topic.lower() in item['content'].lower():
|
223 |
+
for pattern in formula_patterns:
|
224 |
+
matches = re.findall(pattern, item['content'])
|
225 |
+
formulas.extend(matches)
|
226 |
+
|
227 |
+
return list(set(formulas))[:10] # Remove duplicates and limit
|
228 |
+
|
229 |
+
def update_knowledge_base(self, new_data: Dict, category: str):
|
230 |
+
"""Update knowledge base with new information"""
|
231 |
+
if category in self.knowledge_base:
|
232 |
+
self.knowledge_base[category].update(new_data)
|
233 |
+
else:
|
234 |
+
self.knowledge_base[category] = new_data
|
235 |
+
|
236 |
+
# Recreate search index
|
237 |
+
self.indexed_content = self.create_search_index()
|
238 |
+
|
239 |
+
# Save updated knowledge base
|
240 |
+
try:
|
241 |
+
with open(self.knowledge_base_path, 'w', encoding='utf-8') as f:
|
242 |
+
json.dump(self.knowledge_base, f, indent=2)
|
243 |
+
except Exception as e:
|
244 |
+
print(f"Error saving knowledge base: {e}")
|
245 |
+
|
246 |
+
def get_statistics(self) -> Dict:
|
247 |
+
"""Get statistics about the knowledge base"""
|
248 |
+
stats = {
|
249 |
+
'total_entries': len(self.indexed_content),
|
250 |
+
'categories': len(set(item['category'] for item in self.indexed_content)),
|
251 |
+
'total_keywords': sum(len(item['keywords']) for item in self.indexed_content),
|
252 |
+
'last_updated': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
253 |
+
}
|
254 |
+
|
255 |
+
# Category breakdown
|
256 |
+
category_counts = {}
|
257 |
+
for item in self.indexed_content:
|
258 |
+
category = item['category']
|
259 |
+
category_counts[category] = category_counts.get(category, 0) + 1
|
260 |
+
|
261 |
+
stats['category_breakdown'] = category_counts
|
262 |
+
return stats
|
263 |
+
|
264 |
+
def export_context_report(self, query: str, filename: str = None) -> str:
|
265 |
+
"""Export detailed context report for a query"""
|
266 |
+
if filename is None:
|
267 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
268 |
+
filename = f"context_report_{timestamp}.md"
|
269 |
+
|
270 |
+
relevant_items = self.semantic_search(query, top_k=20)
|
271 |
+
|
272 |
+
report_content = f"""# Context Report for Query: "{query}"
|
273 |
+
|
274 |
+
Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
275 |
+
|
276 |
+
## Search Results ({len(relevant_items)} items found)
|
277 |
+
|
278 |
+
"""
|
279 |
+
|
280 |
+
for i, item in enumerate(relevant_items, 1):
|
281 |
+
report_content += f"""### {i}. {item['category']} - {item['key']}
|
282 |
+
- **Content**: {item['content']}
|
283 |
+
- **Relevance Score**: {item['relevance_score']:.2f}
|
284 |
+
- **Matched Keywords**: {', '.join(item['matched_keywords'])}
|
285 |
+
|
286 |
+
"""
|
287 |
+
|
288 |
+
# Save report
|
289 |
+
try:
|
290 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
291 |
+
f.write(report_content)
|
292 |
+
return f"Context report saved to {filename}"
|
293 |
+
except Exception as e:
|
294 |
+
return f"Error saving report: {e}"
|
295 |
+
|
296 |
+
# Example usage and testing
|
297 |
+
if __name__ == "__main__":
|
298 |
+
# Test the RAG system
|
299 |
+
rag = EnhancedRAGSystem()
|
300 |
+
|
301 |
+
# Test queries
|
302 |
+
test_queries = [
|
303 |
+
"fault analysis",
|
304 |
+
"IEEE standards",
|
305 |
+
"protection systems",
|
306 |
+
"short circuit calculation",
|
307 |
+
"transformer protection"
|
308 |
+
]
|
309 |
+
|
310 |
+
for query in test_queries:
|
311 |
+
print(f"\nQuery: {query}")
|
312 |
+
context = rag.retrieve_context(query)
|
313 |
+
print(f"Context: {context[:200]}...")
|
314 |
+
|
315 |
+
related_topics = rag.suggest_related_topics(query)
|
316 |
+
print(f"Related topics: {related_topics}")
|
317 |
+
|
318 |
+
# Print statistics
|
319 |
+
stats = rag.get_statistics()
|
320 |
+
print(f"\nKnowledge Base Statistics:")
|
321 |
+
for key, value in stats.items():
|
322 |
+
print(f" {key}: {value}")
|