Spaces:
Runtime error
Runtime error
""" | |
Enhanced RAG (Retrieval-Augmented Generation) System | |
for Power Systems Knowledge Base | |
""" | |
import json | |
import re | |
from typing import Dict, List, Tuple, Optional | |
import pandas as pd | |
from datetime import datetime | |
import os | |
class EnhancedRAGSystem: | |
""" | |
Advanced RAG system with semantic search and context ranking | |
""" | |
def __init__(self, knowledge_base_path: str = 'data/knowledge_base.json'): | |
self.knowledge_base_path = knowledge_base_path | |
self.knowledge_base = self.load_knowledge_base() | |
self.indexed_content = self.create_search_index() | |
def load_knowledge_base(self) -> Dict: | |
"""Load the power systems knowledge base""" | |
try: | |
with open(self.knowledge_base_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
except FileNotFoundError: | |
print(f"Knowledge base not found at {self.knowledge_base_path}") | |
return self.get_fallback_knowledge_base() | |
def get_fallback_knowledge_base(self) -> Dict: | |
"""Fallback knowledge base if file is not found""" | |
return { | |
"faults": { | |
"symmetrical": "Three-phase faults with balanced conditions", | |
"unsymmetrical": "Single-phase or two-phase faults" | |
}, | |
"protection": { | |
"overcurrent": "Current-based protection schemes", | |
"differential": "Current comparison protection" | |
} | |
} | |
def create_search_index(self) -> List[Dict]: | |
"""Create searchable index from knowledge base""" | |
indexed_items = [] | |
def index_recursive(data, path="", category=""): | |
if isinstance(data, dict): | |
for key, value in data.items(): | |
current_path = f"{path}.{key}" if path else key | |
current_category = category or key | |
if isinstance(value, (str, int, float)): | |
indexed_items.append({ | |
'path': current_path, | |
'category': current_category, | |
'key': key, | |
'content': str(value), | |
'keywords': self.extract_keywords(f"{key} {value}") | |
}) | |
else: | |
index_recursive(value, current_path, current_category) | |
elif isinstance(data, list): | |
for i, item in enumerate(data): | |
index_recursive(item, f"{path}[{i}]", category) | |
index_recursive(self.knowledge_base) | |
return indexed_items | |
def extract_keywords(self, text: str) -> List[str]: | |
"""Extract keywords from text for better matching""" | |
# Convert to lowercase and split | |
words = re.findall(r'\b\w+\b', text.lower()) | |
# Remove common stop words | |
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', | |
'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were'} | |
keywords = [word for word in words if word not in stop_words and len(word) > 2] | |
return keywords | |
def semantic_search(self, query: str, top_k: int = 5) -> List[Dict]: | |
"""Perform semantic search on the knowledge base""" | |
query_keywords = self.extract_keywords(query) | |
scored_results = [] | |
for item in self.indexed_content: | |
score = self.calculate_relevance_score(query_keywords, item) | |
if score > 0: | |
scored_results.append({ | |
**item, | |
'relevance_score': score, | |
'matched_keywords': self.get_matched_keywords(query_keywords, item['keywords']) | |
}) | |
# Sort by relevance score | |
scored_results.sort(key=lambda x: x['relevance_score'], reverse=True) | |
return scored_results[:top_k] | |
def calculate_relevance_score(self, query_keywords: List[str], item: Dict) -> float: | |
"""Calculate relevance score between query and item""" | |
item_keywords = item['keywords'] | |
item_text = f"{item['key']} {item['content']}".lower() | |
score = 0.0 | |
# Exact keyword matches | |
for keyword in query_keywords: | |
if keyword in item_keywords: | |
score += 2.0 | |
elif keyword in item_text: | |
score += 1.0 | |
# Category boost for relevant topics | |
category_boost = { | |
'fault': 1.5, 'protection': 1.5, 'standard': 1.3, | |
'power': 1.2, 'analysis': 1.2, 'calculation': 1.3 | |
} | |
for boost_term, boost_value in category_boost.items(): | |
if boost_term in item['category'].lower(): | |
for keyword in query_keywords: | |
if boost_term in keyword: | |
score *= boost_value | |
break | |
# Length normalization | |
if len(item_keywords) > 0: | |
score = score / (1 + len(item_keywords) * 0.1) | |
return score | |
def get_matched_keywords(self, query_keywords: List[str], item_keywords: List[str]) -> List[str]: | |
"""Get keywords that matched between query and item""" | |
return [kw for kw in query_keywords if kw in item_keywords] | |
def retrieve_context(self, query: str, max_context_length: int = 1000) -> str: | |
"""Retrieve relevant context for the query""" | |
relevant_items = self.semantic_search(query, top_k=10) | |
if not relevant_items: | |
return "No specific context found in knowledge base." | |
context_parts = [] | |
total_length = 0 | |
for item in relevant_items: | |
context_part = f"**{item['category']} - {item['key']}**: {item['content']}" | |
if total_length + len(context_part) < max_context_length: | |
context_parts.append(context_part) | |
total_length += len(context_part) | |
else: | |
break | |
return "\n\n".join(context_parts) | |
def get_topic_overview(self, topic: str) -> str: | |
"""Get comprehensive overview of a specific topic""" | |
topic_items = [] | |
for item in self.indexed_content: | |
if topic.lower() in item['category'].lower() or topic.lower() in item['key'].lower(): | |
topic_items.append(item) | |
if not topic_items: | |
return f"No information found for topic: {topic}" | |
# Group by category | |
categories = {} | |
for item in topic_items: | |
category = item['category'] | |
if category not in categories: | |
categories[category] = [] | |
categories[category].append(item) | |
overview_parts = [] | |
for category, items in categories.items(): | |
overview_parts.append(f"## {category.title()}") | |
for item in items[:5]: # Limit items per category | |
overview_parts.append(f"- **{item['key']}**: {item['content'][:200]}...") | |
return "\n".join(overview_parts) | |
def suggest_related_topics(self, query: str) -> List[str]: | |
"""Suggest related topics based on the query""" | |
relevant_items = self.semantic_search(query, top_k=15) | |
categories = set() | |
for item in relevant_items: | |
categories.add(item['category']) | |
return list(categories)[:5] | |
def get_formulas_for_topic(self, topic: str) -> List[str]: | |
"""Extract formulas related to a specific topic""" | |
formulas = [] | |
# Search in formulas section | |
if 'formulas' in self.knowledge_base: | |
formulas_data = self.knowledge_base['formulas'] | |
for category, formulas_dict in formulas_data.items(): | |
if topic.lower() in category.lower(): | |
if isinstance(formulas_dict, dict): | |
for formula_name, formula in formulas_dict.items(): | |
formulas.append(f"**{formula_name}**: {formula}") | |
# Search in general content for formula patterns | |
formula_patterns = [ | |
r'[A-Z]_[a-z]+ = [^.]+', | |
r'[A-Z] = [^.]+', | |
r'I_fault = [^.]+', | |
r'V_[a-z]+ = [^.]+', | |
r'Z_[a-z]+ = [^.]+', | |
r'P = [^.]+', | |
r'Q = [^.]+', | |
] | |
for item in self.indexed_content: | |
if topic.lower() in item['content'].lower(): | |
for pattern in formula_patterns: | |
matches = re.findall(pattern, item['content']) | |
formulas.extend(matches) | |
return list(set(formulas))[:10] # Remove duplicates and limit | |
def update_knowledge_base(self, new_data: Dict, category: str): | |
"""Update knowledge base with new information""" | |
if category in self.knowledge_base: | |
self.knowledge_base[category].update(new_data) | |
else: | |
self.knowledge_base[category] = new_data | |
# Recreate search index | |
self.indexed_content = self.create_search_index() | |
# Save updated knowledge base | |
try: | |
with open(self.knowledge_base_path, 'w', encoding='utf-8') as f: | |
json.dump(self.knowledge_base, f, indent=2) | |
except Exception as e: | |
print(f"Error saving knowledge base: {e}") | |
def get_statistics(self) -> Dict: | |
"""Get statistics about the knowledge base""" | |
stats = { | |
'total_entries': len(self.indexed_content), | |
'categories': len(set(item['category'] for item in self.indexed_content)), | |
'total_keywords': sum(len(item['keywords']) for item in self.indexed_content), | |
'last_updated': datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
} | |
# Category breakdown | |
category_counts = {} | |
for item in self.indexed_content: | |
category = item['category'] | |
category_counts[category] = category_counts.get(category, 0) + 1 | |
stats['category_breakdown'] = category_counts | |
return stats | |
def export_context_report(self, query: str, filename: str = None) -> str: | |
"""Export detailed context report for a query""" | |
if filename is None: | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
filename = f"context_report_{timestamp}.md" | |
relevant_items = self.semantic_search(query, top_k=20) | |
report_content = f"""# Context Report for Query: "{query}" | |
Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
## Search Results ({len(relevant_items)} items found) | |
""" | |
for i, item in enumerate(relevant_items, 1): | |
report_content += f"""### {i}. {item['category']} - {item['key']} | |
- **Content**: {item['content']} | |
- **Relevance Score**: {item['relevance_score']:.2f} | |
- **Matched Keywords**: {', '.join(item['matched_keywords'])} | |
""" | |
# Save report | |
try: | |
with open(filename, 'w', encoding='utf-8') as f: | |
f.write(report_content) | |
return f"Context report saved to {filename}" | |
except Exception as e: | |
return f"Error saving report: {e}" | |
# Example usage and testing | |
if __name__ == "__main__": | |
# Test the RAG system | |
rag = EnhancedRAGSystem() | |
# Test queries | |
test_queries = [ | |
"fault analysis", | |
"IEEE standards", | |
"protection systems", | |
"short circuit calculation", | |
"transformer protection" | |
] | |
for query in test_queries: | |
print(f"\nQuery: {query}") | |
context = rag.retrieve_context(query) | |
print(f"Context: {context[:200]}...") | |
related_topics = rag.suggest_related_topics(query) | |
print(f"Related topics: {related_topics}") | |
# Print statistics | |
stats = rag.get_statistics() | |
print(f"\nKnowledge Base Statistics:") | |
for key, value in stats.items(): | |
print(f" {key}: {value}") |