import networkx as nx from typing import List, Dict, Any, Tuple import json class GraphBuilder: def __init__(self): self.graph = nx.DiGraph() # Directed graph for relationships def build_graph(self, entities: List[Dict[str, Any]], relationships: List[Dict[str, Any]]) -> nx.DiGraph: """Build NetworkX graph from entities and relationships.""" self.graph.clear() # Add entities as nodes for entity in entities: node_id = entity.get("name", "").strip() if node_id: self.graph.add_node( node_id, type=entity.get("type", "UNKNOWN"), importance=entity.get("importance", 0.0), description=entity.get("description", ""), size=self._calculate_node_size(entity.get("importance", 0.0)) ) # Add relationships as edges for relationship in relationships: source = relationship.get("source", "").strip() target = relationship.get("target", "").strip() rel_type = relationship.get("relationship", "related_to") description = relationship.get("description", "") if source and target and source in self.graph.nodes and target in self.graph.nodes: self.graph.add_edge( source, target, relationship=rel_type, description=description, weight=1.0 ) return self.graph def _calculate_node_size(self, importance: float) -> int: """Calculate node size based on importance score.""" # Map importance (0.0-1.0) to node size (10-50) min_size, max_size = 10, 50 return int(min_size + (max_size - min_size) * importance) def get_graph_statistics(self) -> Dict[str, Any]: """Get basic statistics about the graph.""" if not self.graph.nodes(): return { "num_nodes": 0, "num_edges": 0, "density": 0.0, "is_connected": False, "num_components": 0 } # Convert to undirected for connectivity analysis undirected = self.graph.to_undirected() return { "num_nodes": self.graph.number_of_nodes(), "num_edges": self.graph.number_of_edges(), "density": nx.density(self.graph), "is_connected": nx.is_connected(undirected), "num_components": nx.number_connected_components(undirected), "avg_degree": sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes() if self.graph.number_of_nodes() > 0 else 0 } def get_central_nodes(self, top_k: int = 5) -> List[Tuple[str, float]]: """Get most central nodes using various centrality measures.""" if not self.graph.nodes(): return [] centralities = {} # Degree centrality degree_centrality = nx.degree_centrality(self.graph) # Betweenness centrality (if graph has enough nodes) if self.graph.number_of_nodes() > 2: betweenness_centrality = nx.betweenness_centrality(self.graph) else: betweenness_centrality = {node: 0.0 for node in self.graph.nodes()} # PageRank try: pagerank = nx.pagerank(self.graph) except: pagerank = {node: 1.0/self.graph.number_of_nodes() for node in self.graph.nodes()} # Combine centrality measures for node in self.graph.nodes(): importance = self.graph.nodes[node].get('importance', 0.0) combined_score = ( 0.3 * degree_centrality.get(node, 0.0) + 0.3 * betweenness_centrality.get(node, 0.0) + 0.2 * pagerank.get(node, 0.0) + 0.2 * importance ) centralities[node] = combined_score # Sort by centrality score sorted_nodes = sorted(centralities.items(), key=lambda x: x[1], reverse=True) return sorted_nodes[:top_k] def filter_graph(self, entity_types: List[str] = None, min_importance: float = None, relationship_types: List[str] = None) -> nx.DiGraph: """Filter graph by entity types, importance, or relationship types.""" filtered_graph = self.graph.copy() # Filter nodes by type and importance nodes_to_remove = [] for node, data in filtered_graph.nodes(data=True): if entity_types and data.get('type') not in entity_types: nodes_to_remove.append(node) elif min_importance and data.get('importance', 0.0) < min_importance: nodes_to_remove.append(node) filtered_graph.remove_nodes_from(nodes_to_remove) # Filter edges by relationship type if relationship_types: edges_to_remove = [] for u, v, data in filtered_graph.edges(data=True): if data.get('relationship') not in relationship_types: edges_to_remove.append((u, v)) filtered_graph.remove_edges_from(edges_to_remove) return filtered_graph def export_graph(self, format_type: str = "json") -> str: """Export graph in various formats.""" if format_type.lower() == "json": return self._export_json() elif format_type.lower() == "graphml": return self._export_graphml() elif format_type.lower() == "gexf": return self._export_gexf() else: raise ValueError(f"Unsupported export format: {format_type}") def _export_json(self) -> str: """Export graph as JSON.""" data = { "nodes": [], "edges": [] } # Export nodes for node, attrs in self.graph.nodes(data=True): node_data = {"id": node} node_data.update(attrs) data["nodes"].append(node_data) # Export edges for u, v, attrs in self.graph.edges(data=True): edge_data = {"source": u, "target": v} edge_data.update(attrs) data["edges"].append(edge_data) return json.dumps(data, indent=2) def _export_graphml(self) -> str: """Export graph as GraphML.""" import io output = io.StringIO() nx.write_graphml(self.graph, output) return output.getvalue() def _export_gexf(self) -> str: """Export graph as GEXF.""" import io output = io.StringIO() nx.write_gexf(self.graph, output) return output.getvalue() def get_subgraph_around_node(self, node: str, radius: int = 1) -> nx.DiGraph: """Get subgraph within specified radius of a node.""" if node not in self.graph: return nx.DiGraph() # Get nodes within radius nodes_in_radius = set([node]) current_nodes = set([node]) for _ in range(radius): next_nodes = set() for n in current_nodes: # Add neighbors (both incoming and outgoing) next_nodes.update(self.graph.successors(n)) next_nodes.update(self.graph.predecessors(n)) nodes_in_radius.update(next_nodes) current_nodes = next_nodes - nodes_in_radius if not current_nodes: break return self.graph.subgraph(nodes_in_radius).copy() def get_shortest_path(self, source: str, target: str) -> List[str]: """Get shortest path between two nodes.""" try: # Convert to undirected for path finding undirected = self.graph.to_undirected() return nx.shortest_path(undirected, source, target) except (nx.NetworkXNoPath, nx.NodeNotFound): return [] def get_node_info(self, node: str) -> Dict[str, Any]: """Get detailed information about a specific node.""" if node not in self.graph: return {} node_data = dict(self.graph.nodes[node]) # Add connectivity information predecessors = list(self.graph.predecessors(node)) successors = list(self.graph.successors(node)) node_data.update({ "in_degree": self.graph.in_degree(node), "out_degree": self.graph.out_degree(node), "predecessors": predecessors, "successors": successors, "total_connections": len(predecessors) + len(successors) }) return node_data