import matplotlib matplotlib.use('Agg') # Use non-interactive backend to avoid GUI issues import matplotlib.pyplot as plt import networkx as nx import numpy as np from typing import Dict, List, Any, Tuple, Optional import json import io import base64 import tempfile import os import plotly.graph_objects as go import plotly.express as px from pyvis.network import Network class GraphVisualizer: def __init__(self): self.color_map = { 'PERSON': '#FF6B6B', 'ORGANIZATION': '#4ECDC4', 'LOCATION': '#45B7D1', 'CONCEPT': '#96CEB4', 'EVENT': '#FFEAA7', 'OBJECT': '#DDA0DD', 'UNKNOWN': '#95A5A6' } def visualize_graph(self, graph: nx.DiGraph, layout_type: str = "spring", show_labels: bool = True, show_edge_labels: bool = False, node_size_factor: float = 1.0, figsize: Tuple[int, int] = (12, 8)) -> str: """Create a matplotlib visualization of the graph and return file path.""" if not graph.nodes(): return self._create_empty_graph_image() # Create figure plt.figure(figsize=figsize) plt.clf() # Calculate layout pos = self._calculate_layout(graph, layout_type) # Get node properties node_colors = [self.color_map.get(graph.nodes[node].get('type', 'UNKNOWN'), '#95A5A6') for node in graph.nodes()] node_sizes = [graph.nodes[node].get('size', 20) * node_size_factor * 10 for node in graph.nodes()] # Draw nodes nx.draw_networkx_nodes(graph, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8) # Draw edges nx.draw_networkx_edges(graph, pos, edge_color='gray', arrows=True, arrowsize=20, alpha=0.6, width=1.5) # Draw labels if show_labels: # Create labels with importance scores labels = {} for node in graph.nodes(): importance = graph.nodes[node].get('importance', 0.0) labels[node] = f"{node}\n({importance:.2f})" nx.draw_networkx_labels(graph, pos, labels, font_size=8) # Draw edge labels if show_edge_labels: edge_labels = {(u, v): data.get('relationship', '') for u, v, data in graph.edges(data=True)} nx.draw_networkx_edge_labels(graph, pos, edge_labels, font_size=6) plt.title("Knowledge Graph", fontsize=16, fontweight='bold') plt.axis('off') plt.tight_layout() # Save to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') plt.close() return temp_file.name def _calculate_layout(self, graph: nx.DiGraph, layout_type: str) -> Dict[str, Tuple[float, float]]: """Calculate node positions using specified layout algorithm.""" try: if layout_type == "spring": return nx.spring_layout(graph, k=1, iterations=50) elif layout_type == "circular": return nx.circular_layout(graph) elif layout_type == "shell": return nx.shell_layout(graph) elif layout_type == "kamada_kawai": return nx.kamada_kawai_layout(graph) elif layout_type == "random": return nx.random_layout(graph) else: return nx.spring_layout(graph, k=1, iterations=50) except: # Fallback to simple layout if algorithm fails return nx.spring_layout(graph, k=1, iterations=50) def _create_empty_graph_image(self) -> str: """Create an image for empty graph.""" plt.figure(figsize=(8, 6)) plt.text(0.5, 0.5, 'No graph data to display', horizontalalignment='center', verticalalignment='center', fontsize=16, transform=plt.gca().transAxes) plt.axis('off') # Save to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') plt.close() return temp_file.name def create_interactive_html(self, graph: nx.DiGraph) -> str: """Create an interactive HTML visualization using vis.js.""" if not graph.nodes(): return "
No graph data to display
" # Convert graph to vis.js format nodes = [] edges = [] for node, data in graph.nodes(data=True): nodes.append({ "id": node, "label": node, "color": self.color_map.get(data.get('type', 'UNKNOWN'), '#95A5A6'), "size": data.get('size', 20), "title": f"Type: {data.get('type', 'UNKNOWN')}
" f"Importance: {data.get('importance', 0.0):.2f}
" f"Description: {data.get('description', 'N/A')}" }) for u, v, data in graph.edges(data=True): edges.append({ "from": u, "to": v, "label": data.get('relationship', ''), "title": data.get('description', ''), "arrows": {"to": {"enabled": True}} }) html_template = f"""
""" return html_template def create_statistics_summary(self, graph: nx.DiGraph, stats: Dict[str, Any]) -> str: """Create a formatted statistics summary.""" if not graph.nodes(): return "No graph statistics available." # Entity type distribution type_counts = {} for node, data in graph.nodes(data=True): node_type = data.get('type', 'UNKNOWN') type_counts[node_type] = type_counts.get(node_type, 0) + 1 # Relationship type distribution rel_counts = {} for u, v, data in graph.edges(data=True): rel_type = data.get('relationship', 'unknown') rel_counts[rel_type] = rel_counts.get(rel_type, 0) + 1 summary = f""" ## Graph Statistics **Basic Metrics:** - Nodes: {stats['num_nodes']} - Edges: {stats['num_edges']} - Density: {stats['density']:.3f} - Connected: {'Yes' if stats['is_connected'] else 'No'} - Components: {stats['num_components']} - Average Degree: {stats['avg_degree']:.2f} **Entity Types:** """ for entity_type, count in sorted(type_counts.items()): summary += f"\n- {entity_type}: {count}" summary += "\n\n**Relationship Types:**" for rel_type, count in sorted(rel_counts.items()): summary += f"\n- {rel_type}: {count}" return summary def create_entity_list(self, graph: nx.DiGraph, sort_by: str = "importance") -> str: """Create a formatted list of entities.""" if not graph.nodes(): return "No entities found." entities = [] for node, data in graph.nodes(data=True): entities.append({ 'name': node, 'type': data.get('type', 'UNKNOWN'), 'importance': data.get('importance', 0.0), 'description': data.get('description', 'N/A'), 'connections': graph.degree(node) }) # Sort entities if sort_by == "importance": entities.sort(key=lambda x: x['importance'], reverse=True) elif sort_by == "connections": entities.sort(key=lambda x: x['connections'], reverse=True) elif sort_by == "name": entities.sort(key=lambda x: x['name']) entity_list = "## Entities\n\n" for entity in entities: entity_list += f""" **{entity['name']}** ({entity['type']}) - Importance: {entity['importance']:.2f} - Connections: {entity['connections']} - Description: {entity['description']} """ return entity_list def get_layout_options(self) -> List[str]: """Get available layout options.""" return ["spring", "circular", "shell", "kamada_kawai", "random"] def get_entity_types(self, graph: nx.DiGraph) -> List[str]: """Get unique entity types from the graph.""" types = set() for node, data in graph.nodes(data=True): types.add(data.get('type', 'UNKNOWN')) return sorted(list(types)) def create_plotly_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> go.Figure: """Create an interactive Plotly visualization of the graph.""" if not graph.nodes(): # Return empty figure fig = go.Figure() fig.add_annotation( text="No graph data to display", xref="paper", yref="paper", x=0.5, y=0.5, xanchor='center', yanchor='middle', showarrow=False, font=dict(size=16) ) return fig # Calculate layout pos = self._calculate_layout(graph, layout_type) # Prepare node data node_x = [] node_y = [] node_text = [] node_info = [] node_colors = [] node_sizes = [] for node in graph.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) data = graph.nodes[node] node_type = data.get('type', 'UNKNOWN') importance = data.get('importance', 0.0) description = data.get('description', 'N/A') connections = graph.degree(node) node_text.append(node) node_info.append( f"{node}
" f"Type: {node_type}
" f"Importance: {importance:.2f}
" f"Connections: {connections}
" f"Description: {description}" ) node_colors.append(self.color_map.get(node_type, '#95A5A6')) node_sizes.append(max(10, data.get('size', 20))) # Prepare edge data edge_x = [] edge_y = [] edge_info = [] for edge in graph.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_data = graph.edges[edge] relationship = edge_data.get('relationship', 'connected') edge_info.append(f"{edge[0]} → {edge[1]}
Relationship: {relationship}") # Create edge trace edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=2, color='gray'), hoverinfo='none', mode='lines' ) # Create node trace node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', text=node_text, hovertext=node_info, textposition="middle center", marker=dict( size=node_sizes, color=node_colors, line=dict(width=2, color='white') ) ) # Create figure fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( title='Interactive Knowledge Graph', titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20,l=5,r=5,t=40), annotations=[ dict( text="Hover over nodes for details. Drag to pan, scroll to zoom.", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002, xanchor='left', yanchor='bottom', font=dict(color="gray", size=12) )], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white' )) return fig def create_pyvis_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> str: """Create an interactive pyvis visualization and return HTML file path.""" if not graph.nodes(): return self._create_empty_pyvis_graph() # Create pyvis network net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") # Configure physics net.set_options(""" { "physics": { "enabled": true, "stabilization": {"enabled": true, "iterations": 200}, "barnesHut": { "gravitationalConstant": -2000, "centralGravity": 0.3, "springLength": 95, "springConstant": 0.04, "damping": 0.09 } }, "interaction": { "hover": true, "tooltipDelay": 200, "hideEdgesOnDrag": false } } """) # Add nodes for node, data in graph.nodes(data=True): node_type = data.get('type', 'UNKNOWN') importance = data.get('importance', 0.0) description = data.get('description', 'N/A') connections = graph.degree(node) # Node properties color = self.color_map.get(node_type, '#95A5A6') size = max(10, data.get('size', 20)) # Tooltip text title = f""" {node}
Type: {node_type}
Importance: {importance:.2f}
Connections: {connections}
Description: {description} """ net.add_node(node, label=node, title=title, color=color, size=size) # Add edges for u, v, data in graph.edges(data=True): relationship = data.get('relationship', 'connected') title = f"{u} → {v}
Relationship: {relationship}" net.add_edge(u, v, title=title, arrows="to", color="gray") # Save to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') net.save_graph(temp_file.name) temp_file.close() return temp_file.name def _create_empty_pyvis_graph(self) -> str: """Create an empty pyvis graph.""" net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") net.add_node(1, label="No graph data", color="#cccccc") temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') net.save_graph(temp_file.name) temp_file.close() return temp_file.name def get_visualization_options(self) -> List[str]: """Get available visualization types.""" return ["matplotlib", "plotly", "pyvis", "vis.js"] def get_relationship_types(self, graph: nx.DiGraph) -> List[str]: """Get unique relationship types from the graph.""" types = set() for u, v, data in graph.edges(data=True): types.add(data.get('relationship', 'unknown')) return sorted(list(types))