|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
import numpy as np |
|
from typing import Dict, List, Any, Tuple, Optional |
|
import json |
|
import io |
|
import base64 |
|
import tempfile |
|
import os |
|
import plotly.graph_objects as go |
|
import plotly.express as px |
|
from pyvis.network import Network |
|
|
|
class GraphVisualizer: |
|
def __init__(self): |
|
self.color_map = { |
|
'PERSON': '#FF6B6B', |
|
'ORGANIZATION': '#4ECDC4', |
|
'LOCATION': '#45B7D1', |
|
'CONCEPT': '#96CEB4', |
|
'EVENT': '#FFEAA7', |
|
'OBJECT': '#DDA0DD', |
|
'UNKNOWN': '#95A5A6' |
|
} |
|
|
|
def visualize_graph(self, |
|
graph: nx.DiGraph, |
|
layout_type: str = "spring", |
|
show_labels: bool = True, |
|
show_edge_labels: bool = False, |
|
node_size_factor: float = 1.0, |
|
figsize: Tuple[int, int] = (12, 8)) -> str: |
|
"""Create a matplotlib visualization of the graph and return file path.""" |
|
|
|
if not graph.nodes(): |
|
return self._create_empty_graph_image() |
|
|
|
|
|
plt.figure(figsize=figsize) |
|
plt.clf() |
|
|
|
|
|
pos = self._calculate_layout(graph, layout_type) |
|
|
|
|
|
node_colors = [self.color_map.get(graph.nodes[node].get('type', 'UNKNOWN'), '#95A5A6') |
|
for node in graph.nodes()] |
|
node_sizes = [graph.nodes[node].get('size', 20) * node_size_factor * 10 |
|
for node in graph.nodes()] |
|
|
|
|
|
nx.draw_networkx_nodes(graph, pos, |
|
node_color=node_colors, |
|
node_size=node_sizes, |
|
alpha=0.8) |
|
|
|
|
|
nx.draw_networkx_edges(graph, pos, |
|
edge_color='gray', |
|
arrows=True, |
|
arrowsize=20, |
|
alpha=0.6, |
|
width=1.5) |
|
|
|
|
|
if show_labels: |
|
|
|
labels = {} |
|
for node in graph.nodes(): |
|
importance = graph.nodes[node].get('importance', 0.0) |
|
labels[node] = f"{node}\n({importance:.2f})" |
|
|
|
nx.draw_networkx_labels(graph, pos, labels, font_size=8) |
|
|
|
|
|
if show_edge_labels: |
|
edge_labels = {(u, v): data.get('relationship', '') |
|
for u, v, data in graph.edges(data=True)} |
|
nx.draw_networkx_edge_labels(graph, pos, edge_labels, font_size=6) |
|
|
|
plt.title("Knowledge Graph", fontsize=16, fontweight='bold') |
|
plt.axis('off') |
|
plt.tight_layout() |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
return temp_file.name |
|
|
|
def _calculate_layout(self, graph: nx.DiGraph, layout_type: str) -> Dict[str, Tuple[float, float]]: |
|
"""Calculate node positions using specified layout algorithm.""" |
|
try: |
|
if layout_type == "spring": |
|
return nx.spring_layout(graph, k=1, iterations=50) |
|
elif layout_type == "circular": |
|
return nx.circular_layout(graph) |
|
elif layout_type == "shell": |
|
return nx.shell_layout(graph) |
|
elif layout_type == "kamada_kawai": |
|
return nx.kamada_kawai_layout(graph) |
|
elif layout_type == "random": |
|
return nx.random_layout(graph) |
|
else: |
|
return nx.spring_layout(graph, k=1, iterations=50) |
|
except: |
|
|
|
return nx.spring_layout(graph, k=1, iterations=50) |
|
|
|
def _create_empty_graph_image(self) -> str: |
|
"""Create an image for empty graph.""" |
|
plt.figure(figsize=(8, 6)) |
|
plt.text(0.5, 0.5, 'No graph data to display', |
|
horizontalalignment='center', verticalalignment='center', |
|
fontsize=16, transform=plt.gca().transAxes) |
|
plt.axis('off') |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
return temp_file.name |
|
|
|
def create_interactive_html(self, graph: nx.DiGraph) -> str: |
|
"""Create an interactive HTML visualization using vis.js.""" |
|
if not graph.nodes(): |
|
return "<div>No graph data to display</div>" |
|
|
|
|
|
nodes = [] |
|
edges = [] |
|
|
|
for node, data in graph.nodes(data=True): |
|
nodes.append({ |
|
"id": node, |
|
"label": node, |
|
"color": self.color_map.get(data.get('type', 'UNKNOWN'), '#95A5A6'), |
|
"size": data.get('size', 20), |
|
"title": f"Type: {data.get('type', 'UNKNOWN')}<br>" |
|
f"Importance: {data.get('importance', 0.0):.2f}<br>" |
|
f"Description: {data.get('description', 'N/A')}" |
|
}) |
|
|
|
for u, v, data in graph.edges(data=True): |
|
edges.append({ |
|
"from": u, |
|
"to": v, |
|
"label": data.get('relationship', ''), |
|
"title": data.get('description', ''), |
|
"arrows": {"to": {"enabled": True}} |
|
}) |
|
|
|
html_template = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script> |
|
<style> |
|
#mynetworkid {{ |
|
width: 100%; |
|
height: 600px; |
|
border: 1px solid lightgray; |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<div id="mynetworkid"></div> |
|
|
|
<script> |
|
var nodes = new vis.DataSet({json.dumps(nodes)}); |
|
var edges = new vis.DataSet({json.dumps(edges)}); |
|
var container = document.getElementById('mynetworkid'); |
|
|
|
var data = {{ |
|
nodes: nodes, |
|
edges: edges |
|
}}; |
|
|
|
var options = {{ |
|
nodes: {{ |
|
shape: 'dot', |
|
scaling: {{ |
|
min: 10, |
|
max: 30 |
|
}}, |
|
font: {{ |
|
size: 12, |
|
face: 'Tahoma' |
|
}} |
|
}}, |
|
edges: {{ |
|
font: {{align: 'middle'}}, |
|
color: {{color:'gray'}}, |
|
arrows: {{to: {{enabled: true, scaleFactor: 1}}}} |
|
}}, |
|
physics: {{ |
|
enabled: true, |
|
stabilization: {{enabled: true, iterations: 200}} |
|
}}, |
|
interaction: {{ |
|
hover: true, |
|
tooltipDelay: 200 |
|
}} |
|
}}; |
|
|
|
var network = new vis.Network(container, data, options); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
return html_template |
|
|
|
def create_statistics_summary(self, graph: nx.DiGraph, stats: Dict[str, Any]) -> str: |
|
"""Create a formatted statistics summary.""" |
|
if not graph.nodes(): |
|
return "No graph statistics available." |
|
|
|
|
|
type_counts = {} |
|
for node, data in graph.nodes(data=True): |
|
node_type = data.get('type', 'UNKNOWN') |
|
type_counts[node_type] = type_counts.get(node_type, 0) + 1 |
|
|
|
|
|
rel_counts = {} |
|
for u, v, data in graph.edges(data=True): |
|
rel_type = data.get('relationship', 'unknown') |
|
rel_counts[rel_type] = rel_counts.get(rel_type, 0) + 1 |
|
|
|
summary = f""" |
|
## Graph Statistics |
|
|
|
**Basic Metrics:** |
|
- Nodes: {stats['num_nodes']} |
|
- Edges: {stats['num_edges']} |
|
- Density: {stats['density']:.3f} |
|
- Connected: {'Yes' if stats['is_connected'] else 'No'} |
|
- Components: {stats['num_components']} |
|
- Average Degree: {stats['avg_degree']:.2f} |
|
|
|
**Entity Types:** |
|
""" |
|
|
|
for entity_type, count in sorted(type_counts.items()): |
|
summary += f"\n- {entity_type}: {count}" |
|
|
|
summary += "\n\n**Relationship Types:**" |
|
for rel_type, count in sorted(rel_counts.items()): |
|
summary += f"\n- {rel_type}: {count}" |
|
|
|
return summary |
|
|
|
def create_entity_list(self, graph: nx.DiGraph, sort_by: str = "importance") -> str: |
|
"""Create a formatted list of entities.""" |
|
if not graph.nodes(): |
|
return "No entities found." |
|
|
|
entities = [] |
|
for node, data in graph.nodes(data=True): |
|
entities.append({ |
|
'name': node, |
|
'type': data.get('type', 'UNKNOWN'), |
|
'importance': data.get('importance', 0.0), |
|
'description': data.get('description', 'N/A'), |
|
'connections': graph.degree(node) |
|
}) |
|
|
|
|
|
if sort_by == "importance": |
|
entities.sort(key=lambda x: x['importance'], reverse=True) |
|
elif sort_by == "connections": |
|
entities.sort(key=lambda x: x['connections'], reverse=True) |
|
elif sort_by == "name": |
|
entities.sort(key=lambda x: x['name']) |
|
|
|
entity_list = "## Entities\n\n" |
|
for entity in entities: |
|
entity_list += f""" |
|
**{entity['name']}** ({entity['type']}) |
|
- Importance: {entity['importance']:.2f} |
|
- Connections: {entity['connections']} |
|
- Description: {entity['description']} |
|
|
|
""" |
|
|
|
return entity_list |
|
|
|
def get_layout_options(self) -> List[str]: |
|
"""Get available layout options.""" |
|
return ["spring", "circular", "shell", "kamada_kawai", "random"] |
|
|
|
def get_entity_types(self, graph: nx.DiGraph) -> List[str]: |
|
"""Get unique entity types from the graph.""" |
|
types = set() |
|
for node, data in graph.nodes(data=True): |
|
types.add(data.get('type', 'UNKNOWN')) |
|
return sorted(list(types)) |
|
|
|
def create_plotly_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> go.Figure: |
|
"""Create an interactive Plotly visualization of the graph.""" |
|
if not graph.nodes(): |
|
|
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text="No graph data to display", |
|
xref="paper", yref="paper", |
|
x=0.5, y=0.5, xanchor='center', yanchor='middle', |
|
showarrow=False, font=dict(size=16) |
|
) |
|
return fig |
|
|
|
|
|
pos = self._calculate_layout(graph, layout_type) |
|
|
|
|
|
node_x = [] |
|
node_y = [] |
|
node_text = [] |
|
node_info = [] |
|
node_colors = [] |
|
node_sizes = [] |
|
|
|
for node in graph.nodes(): |
|
x, y = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
|
|
data = graph.nodes[node] |
|
node_type = data.get('type', 'UNKNOWN') |
|
importance = data.get('importance', 0.0) |
|
description = data.get('description', 'N/A') |
|
connections = graph.degree(node) |
|
|
|
node_text.append(node) |
|
node_info.append( |
|
f"<b>{node}</b><br>" |
|
f"Type: {node_type}<br>" |
|
f"Importance: {importance:.2f}<br>" |
|
f"Connections: {connections}<br>" |
|
f"Description: {description}" |
|
) |
|
node_colors.append(self.color_map.get(node_type, '#95A5A6')) |
|
node_sizes.append(max(10, data.get('size', 20))) |
|
|
|
|
|
edge_x = [] |
|
edge_y = [] |
|
edge_info = [] |
|
|
|
for edge in graph.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
|
|
edge_data = graph.edges[edge] |
|
relationship = edge_data.get('relationship', 'connected') |
|
edge_info.append(f"{edge[0]} → {edge[1]}<br>Relationship: {relationship}") |
|
|
|
|
|
edge_trace = go.Scatter( |
|
x=edge_x, y=edge_y, |
|
line=dict(width=2, color='gray'), |
|
hoverinfo='none', |
|
mode='lines' |
|
) |
|
|
|
|
|
node_trace = go.Scatter( |
|
x=node_x, y=node_y, |
|
mode='markers+text', |
|
hoverinfo='text', |
|
text=node_text, |
|
hovertext=node_info, |
|
textposition="middle center", |
|
marker=dict( |
|
size=node_sizes, |
|
color=node_colors, |
|
line=dict(width=2, color='white') |
|
) |
|
) |
|
|
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
layout=go.Layout( |
|
title='Interactive Knowledge Graph', |
|
titlefont_size=16, |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
annotations=[ dict( |
|
text="Hover over nodes for details. Drag to pan, scroll to zoom.", |
|
showarrow=False, |
|
xref="paper", yref="paper", |
|
x=0.005, y=-0.002, |
|
xanchor='left', yanchor='bottom', |
|
font=dict(color="gray", size=12) |
|
)], |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
plot_bgcolor='white' |
|
)) |
|
|
|
return fig |
|
|
|
def create_pyvis_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> str: |
|
"""Create an interactive pyvis visualization and return HTML file path.""" |
|
if not graph.nodes(): |
|
return self._create_empty_pyvis_graph() |
|
|
|
|
|
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") |
|
|
|
|
|
net.set_options(""" |
|
{ |
|
"physics": { |
|
"enabled": true, |
|
"stabilization": {"enabled": true, "iterations": 200}, |
|
"barnesHut": { |
|
"gravitationalConstant": -2000, |
|
"centralGravity": 0.3, |
|
"springLength": 95, |
|
"springConstant": 0.04, |
|
"damping": 0.09 |
|
} |
|
}, |
|
"interaction": { |
|
"hover": true, |
|
"tooltipDelay": 200, |
|
"hideEdgesOnDrag": false |
|
} |
|
} |
|
""") |
|
|
|
|
|
for node, data in graph.nodes(data=True): |
|
node_type = data.get('type', 'UNKNOWN') |
|
importance = data.get('importance', 0.0) |
|
description = data.get('description', 'N/A') |
|
connections = graph.degree(node) |
|
|
|
|
|
color = self.color_map.get(node_type, '#95A5A6') |
|
size = max(10, data.get('size', 20)) |
|
|
|
|
|
title = f""" |
|
<b>{node}</b><br> |
|
Type: {node_type}<br> |
|
Importance: {importance:.2f}<br> |
|
Connections: {connections}<br> |
|
Description: {description} |
|
""" |
|
|
|
net.add_node(node, label=node, title=title, color=color, size=size) |
|
|
|
|
|
for u, v, data in graph.edges(data=True): |
|
relationship = data.get('relationship', 'connected') |
|
title = f"{u} → {v}<br>Relationship: {relationship}" |
|
|
|
net.add_edge(u, v, title=title, arrows="to", color="gray") |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') |
|
net.save_graph(temp_file.name) |
|
temp_file.close() |
|
|
|
return temp_file.name |
|
|
|
def _create_empty_pyvis_graph(self) -> str: |
|
"""Create an empty pyvis graph.""" |
|
net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") |
|
net.add_node(1, label="No graph data", color="#cccccc") |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') |
|
net.save_graph(temp_file.name) |
|
temp_file.close() |
|
|
|
return temp_file.name |
|
|
|
def get_visualization_options(self) -> List[str]: |
|
"""Get available visualization types.""" |
|
return ["matplotlib", "plotly", "pyvis", "vis.js"] |
|
|
|
def get_relationship_types(self, graph: nx.DiGraph) -> List[str]: |
|
"""Get unique relationship types from the graph.""" |
|
types = set() |
|
for u, v, data in graph.edges(data=True): |
|
types.add(data.get('relationship', 'unknown')) |
|
return sorted(list(types)) |
|
|