|
import networkx as nx |
|
from typing import List, Dict, Any, Tuple |
|
import json |
|
|
|
class GraphBuilder: |
|
def __init__(self): |
|
self.graph = nx.DiGraph() |
|
|
|
def build_graph(self, entities: List[Dict[str, Any]], relationships: List[Dict[str, Any]]) -> nx.DiGraph: |
|
"""Build NetworkX graph from entities and relationships.""" |
|
self.graph.clear() |
|
|
|
|
|
for entity in entities: |
|
node_id = entity.get("name", "").strip() |
|
if node_id: |
|
self.graph.add_node( |
|
node_id, |
|
type=entity.get("type", "UNKNOWN"), |
|
importance=entity.get("importance", 0.0), |
|
description=entity.get("description", ""), |
|
size=self._calculate_node_size(entity.get("importance", 0.0)) |
|
) |
|
|
|
|
|
for relationship in relationships: |
|
source = relationship.get("source", "").strip() |
|
target = relationship.get("target", "").strip() |
|
rel_type = relationship.get("relationship", "related_to") |
|
description = relationship.get("description", "") |
|
|
|
if source and target and source in self.graph.nodes and target in self.graph.nodes: |
|
self.graph.add_edge( |
|
source, |
|
target, |
|
relationship=rel_type, |
|
description=description, |
|
weight=1.0 |
|
) |
|
|
|
return self.graph |
|
|
|
def _calculate_node_size(self, importance: float) -> int: |
|
"""Calculate node size based on importance score.""" |
|
|
|
min_size, max_size = 10, 50 |
|
return int(min_size + (max_size - min_size) * importance) |
|
|
|
def get_graph_statistics(self) -> Dict[str, Any]: |
|
"""Get basic statistics about the graph.""" |
|
if not self.graph.nodes(): |
|
return { |
|
"num_nodes": 0, |
|
"num_edges": 0, |
|
"density": 0.0, |
|
"is_connected": False, |
|
"num_components": 0 |
|
} |
|
|
|
|
|
undirected = self.graph.to_undirected() |
|
|
|
return { |
|
"num_nodes": self.graph.number_of_nodes(), |
|
"num_edges": self.graph.number_of_edges(), |
|
"density": nx.density(self.graph), |
|
"is_connected": nx.is_connected(undirected), |
|
"num_components": nx.number_connected_components(undirected), |
|
"avg_degree": sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes() if self.graph.number_of_nodes() > 0 else 0 |
|
} |
|
|
|
def get_central_nodes(self, top_k: int = 5) -> List[Tuple[str, float]]: |
|
"""Get most central nodes using various centrality measures.""" |
|
if not self.graph.nodes(): |
|
return [] |
|
|
|
centralities = {} |
|
|
|
|
|
degree_centrality = nx.degree_centrality(self.graph) |
|
|
|
|
|
if self.graph.number_of_nodes() > 2: |
|
betweenness_centrality = nx.betweenness_centrality(self.graph) |
|
else: |
|
betweenness_centrality = {node: 0.0 for node in self.graph.nodes()} |
|
|
|
|
|
try: |
|
pagerank = nx.pagerank(self.graph) |
|
except: |
|
pagerank = {node: 1.0/self.graph.number_of_nodes() for node in self.graph.nodes()} |
|
|
|
|
|
for node in self.graph.nodes(): |
|
importance = self.graph.nodes[node].get('importance', 0.0) |
|
combined_score = ( |
|
0.3 * degree_centrality.get(node, 0.0) + |
|
0.3 * betweenness_centrality.get(node, 0.0) + |
|
0.2 * pagerank.get(node, 0.0) + |
|
0.2 * importance |
|
) |
|
centralities[node] = combined_score |
|
|
|
|
|
sorted_nodes = sorted(centralities.items(), key=lambda x: x[1], reverse=True) |
|
return sorted_nodes[:top_k] |
|
|
|
def filter_graph(self, |
|
entity_types: List[str] = None, |
|
min_importance: float = None, |
|
relationship_types: List[str] = None) -> nx.DiGraph: |
|
"""Filter graph by entity types, importance, or relationship types.""" |
|
filtered_graph = self.graph.copy() |
|
|
|
|
|
nodes_to_remove = [] |
|
for node, data in filtered_graph.nodes(data=True): |
|
if entity_types and data.get('type') not in entity_types: |
|
nodes_to_remove.append(node) |
|
elif min_importance and data.get('importance', 0.0) < min_importance: |
|
nodes_to_remove.append(node) |
|
|
|
filtered_graph.remove_nodes_from(nodes_to_remove) |
|
|
|
|
|
if relationship_types: |
|
edges_to_remove = [] |
|
for u, v, data in filtered_graph.edges(data=True): |
|
if data.get('relationship') not in relationship_types: |
|
edges_to_remove.append((u, v)) |
|
filtered_graph.remove_edges_from(edges_to_remove) |
|
|
|
return filtered_graph |
|
|
|
def export_graph(self, format_type: str = "json") -> str: |
|
"""Export graph in various formats.""" |
|
if format_type.lower() == "json": |
|
return self._export_json() |
|
elif format_type.lower() == "graphml": |
|
return self._export_graphml() |
|
elif format_type.lower() == "gexf": |
|
return self._export_gexf() |
|
else: |
|
raise ValueError(f"Unsupported export format: {format_type}") |
|
|
|
def _export_json(self) -> str: |
|
"""Export graph as JSON.""" |
|
data = { |
|
"nodes": [], |
|
"edges": [] |
|
} |
|
|
|
|
|
for node, attrs in self.graph.nodes(data=True): |
|
node_data = {"id": node} |
|
node_data.update(attrs) |
|
data["nodes"].append(node_data) |
|
|
|
|
|
for u, v, attrs in self.graph.edges(data=True): |
|
edge_data = {"source": u, "target": v} |
|
edge_data.update(attrs) |
|
data["edges"].append(edge_data) |
|
|
|
return json.dumps(data, indent=2) |
|
|
|
def _export_graphml(self) -> str: |
|
"""Export graph as GraphML.""" |
|
import io |
|
output = io.StringIO() |
|
nx.write_graphml(self.graph, output) |
|
return output.getvalue() |
|
|
|
def _export_gexf(self) -> str: |
|
"""Export graph as GEXF.""" |
|
import io |
|
output = io.StringIO() |
|
nx.write_gexf(self.graph, output) |
|
return output.getvalue() |
|
|
|
def get_subgraph_around_node(self, node: str, radius: int = 1) -> nx.DiGraph: |
|
"""Get subgraph within specified radius of a node.""" |
|
if node not in self.graph: |
|
return nx.DiGraph() |
|
|
|
|
|
nodes_in_radius = set([node]) |
|
current_nodes = set([node]) |
|
|
|
for _ in range(radius): |
|
next_nodes = set() |
|
for n in current_nodes: |
|
|
|
next_nodes.update(self.graph.successors(n)) |
|
next_nodes.update(self.graph.predecessors(n)) |
|
|
|
nodes_in_radius.update(next_nodes) |
|
current_nodes = next_nodes - nodes_in_radius |
|
|
|
if not current_nodes: |
|
break |
|
|
|
return self.graph.subgraph(nodes_in_radius).copy() |
|
|
|
def get_shortest_path(self, source: str, target: str) -> List[str]: |
|
"""Get shortest path between two nodes.""" |
|
try: |
|
|
|
undirected = self.graph.to_undirected() |
|
return nx.shortest_path(undirected, source, target) |
|
except (nx.NetworkXNoPath, nx.NodeNotFound): |
|
return [] |
|
|
|
def get_node_info(self, node: str) -> Dict[str, Any]: |
|
"""Get detailed information about a specific node.""" |
|
if node not in self.graph: |
|
return {} |
|
|
|
node_data = dict(self.graph.nodes[node]) |
|
|
|
|
|
predecessors = list(self.graph.predecessors(node)) |
|
successors = list(self.graph.successors(node)) |
|
|
|
node_data.update({ |
|
"in_degree": self.graph.in_degree(node), |
|
"out_degree": self.graph.out_degree(node), |
|
"predecessors": predecessors, |
|
"successors": successors, |
|
"total_connections": len(predecessors) + len(successors) |
|
}) |
|
|
|
return node_data |
|
|