Spaces:
Runtime error
Runtime error
""" | |
Enhanced RAG (Retrieval-Augmented Generation) System | |
for Power Systems Knowledge Base with Advanced Features | |
""" | |
import json | |
import re | |
import os | |
from typing import Dict, List, Tuple, Optional | |
import pandas as pd | |
from datetime import datetime | |
import sqlite3 | |
import hashlib | |
class EnhancedRAGSystem: | |
""" | |
Advanced RAG system with semantic search, context ranking, and knowledge management | |
""" | |
def __init__(self, knowledge_base_path: str = 'data/knowledge_base.json'): | |
self.knowledge_base_path = knowledge_base_path | |
self.db_path = 'rag_cache.db' | |
self.knowledge_base = self.load_knowledge_base() | |
self.indexed_content = self.create_search_index() | |
self.init_cache_database() | |
def init_cache_database(self): | |
"""Initialize SQLite database for caching and analytics""" | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS query_cache ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
query_hash TEXT UNIQUE, | |
query_text TEXT, | |
response_context TEXT, | |
relevance_scores TEXT, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
access_count INTEGER DEFAULT 1 | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS query_analytics ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
query_text TEXT, | |
topic_category TEXT, | |
response_quality REAL, | |
user_feedback TEXT, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
) | |
''') | |
conn.commit() | |
conn.close() | |
def load_knowledge_base(self) -> Dict: | |
"""Load the comprehensive power systems knowledge base""" | |
try: | |
with open(self.knowledge_base_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
except FileNotFoundError: | |
print(f"Knowledge base not found at {self.knowledge_base_path}, creating default...") | |
return self.create_comprehensive_knowledge_base() | |
def create_comprehensive_knowledge_base(self) -> Dict: | |
"""Create comprehensive default knowledge base""" | |
knowledge_base = { | |
"fault_analysis": { | |
"symmetrical_faults": { | |
"description": "Three-phase faults where all phases are equally affected", | |
"characteristics": "Balanced conditions, highest fault current magnitude", | |
"analysis_method": "Single-phase equivalent circuit using positive sequence only", | |
"calculation": "If = Ea / Z1", | |
"occurrence": "5-10% of all power system faults", | |
"protection": "Instantaneous overcurrent, differential protection" | |
}, | |
"unsymmetrical_faults": { | |
"line_to_ground": { | |
"description": "Single-phase to ground fault", | |
"occurrence": "70-80% of all transmission line faults", | |
"calculation": "If = 3 × Ea / (Z1 + Z2 + Z0)", | |
"sequence_networks": "All three sequence networks in series", | |
"factors": "Ground resistance, tower footing resistance affect magnitude" | |
}, | |
"line_to_line": { | |
"description": "Two phases short-circuited together", | |
"occurrence": "15-20% of all faults", | |
"calculation": "If = √3 × Ea / (Z1 + Z2)", | |
"sequence_networks": "Positive and negative sequence in parallel", | |
"characteristics": "No zero sequence current" | |
}, | |
"double_line_to_ground": { | |
"description": "Two phases short-circuited to ground", | |
"occurrence": "2-5% of all faults", | |
"calculation": "Complex involving all sequence networks", | |
"severity": "Can be more severe than three-phase fault" | |
} | |
}, | |
"sequence_components": { | |
"positive_sequence": { | |
"description": "Represents balanced three-phase system", | |
"rotation": "ABC phase rotation, same as system", | |
"impedance": "Lowest impedance path, mainly system reactance" | |
}, | |
"negative_sequence": { | |
"description": "Represents phase imbalance with ACB rotation", | |
"rotation": "Opposite to system rotation", | |
"impedance": "Usually equal to positive sequence for static equipment" | |
}, | |
"zero_sequence": { | |
"description": "All phases in phase, returns through ground/neutral", | |
"impedance": "Highest impedance, depends on grounding", | |
"path": "Through ground, neutral conductors, transformer connections" | |
} | |
} | |
}, | |
"protection_systems": { | |
"overcurrent_protection": { | |
"principles": "Current magnitude based protection", | |
"types": { | |
"instantaneous": "No time delay, fast tripping for high currents", | |
"definite_time": "Fixed time delay regardless of current magnitude", | |
"inverse_time": "Time inversely related to current magnitude", | |
"very_inverse": "Steeper inverse characteristic", | |
"extremely_inverse": "Very steep characteristic for high currents" | |
}, | |
"settings": { | |
"pickup_current": "1.05-1.25 × Full load current", | |
"time_multiplier": "Adjust operating time", | |
"curve_selection": "Based on coordination requirements" | |
}, | |
"applications": "Distribution feeders, motor protection, backup protection" | |
}, | |
"differential_protection": { | |
"principle": "Compares currents entering and leaving protected zone", | |
"equation": "Id = I1 + I2 + ... + In (vector sum)", | |
"sensitivity": "Can detect internal faults as low as 5-10% of rated current", | |
"applications": { | |
"transformers": "High impedance or low impedance schemes", | |
"generators": "Stator winding and rotor protection", | |
"buses": "High speed bus protection", | |
"transmission_lines": "Pilot wire or communication based" | |
}, | |
"advantages": "Selective, sensitive, fast operating", | |
"limitations": "Requires CTs at all terminals, communication links" | |
}, | |
"distance_protection": { | |
"principle": "Measures impedance to fault location", | |
"zones": { | |
"zone_1": "80-90% of line length, instantaneous", | |
"zone_2": "120% of line + 50% of shortest adjacent line, time delayed", | |
"zone_3": "Backup protection, longer time delay", | |
"zone_4": "Reverse direction protection if required" | |
}, | |
"characteristics": { | |
"mho": "Circle passing through origin and fault point", | |
"impedance": "Circle centered at origin", | |
"reactance": "Straight line parallel to R-axis" | |
}, | |
"settings": { | |
"reach": "Based on line impedance and coordination", | |
"angle": "Line angle ± 15°, typically 60-85°", | |
"time_delays": "Zone 1: 0s, Zone 2: 0.3s, Zone 3: 1.0s" | |
} | |
} | |
}, | |
"standards": { | |
"ieee_standards": { | |
"C37.2": { | |
"title": "Electrical Power System Device Function Numbers", | |
"scope": "Standard device function numbers for protective relays", | |
"common_functions": { | |
"21": "Distance protection", | |
"27": "Undervoltage relay", | |
"50": "Instantaneous overcurrent", | |
"51": "AC time overcurrent", | |
"59": "Overvoltage relay", | |
"67": "Directional overcurrent", | |
"87": "Differential protection" | |
} | |
}, | |
"C37.90": { | |
"title": "Standard for Relays and Relay Systems", | |
"scope": "General requirements for protective relays" | |
}, | |
"C37.118": { | |
"title": "Synchrophasor Standard", | |
"scope": "PMU data format and communication protocol" | |
} | |
}, | |
"iec_standards": { | |
"61850": { | |
"title": "Communication protocols for intelligent electronic devices", | |
"scope": "Substation automation and communication" | |
}, | |
"60909": { | |
"title": "Short-circuit currents in three-phase AC systems", | |
"scope": "Calculation methods for fault currents" | |
} | |
} | |
}, | |
"formulas": { | |
"fault_calculations": { | |
"three_phase_fault": "If = Vf / Z1", | |
"line_to_ground": "If = 3 × Vf / (Z1 + Z2 + Z0)", | |
"line_to_line": "If = √3 × Vf / (Z1 + Z2)", | |
"double_line_to_ground": "If = 3 × Vf × (Z2 + Z0) / ((Z1 + Z2) × (Z1 + Z0) + Z1 × (Z2 + Z0))" | |
}, | |
"power_calculations": { | |
"apparent_power": "S = V × I* (complex conjugate)", | |
"real_power": "P = V × I × cos(θ)", | |
"reactive_power": "Q = V × I × sin(θ)", | |
"power_factor": "pf = P / S = cos(θ)" | |
}, | |
"impedance_calculations": { | |
"series_impedance": "Z_total = Z1 + Z2 + ... + Zn", | |
"parallel_impedance": "1/Z_total = 1/Z1 + 1/Z2 + ... + 1/Zn", | |
"transmission_line": "Z = R + jωL, Y = G + jωC" | |
} | |
}, | |
"equipment": { | |
"transformers": { | |
"types": { | |
"power_transformers": "High voltage, high power rating", | |
"distribution_transformers": "Medium to low voltage distribution", | |
"instrument_transformers": "Current and voltage measurement" | |
}, | |
"protection": { | |
"differential": "Primary protection for internal faults", | |
"overcurrent": "Backup protection and overload", | |
"buchholz": "Gas-actuated relay for oil-filled transformers", | |
"temperature": "Winding and oil temperature monitoring" | |
}, | |
"connections": { | |
"wye_wye": "Y-Y connection, neutral available", | |
"delta_delta": "Δ-Δ connection, no neutral", | |
"wye_delta": "Y-Δ connection, phase shift 30°", | |
"delta_wye": "Δ-Y connection, phase shift -30°" | |
} | |
}, | |
"generators": { | |
"types": { | |
"synchronous": "Constant speed, grid connected", | |
"induction": "Variable speed, wind turbines", | |
"dc": "Direct current, special applications" | |
}, | |
"protection": { | |
"differential": "Stator winding protection", | |
"reverse_power": "Motoring protection", | |
"loss_of_excitation": "Field loss protection", | |
"overvoltage": "Terminal voltage protection", | |
"frequency": "Under/over frequency protection" | |
} | |
}, | |
"transmission_lines": { | |
"types": { | |
"overhead": "Air insulated, towers and poles", | |
"underground": "Cable systems, higher cost", | |
"submarine": "Underwater cables, special insulation" | |
}, | |
"parameters": { | |
"resistance": "R = ρL/A (conductor resistance)", | |
"inductance": "L = μ₀μᵣ(ln(D/r))/(2π) per unit length", | |
"capacitance": "C = πε₀εᵣ/ln(D/r) per unit length", | |
"conductance": "G = σπd (leakage conductance)" | |
} | |
} | |
} | |
} | |
# Save the knowledge base | |
os.makedirs(os.path.dirname(self.knowledge_base_path), exist_ok=True) | |
with open(self.knowledge_base_path, 'w', encoding='utf-8') as f: | |
json.dump(knowledge_base, f, indent=2) | |
return knowledge_base | |
def create_search_index(self) -> List[Dict]: | |
"""Create searchable index from knowledge base""" | |
indexed_items = [] | |
def index_recursive(data, path="", category=""): | |
if isinstance(data, dict): | |
for key, value in data.items(): | |
current_path = f"{path}.{key}" if path else key | |
current_category = category or key | |
if isinstance(value, (str, int, float)): | |
indexed_items.append({ | |
'path': current_path, | |
'category': current_category, | |
'key': key, | |
'content': str(value), | |
'keywords': self.extract_keywords(f"{key} {value}") | |
}) | |
else: | |
index_recursive(value, current_path, current_category) | |
elif isinstance(data, list): | |
for i, item in enumerate(data): | |
index_recursive(item, f"{path}[{i}]", category) | |
index_recursive(self.knowledge_base) | |
return indexed_items | |
def extract_keywords(self, text: str) -> List[str]: | |
"""Extract keywords from text for better matching""" | |
# Convert to lowercase and split | |
words = re.findall(r'\b\w+\b', text.lower()) | |
# Remove common stop words | |
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', | |
'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', | |
'this', 'that', 'these', 'those', 'be', 'have', 'has', 'had'} | |
keywords = [word for word in words if word not in stop_words and len(word) > 2] | |
return keywords | |
def get_query_hash(self, query: str) -> str: | |
"""Generate hash for query caching""" | |
return hashlib.md5(query.lower().strip().encode()).hexdigest() | |
def get_cached_response(self, query: str) -> Optional[Dict]: | |
"""Retrieve cached response for query""" | |
query_hash = self.get_query_hash(query) | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute(''' | |
SELECT response_context, relevance_scores, access_count | |
FROM query_cache | |
WHERE query_hash = ? | |
''', (query_hash,)) | |
result = cursor.fetchone() | |
if result: | |
# Update access count | |
cursor.execute(''' | |
UPDATE query_cache | |
SET access_count = access_count + 1, timestamp = CURRENT_TIMESTAMP | |
WHERE query_hash = ? | |
''', (query_hash,)) | |
conn.commit() | |
conn.close() | |
return { | |
'context': result[0], | |
'scores': json.loads(result[1]), | |
'access_count': result[2] + 1 | |
} | |
conn.close() | |
return None | |
def cache_response(self, query: str, context: str, relevance_scores: List[float]): | |
"""Cache query response""" | |
query_hash = self.get_query_hash(query) | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
try: | |
cursor.execute(''' | |
INSERT INTO query_cache (query_hash, query_text, response_context, relevance_scores) | |
VALUES (?, ?, ?, ?) | |
''', (query_hash, query, context, json.dumps(relevance_scores))) | |
conn.commit() | |
except sqlite3.IntegrityError: | |
# Query already cached, update it | |
cursor.execute(''' | |
UPDATE query_cache | |
SET response_context = ?, relevance_scores = ?, timestamp = CURRENT_TIMESTAMP | |
WHERE query_hash = ? | |
''', (context, json.dumps(relevance_scores), query_hash)) | |
conn.commit() | |
conn.close() | |
def semantic_search(self, query: str, top_k: int = 5) -> List[Dict]: | |
"""Perform semantic search on the knowledge base with caching""" | |
# Check cache first | |
cached = self.get_cached_response(query) | |
if cached and cached['access_count'] > 1: # Use cache for repeated queries | |
# Parse cached results | |
return self._parse_cached_results(cached, top_k) | |
query_keywords = self.extract_keywords(query) | |
scored_results = [] | |
for item in self.indexed_content: | |
score = self.calculate_relevance_score(query_keywords, item) | |
if score > 0: | |
scored_results.append({ | |
**item, | |
'relevance_score': score, | |
'matched_keywords': self.get_matched_keywords(query_keywords, item['keywords']) | |
}) | |
# Sort by relevance score | |
scored_results.sort(key=lambda x: x['relevance_score'], reverse=True) | |
top_results = scored_results[:top_k] | |
# Cache the results | |
if top_results: | |
context = self._format_results_for_cache(top_results) | |
scores = [r['relevance_score'] for r in top_results] | |
self.cache_response(query, context, scores) | |
return top_results | |
def _parse_cached_results(self, cached: Dict, top_k: int) -> List[Dict]: | |
"""Parse cached results back to search format""" | |
# This is a simplified version - in practice you'd want to store more structured data | |
return [] # Placeholder for cached result parsing | |
def _format_results_for_cache(self, results: List[Dict]) -> str: | |
"""Format search results for caching""" | |
formatted = [] | |
for item in results: | |
formatted.append(f"**{item['category']} - {item['key']}**: {item['content']}") | |
return "\n\n".join(formatted) | |
def calculate_relevance_score(self, query_keywords: List[str], item: Dict) -> float: | |
"""Calculate relevance score between query and item""" | |
item_keywords = item['keywords'] | |
item_text = f"{item['key']} {item['content']}".lower() | |
score = 0.0 | |
# Exact keyword matches (higher weight) | |
for keyword in query_keywords: | |
if keyword in item_keywords: | |
score += 3.0 | |
elif keyword in item_text: | |
score += 1.5 | |
# Partial matches | |
for keyword in query_keywords: | |
for item_keyword in item_keywords: | |
if keyword in item_keyword or item_keyword in keyword: | |
score += 1.0 | |
# Category and domain-specific boosts | |
category_boost = { | |
'fault': 1.5, 'protection': 1.5, 'standard': 1.3, | |
'power': 1.2, 'analysis': 1.2, 'calculation': 1.3, | |
'equipment': 1.3, 'transformer': 1.4, 'generator': 1.4, | |
'transmission': 1.3, 'ieee': 1.2, 'iec': 1.2 | |
} | |
for boost_term, boost_value in category_boost.items(): | |
if boost_term in item['category'].lower() or boost_term in item['key'].lower(): | |
for keyword in query_keywords: | |
if boost_term in keyword: | |
score *= boost_value | |
break | |
# Length normalization to prevent bias toward longer content | |
if len(item_keywords) > 0: | |
score = score / (1 + len(item_keywords) * 0.05) | |
return score | |
def get_matched_keywords(self, query_keywords: List[str], item_keywords: List[str]) -> List[str]: | |
"""Get keywords that matched between query and item""" | |
matched = [] | |
for qk in query_keywords: | |
for ik in item_keywords: | |
if qk == ik or qk in ik or ik in qk: | |
matched.append(qk) | |
break | |
return list(set(matched)) | |
def retrieve_context(self, query: str, max_context_length: int = 1000) -> str: | |
"""Retrieve relevant context for the query""" | |
relevant_items = self.semantic_search(query, top_k=10) | |
if not relevant_items: | |
return "No specific context found in knowledge base." | |
context_parts = [] | |
total_length = 0 | |
for item in relevant_items: | |
context_part = f"**{item['category']} - {item['key']}**: {item['content']}" | |
if total_length + len(context_part) < max_context_length: | |
context_parts.append(context_part) | |
total_length += len(context_part) | |
else: | |
# Add truncated version if space allows | |
remaining_space = max_context_length - total_length - 20 | |
if remaining_space > 100: | |
truncated = context_part[:remaining_space] + "..." | |
context_parts.append(truncated) | |
break | |
return "\n\n".join(context_parts) | |
def get_topic_overview(self, topic: str) -> str: | |
"""Get comprehensive overview of a specific topic""" | |
topic_items = [] | |
for item in self.indexed_content: | |
if (topic.lower() in item['category'].lower() or | |
topic.lower() in item['key'].lower() or | |
topic.lower() in item['content'].lower()): | |
topic_items.append(item) | |
if not topic_items: | |
return f"No information found for topic: {topic}" | |
# Group by category | |
categories = {} | |
for item in topic_items: | |
category = item['category'] | |
if category not in categories: | |
categories[category] = [] | |
categories[category].append(item) | |
overview_parts = [] | |
for category, items in categories.items(): | |
overview_parts.append(f"## {category.title().replace('_', ' ')}") | |
for item in items[:5]: # Limit items per category | |
content_preview = item['content'][:200] | |
if len(item['content']) > 200: | |
content_preview += "..." | |
overview_parts.append(f"- **{item['key']}**: {content_preview}") | |
return "\n\n".join(overview_parts) | |
def suggest_related_topics(self, query: str) -> List[str]: | |
"""Suggest related topics based on the query""" | |
relevant_items = self.semantic_search(query, top_k=15) | |
categories = set() | |
for item in relevant_items: | |
categories.add(item['category'].replace('_', ' ').title()) | |
return sorted(list(categories))[:5] | |
def get_formulas_for_topic(self, topic: str) -> List[str]: | |
"""Extract formulas related to a specific topic""" | |
formulas = [] | |
# Search in formulas section | |
if 'formulas' in self.knowledge_base: | |
formulas_data = self.knowledge_base['formulas'] | |
for category, formulas_dict in formulas_data.items(): | |
if topic.lower() in category.lower(): | |
if isinstance(formulas_dict, dict): | |
for formula_name, formula in formulas_dict.items(): | |
formulas.append(f"**{formula_name.replace('_', ' ').title()}**: {formula}") | |
# Search in general content for formula patterns | |
formula_patterns = [ | |
r'[A-Z][a-z]*\s*=\s*[^.]+', | |
r'I_[a-zA-Z]+\s*=\s*[^.]+', | |
r'V_[a-zA-Z]+\s*=\s*[^.]+', | |
r'Z_[a-zA-Z]+\s*=\s*[^.]+', | |
r'P\s*=\s*[^.]+', | |
r'Q\s*=\s*[^.]+', | |
r'S\s*=\s*[^.]+', | |
] | |
for item in self.indexed_content: | |
if topic.lower() in item['content'].lower(): | |
for pattern in formula_patterns: | |
matches = re.findall(pattern, item['content']) | |
for match in matches: | |
if len(match.strip()) > 5: # Filter out very short matches | |
formulas.append(match.strip()) | |
return list(set(formulas))[:10] # Remove duplicates and limit | |
def update_knowledge_base(self, new_data: Dict, category: str): | |
"""Update knowledge base with new information""" | |
if category in self.knowledge_base: | |
if isinstance(self.knowledge_base[category], dict) and isinstance(new_data, dict): | |
self.knowledge_base[category].update(new_data) | |
else: | |
self.knowledge_base[category] = new_data | |
else: | |
self.knowledge_base[category] = new_data | |
# Recreate search index | |
self.indexed_content = self.create_search_index() | |
# Save updated knowledge base | |
try: | |
os.makedirs(os.path.dirname(self.knowledge_base_path), exist_ok=True) | |
with open(self.knowledge_base_path, 'w', encoding='utf-8') as f: | |
json.dump(self.knowledge_base, f, indent=2) | |
print(f"Knowledge base updated successfully in category: {category}") | |
except Exception as e: | |
print(f"Error saving knowledge base: {e}") | |
def get_statistics(self) -> Dict: | |
"""Get statistics about the knowledge base""" | |
stats = { | |
'total_entries': len(self.indexed_content), | |
'categories': len(set(item['category'] for item in self.indexed_content)), | |
'total_keywords': sum(len(item['keywords']) for item in self.indexed_content), | |
'last_updated': datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
} | |
# Category breakdown | |
category_counts = {} | |
for item in self.indexed_content: | |
category = item['category'] | |
category_counts[category] = category_counts.get(category, 0) + 1 | |
stats['category_breakdown'] = category_counts | |
# Cache statistics | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute('SELECT COUNT(*) FROM query_cache') | |
cached_queries = cursor.fetchone()[0] | |
cursor.execute('SELECT COUNT(*) FROM query_analytics') | |
analytics_entries = cursor.fetchone()[0] | |
conn.close() | |
stats['cached_queries'] = cached_queries | |
stats['analytics_entries'] = analytics_entries | |
return stats | |
def log_query_analytics(self, query: str, topic_category: str, response_quality: float = 0.0, user_feedback: str = ""): | |
"""Log query analytics for system improvement""" | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute(''' | |
INSERT INTO query_analytics (query_text, topic_category, response_quality, user_feedback) | |
VALUES (?, ?, ?, ?) | |
''', (query, topic_category, response_quality, user_feedback)) | |
conn.commit() | |
conn.close() | |
def get_query_analytics(self, days: int = 30) -> pd.DataFrame: | |
"""Get query analytics for the specified number of days""" | |
conn = sqlite3.connect(self.db_path) | |
query = ''' | |
SELECT query_text, topic_category, response_quality, user_feedback, timestamp | |
FROM query_analytics | |
WHERE timestamp >= datetime('now', '-{} days') | |
ORDER BY timestamp DESC | |
'''.format(days) | |
df = pd.read_sql_query(query, conn) | |
conn.close() | |
return df | |
def export_context_report(self, query: str, filename: str = None) -> str: | |
"""Export detailed context report for a query""" | |
if filename is None: | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
filename = f"context_report_{timestamp}.md" | |
relevant_items = self.semantic_search(query, top_k=20) | |
report_content = f"""# Context Report for Query: "{query}" | |
Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
## Search Results ({len(relevant_items)} items found) | |
""" | |
for i, item in enumerate(relevant_items, 1): | |
matched_kw = ', '.join(item.get('matched_keywords', [])) | |
report_content += f"""### {i}. {item['category'].replace('_', ' ').title()} - {item['key'].replace('_', ' ').title()} | |
- **Content**: {item['content']} | |
- **Relevance Score**: {item['relevance_score']:.3f} | |
- **Matched Keywords**: {matched_kw if matched_kw else 'None'} | |
- **Full Path**: {item['path']} | |
""" | |
# Add related formulas if available | |
formulas = self.get_formulas_for_topic(query) | |
if formulas: | |
report_content += f"""## Related Formulas | |
""" | |
for formula in formulas: | |
report_content += f"- {formula}\n" | |
# Add suggested topics | |
related_topics = self.suggest_related_topics(query) | |
if related_topics: | |
report_content += f""" | |
## Related Topics | |
{', '.join(related_topics)} | |
""" | |
# Save report | |
try: | |
with open(filename, 'w', encoding='utf-8') as f: | |
f.write(report_content) | |
return f"Context report saved to {filename}" | |
except Exception as e: | |
return f"Error saving report: {e}" | |
def clear_cache(self): | |
"""Clear the query cache""" | |
conn = sqlite3.connect(self.db_path) | |
cursor = conn.cursor() | |
cursor.execute('DELETE FROM query_cache') | |
cursor.execute('DELETE FROM query_analytics') | |
conn.commit() | |
conn.close() | |
print("Cache cleared successfully") | |
# Example usage and testing | |
def demo_rag_system(): | |
"""Demonstration of the Enhanced RAG System capabilities""" | |
print("=== Enhanced RAG System for Power Systems ===\n") | |
# Initialize the system | |
rag = EnhancedRAGSystem() | |
# Display system statistics | |
stats = rag.get_statistics() | |
print("Knowledge Base Statistics:") | |
for key, value in stats.items(): | |
if key == 'category_breakdown': | |
print(f" {key}:") | |
for cat, count in value.items(): | |
print(f" {cat.replace('_', ' ').title()}: {count} entries") | |
else: | |
print(f" {key}: {value}") | |
print() | |
# Test queries with different complexity levels | |
test_queries = [ | |
"fault analysis three phase", | |
"IEEE standards protection relays", | |
"transformer differential protection", | |
"short circuit calculation methods", | |
"distance protection zones settings", | |
"sequence components impedance", | |
"overcurrent relay coordination", | |
"power system formulas calculations" | |
] | |
print("=== Testing Search Capabilities ===\n") | |
for i, query in enumerate(test_queries, 1): | |
print(f"{i}. Query: '{query}'") | |
# Get context | |
context = rag.retrieve_context(query, max_context_length=500) | |
print(f" Context Preview: {context[:200]}{'...' if len(context) > 200 else ''}") | |
# Get related topics | |
related_topics = rag.suggest_related_topics(query) | |
print(f" Related Topics: {', '.join(related_topics)}") | |
# Get formulas if available | |
formulas = rag.get_formulas_for_topic(query) | |
if formulas: | |
print(f" Related Formulas: {formulas[0] if formulas else 'None'}") | |
# Log analytics | |
rag.log_query_analytics(query, related_topics[0] if related_topics else "general", 0.85) | |
print() | |
print("=== Advanced Features Demo ===\n") | |
# Topic overview | |
print("1. Topic Overview for 'protection':") | |
overview = rag.get_topic_overview("protection") | |
print(overview[:300] + "..." if len(overview) > 300 else overview) | |
print() | |
# Formula extraction | |
print("2. Formulas for 'fault calculations':") | |
formulas = rag.get_formulas_for_topic("fault") | |
for formula in formulas[:3]: | |
print(f" - {formula}") | |
print() | |
# Cache demonstration | |
print("3. Cache Performance Test:") | |
import time | |
test_query = "differential protection applications" | |
# First query (no cache) | |
start_time = time.time() | |
context1 = rag.retrieve_context(test_query) | |
time1 = time.time() - start_time | |
print(f" First query time: {time1:.4f} seconds") | |
# Second query (with cache) | |
start_time = time.time() | |
context2 = rag.retrieve_context(test_query) | |
time2 = time.time() - start_time | |
print(f" Cached query time: {time2:.4f} seconds") | |
print(f" Speed improvement: {((time1 - time2) / time1 * 100):.1f}%") | |
print() | |
# Export report | |
print("4. Exporting Context Report:") | |
report_status = rag.export_context_report("protection systems analysis") | |
print(f" {report_status}") | |
print() | |
# Analytics summary | |
print("5. Query Analytics Summary:") | |
try: | |
analytics_df = rag.get_query_analytics(days=1) | |
if not analytics_df.empty: | |
print(f" Total queries today: {len(analytics_df)}") | |
categories = analytics_df['topic_category'].value_counts() | |
print(f" Top categories: {dict(categories.head(3))}") | |
else: | |
print(" No analytics data available yet") | |
except Exception as e: | |
print(f" Analytics error: {e}") | |
print("\n=== System Update Demo ===\n") | |
# Add new knowledge | |
new_protection_data = { | |
"pilot_protection": { | |
"description": "High-speed protection using communication channels", | |
"types": { | |
"pilot_wire": "Dedicated metallic circuit communication", | |
"microwave": "Radio frequency communication", | |
"fiber_optic": "Optical fiber communication", | |
"power_line_carrier": "Communication over power lines" | |
}, | |
"advantages": "High speed, secure communication, reliable", | |
"applications": "Long transmission lines, critical circuits" | |
} | |
} | |
print("Adding new protection system data...") | |
rag.update_knowledge_base(new_protection_data, "protection_systems") | |
# Test the new data | |
new_context = rag.retrieve_context("pilot protection communication") | |
print(f"New data retrieval test: {'Success' if 'pilot protection' in new_context.lower() else 'Failed'}") | |
print() | |
# Final statistics | |
final_stats = rag.get_statistics() | |
print("Final Statistics:") | |
print(f" Total entries: {final_stats['total_entries']}") | |
print(f" Cached queries: {final_stats['cached_queries']}") | |
print(f" Analytics entries: {final_stats['analytics_entries']}") | |
return rag | |
class RAGSystemInterface: | |
""" | |
Interactive interface for the RAG system | |
""" | |
def __init__(self, rag_system: EnhancedRAGSystem): | |
self.rag = rag_system | |
self.session_queries = [] | |
def interactive_session(self): | |
"""Run an interactive session with the RAG system""" | |
print("\n=== Interactive RAG System Session ===") | |
print("Commands:") | |
print(" 'help' - Show available commands") | |
print(" 'stats' - Show system statistics") | |
print(" 'topics' - List main topics") | |
print(" 'formulas [topic]' - Get formulas for topic") | |
print(" 'overview [topic]' - Get topic overview") | |
print(" 'export [query]' - Export context report") | |
print(" 'clear' - Clear cache") | |
print(" 'quit' - Exit session") | |
print(" Or enter any query for search\n") | |
while True: | |
try: | |
user_input = input("RAG> ").strip() | |
if not user_input: | |
continue | |
if user_input.lower() == 'quit': | |
break | |
elif user_input.lower() == 'help': | |
self.show_help() | |
elif user_input.lower() == 'stats': | |
self.show_stats() | |
elif user_input.lower() == 'topics': | |
self.show_topics() | |
elif user_input.lower().startswith('formulas'): | |
topic = user_input[8:].strip() or "fault" | |
self.show_formulas(topic) | |
elif user_input.lower().startswith('overview'): | |
topic = user_input[8:].strip() or "protection" | |
self.show_overview(topic) | |
elif user_input.lower().startswith('export'): | |
query = user_input[6:].strip() or "power systems" | |
self.export_report(query) | |
elif user_input.lower() == 'clear': | |
self.clear_cache() | |
else: | |
self.process_query(user_input) | |
except KeyboardInterrupt: | |
print("\nSession interrupted. Type 'quit' to exit properly.") | |
except Exception as e: | |
print(f"Error: {e}") | |
print("Session ended. Goodbye!") | |
def show_help(self): | |
"""Show detailed help""" | |
help_text = """ | |
Available Commands: | |
Query Search: | |
- Enter any natural language query about power systems | |
- Example: "How does differential protection work?" | |
System Commands: | |
- stats: Show knowledge base statistics | |
- topics: List all available main topics | |
- formulas [topic]: Show formulas related to topic (default: fault) | |
- overview [topic]: Get comprehensive overview (default: protection) | |
- export [query]: Export detailed context report (default: power systems) | |
- clear: Clear query cache and analytics | |
- quit: Exit the interactive session | |
Tips: | |
- Be specific in queries for better results | |
- Use technical terms for more precise matches | |
- Try related topic suggestions for exploration | |
""" | |
print(help_text) | |
def show_stats(self): | |
"""Show system statistics""" | |
stats = self.rag.get_statistics() | |
print("\nSystem Statistics:") | |
print("-" * 40) | |
for key, value in stats.items(): | |
if key == 'category_breakdown': | |
print(f"{key.replace('_', ' ').title()}:") | |
for cat, count in value.items(): | |
print(f" • {cat.replace('_', ' ').title()}: {count}") | |
else: | |
print(f"{key.replace('_', ' ').title()}: {value}") | |
print() | |
def show_topics(self): | |
"""Show main topics""" | |
categories = set(item['category'] for item in self.rag.indexed_content) | |
print("\nAvailable Topics:") | |
print("-" * 30) | |
for i, category in enumerate(sorted(categories), 1): | |
print(f"{i:2d}. {category.replace('_', ' ').title()}") | |
print() | |
def show_formulas(self, topic: str): | |
"""Show formulas for topic""" | |
formulas = self.rag.get_formulas_for_topic(topic) | |
print(f"\nFormulas for '{topic}':") | |
print("-" * 40) | |
if formulas: | |
for i, formula in enumerate(formulas, 1): | |
print(f"{i:2d}. {formula}") | |
else: | |
print(f"No formulas found for topic '{topic}'") | |
print() | |
def show_overview(self, topic: str): | |
"""Show topic overview""" | |
overview = self.rag.get_topic_overview(topic) | |
print(f"\nOverview for '{topic}':") | |
print("-" * 50) | |
print(overview) | |
print() | |
def export_report(self, query: str): | |
"""Export context report""" | |
result = self.rag.export_context_report(query) | |
print(f"\nExport Result: {result}\n") | |
def clear_cache(self): | |
"""Clear system cache""" | |
self.rag.clear_cache() | |
print("\nCache cleared successfully!\n") | |
def process_query(self, query: str): | |
"""Process a user query""" | |
self.session_queries.append(query) | |
print(f"\nQuery: {query}") | |
print("=" * 50) | |
# Get search results | |
results = self.rag.semantic_search(query, top_k=5) | |
if not results: | |
print("No relevant results found.") | |
return | |
# Show top results | |
print(f"Found {len(results)} relevant results:\n") | |
for i, result in enumerate(results, 1): | |
print(f"{i}. {result['category'].replace('_', ' ').title()} - {result['key'].replace('_', ' ').title()}") | |
print(f" Score: {result['relevance_score']:.3f}") | |
print(f" Content: {result['content'][:150]}{'...' if len(result['content']) > 150 else ''}") | |
if result.get('matched_keywords'): | |
print(f" Keywords: {', '.join(result['matched_keywords'])}") | |
print() | |
# Show context | |
context = self.rag.retrieve_context(query) | |
print("Context Summary:") | |
print("-" * 20) | |
print(context) | |
print() | |
# Show related topics | |
related_topics = self.rag.suggest_related_topics(query) | |
if related_topics: | |
print(f"Related Topics: {', '.join(related_topics)}") | |
print() | |
# Log analytics | |
main_category = results[0]['category'] if results else "general" | |
self.rag.log_query_analytics(query, main_category, results[0]['relevance_score'] if results else 0.0) | |
# Main execution | |
if __name__ == "__main__": | |
print("Enhanced RAG System for Power Systems Knowledge Base") | |
print("=" * 60) | |
# Run demonstration | |
rag_system = demo_rag_system() | |
# Ask user if they want interactive session | |
while True: | |
choice = input("\nWould you like to start an interactive session? (y/n): ").lower().strip() | |
if choice in ['y', 'yes']: | |
interface = RAGSystemInterface(rag_system) | |
interface.interactive_session() | |
break | |
elif choice in ['n', 'no']: | |
print("Thank you for using the Enhanced RAG System!") | |
break | |
else: | |
print("Please enter 'y' for yes or 'n' for no.") | |
print("\nSystem shutdown complete.") |