import matplotlib
matplotlib.use('Agg') # Use non-interactive backend to avoid GUI issues
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from typing import Dict, List, Any, Tuple, Optional
import json
import io
import base64
import tempfile
import os
import plotly.graph_objects as go
import plotly.express as px
from pyvis.network import Network
class GraphVisualizer:
def __init__(self):
self.color_map = {
'PERSON': '#FF6B6B',
'ORGANIZATION': '#4ECDC4',
'LOCATION': '#45B7D1',
'CONCEPT': '#96CEB4',
'EVENT': '#FFEAA7',
'OBJECT': '#DDA0DD',
'UNKNOWN': '#95A5A6'
}
def visualize_graph(self,
graph: nx.DiGraph,
layout_type: str = "spring",
show_labels: bool = True,
show_edge_labels: bool = False,
node_size_factor: float = 1.0,
figsize: Tuple[int, int] = (12, 8)) -> str:
"""Create a matplotlib visualization of the graph and return file path."""
if not graph.nodes():
return self._create_empty_graph_image()
# Create figure
plt.figure(figsize=figsize)
plt.clf()
# Calculate layout
pos = self._calculate_layout(graph, layout_type)
# Get node properties
node_colors = [self.color_map.get(graph.nodes[node].get('type', 'UNKNOWN'), '#95A5A6')
for node in graph.nodes()]
node_sizes = [graph.nodes[node].get('size', 20) * node_size_factor * 10
for node in graph.nodes()]
# Draw nodes
nx.draw_networkx_nodes(graph, pos,
node_color=node_colors,
node_size=node_sizes,
alpha=0.8)
# Draw edges
nx.draw_networkx_edges(graph, pos,
edge_color='gray',
arrows=True,
arrowsize=20,
alpha=0.6,
width=1.5)
# Draw labels
if show_labels:
# Create labels with importance scores
labels = {}
for node in graph.nodes():
importance = graph.nodes[node].get('importance', 0.0)
labels[node] = f"{node}\n({importance:.2f})"
nx.draw_networkx_labels(graph, pos, labels, font_size=8)
# Draw edge labels
if show_edge_labels:
edge_labels = {(u, v): data.get('relationship', '')
for u, v, data in graph.edges(data=True)}
nx.draw_networkx_edge_labels(graph, pos, edge_labels, font_size=6)
plt.title("Knowledge Graph", fontsize=16, fontweight='bold')
plt.axis('off')
plt.tight_layout()
# Save to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight')
plt.close()
return temp_file.name
def _calculate_layout(self, graph: nx.DiGraph, layout_type: str) -> Dict[str, Tuple[float, float]]:
"""Calculate node positions using specified layout algorithm."""
try:
if layout_type == "spring":
return nx.spring_layout(graph, k=1, iterations=50)
elif layout_type == "circular":
return nx.circular_layout(graph)
elif layout_type == "shell":
return nx.shell_layout(graph)
elif layout_type == "kamada_kawai":
return nx.kamada_kawai_layout(graph)
elif layout_type == "random":
return nx.random_layout(graph)
else:
return nx.spring_layout(graph, k=1, iterations=50)
except:
# Fallback to simple layout if algorithm fails
return nx.spring_layout(graph, k=1, iterations=50)
def _create_empty_graph_image(self) -> str:
"""Create an image for empty graph."""
plt.figure(figsize=(8, 6))
plt.text(0.5, 0.5, 'No graph data to display',
horizontalalignment='center', verticalalignment='center',
fontsize=16, transform=plt.gca().transAxes)
plt.axis('off')
# Save to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight')
plt.close()
return temp_file.name
def create_interactive_html(self, graph: nx.DiGraph) -> str:
"""Create an interactive HTML visualization using vis.js."""
if not graph.nodes():
return "
No graph data to display
"
# Convert graph to vis.js format
nodes = []
edges = []
for node, data in graph.nodes(data=True):
nodes.append({
"id": node,
"label": node,
"color": self.color_map.get(data.get('type', 'UNKNOWN'), '#95A5A6'),
"size": data.get('size', 20),
"title": f"Type: {data.get('type', 'UNKNOWN')}
"
f"Importance: {data.get('importance', 0.0):.2f}
"
f"Description: {data.get('description', 'N/A')}"
})
for u, v, data in graph.edges(data=True):
edges.append({
"from": u,
"to": v,
"label": data.get('relationship', ''),
"title": data.get('description', ''),
"arrows": {"to": {"enabled": True}}
})
html_template = f"""
"""
return html_template
def create_statistics_summary(self, graph: nx.DiGraph, stats: Dict[str, Any]) -> str:
"""Create a formatted statistics summary."""
if not graph.nodes():
return "No graph statistics available."
# Entity type distribution
type_counts = {}
for node, data in graph.nodes(data=True):
node_type = data.get('type', 'UNKNOWN')
type_counts[node_type] = type_counts.get(node_type, 0) + 1
# Relationship type distribution
rel_counts = {}
for u, v, data in graph.edges(data=True):
rel_type = data.get('relationship', 'unknown')
rel_counts[rel_type] = rel_counts.get(rel_type, 0) + 1
summary = f"""
## Graph Statistics
**Basic Metrics:**
- Nodes: {stats['num_nodes']}
- Edges: {stats['num_edges']}
- Density: {stats['density']:.3f}
- Connected: {'Yes' if stats['is_connected'] else 'No'}
- Components: {stats['num_components']}
- Average Degree: {stats['avg_degree']:.2f}
**Entity Types:**
"""
for entity_type, count in sorted(type_counts.items()):
summary += f"\n- {entity_type}: {count}"
summary += "\n\n**Relationship Types:**"
for rel_type, count in sorted(rel_counts.items()):
summary += f"\n- {rel_type}: {count}"
return summary
def create_entity_list(self, graph: nx.DiGraph, sort_by: str = "importance") -> str:
"""Create a formatted list of entities."""
if not graph.nodes():
return "No entities found."
entities = []
for node, data in graph.nodes(data=True):
entities.append({
'name': node,
'type': data.get('type', 'UNKNOWN'),
'importance': data.get('importance', 0.0),
'description': data.get('description', 'N/A'),
'connections': graph.degree(node)
})
# Sort entities
if sort_by == "importance":
entities.sort(key=lambda x: x['importance'], reverse=True)
elif sort_by == "connections":
entities.sort(key=lambda x: x['connections'], reverse=True)
elif sort_by == "name":
entities.sort(key=lambda x: x['name'])
entity_list = "## Entities\n\n"
for entity in entities:
entity_list += f"""
**{entity['name']}** ({entity['type']})
- Importance: {entity['importance']:.2f}
- Connections: {entity['connections']}
- Description: {entity['description']}
"""
return entity_list
def get_layout_options(self) -> List[str]:
"""Get available layout options."""
return ["spring", "circular", "shell", "kamada_kawai", "random"]
def get_entity_types(self, graph: nx.DiGraph) -> List[str]:
"""Get unique entity types from the graph."""
types = set()
for node, data in graph.nodes(data=True):
types.add(data.get('type', 'UNKNOWN'))
return sorted(list(types))
def create_plotly_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> go.Figure:
"""Create an interactive Plotly visualization of the graph."""
if not graph.nodes():
# Return empty figure
fig = go.Figure()
fig.add_annotation(
text="No graph data to display",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16)
)
return fig
# Calculate layout
pos = self._calculate_layout(graph, layout_type)
# Prepare node data
node_x = []
node_y = []
node_text = []
node_info = []
node_colors = []
node_sizes = []
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
data = graph.nodes[node]
node_type = data.get('type', 'UNKNOWN')
importance = data.get('importance', 0.0)
description = data.get('description', 'N/A')
connections = graph.degree(node)
node_text.append(node)
node_info.append(
f"{node}
"
f"Type: {node_type}
"
f"Importance: {importance:.2f}
"
f"Connections: {connections}
"
f"Description: {description}"
)
node_colors.append(self.color_map.get(node_type, '#95A5A6'))
node_sizes.append(max(10, data.get('size', 20)))
# Prepare edge data
edge_x = []
edge_y = []
edge_info = []
for edge in graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_data = graph.edges[edge]
relationship = edge_data.get('relationship', 'connected')
edge_info.append(f"{edge[0]} → {edge[1]}
Relationship: {relationship}")
# Create edge trace
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='gray'),
hoverinfo='none',
mode='lines'
)
# Create node trace
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
text=node_text,
hovertext=node_info,
textposition="middle center",
marker=dict(
size=node_sizes,
color=node_colors,
line=dict(width=2, color='white')
)
)
# Create figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title='Interactive Knowledge Graph',
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[ dict(
text="Hover over nodes for details. Drag to pan, scroll to zoom.",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
xanchor='left', yanchor='bottom',
font=dict(color="gray", size=12)
)],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white'
))
return fig
def create_pyvis_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> str:
"""Create an interactive pyvis visualization and return HTML file path."""
if not graph.nodes():
return self._create_empty_pyvis_graph()
# Create pyvis network
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black")
# Configure physics
net.set_options("""
{
"physics": {
"enabled": true,
"stabilization": {"enabled": true, "iterations": 200},
"barnesHut": {
"gravitationalConstant": -2000,
"centralGravity": 0.3,
"springLength": 95,
"springConstant": 0.04,
"damping": 0.09
}
},
"interaction": {
"hover": true,
"tooltipDelay": 200,
"hideEdgesOnDrag": false
}
}
""")
# Add nodes
for node, data in graph.nodes(data=True):
node_type = data.get('type', 'UNKNOWN')
importance = data.get('importance', 0.0)
description = data.get('description', 'N/A')
connections = graph.degree(node)
# Node properties
color = self.color_map.get(node_type, '#95A5A6')
size = max(10, data.get('size', 20))
# Tooltip text
title = f"""
{node}
Type: {node_type}
Importance: {importance:.2f}
Connections: {connections}
Description: {description}
"""
net.add_node(node, label=node, title=title, color=color, size=size)
# Add edges
for u, v, data in graph.edges(data=True):
relationship = data.get('relationship', 'connected')
title = f"{u} → {v}
Relationship: {relationship}"
net.add_edge(u, v, title=title, arrows="to", color="gray")
# Save to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w')
net.save_graph(temp_file.name)
temp_file.close()
return temp_file.name
def _create_empty_pyvis_graph(self) -> str:
"""Create an empty pyvis graph."""
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black")
net.add_node(1, label="No graph data", color="#cccccc")
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w')
net.save_graph(temp_file.name)
temp_file.close()
return temp_file.name
def get_visualization_options(self) -> List[str]:
"""Get available visualization types."""
return ["matplotlib", "plotly", "pyvis", "vis.js"]
def get_relationship_types(self, graph: nx.DiGraph) -> List[str]:
"""Get unique relationship types from the graph."""
types = set()
for u, v, data in graph.edges(data=True):
types.add(data.get('relationship', 'unknown'))
return sorted(list(types))