Generate-Knowledge-Graphs / src /graph_builder.py
CultriX's picture
First commit
e86199a
raw
history blame
8.95 kB
import networkx as nx
from typing import List, Dict, Any, Tuple
import json
class GraphBuilder:
def __init__(self):
self.graph = nx.DiGraph() # Directed graph for relationships
def build_graph(self, entities: List[Dict[str, Any]], relationships: List[Dict[str, Any]]) -> nx.DiGraph:
"""Build NetworkX graph from entities and relationships."""
self.graph.clear()
# Add entities as nodes
for entity in entities:
node_id = entity.get("name", "").strip()
if node_id:
self.graph.add_node(
node_id,
type=entity.get("type", "UNKNOWN"),
importance=entity.get("importance", 0.0),
description=entity.get("description", ""),
size=self._calculate_node_size(entity.get("importance", 0.0))
)
# Add relationships as edges
for relationship in relationships:
source = relationship.get("source", "").strip()
target = relationship.get("target", "").strip()
rel_type = relationship.get("relationship", "related_to")
description = relationship.get("description", "")
if source and target and source in self.graph.nodes and target in self.graph.nodes:
self.graph.add_edge(
source,
target,
relationship=rel_type,
description=description,
weight=1.0
)
return self.graph
def _calculate_node_size(self, importance: float) -> int:
"""Calculate node size based on importance score."""
# Map importance (0.0-1.0) to node size (10-50)
min_size, max_size = 10, 50
return int(min_size + (max_size - min_size) * importance)
def get_graph_statistics(self) -> Dict[str, Any]:
"""Get basic statistics about the graph."""
if not self.graph.nodes():
return {
"num_nodes": 0,
"num_edges": 0,
"density": 0.0,
"is_connected": False,
"num_components": 0
}
# Convert to undirected for connectivity analysis
undirected = self.graph.to_undirected()
return {
"num_nodes": self.graph.number_of_nodes(),
"num_edges": self.graph.number_of_edges(),
"density": nx.density(self.graph),
"is_connected": nx.is_connected(undirected),
"num_components": nx.number_connected_components(undirected),
"avg_degree": sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes() if self.graph.number_of_nodes() > 0 else 0
}
def get_central_nodes(self, top_k: int = 5) -> List[Tuple[str, float]]:
"""Get most central nodes using various centrality measures."""
if not self.graph.nodes():
return []
centralities = {}
# Degree centrality
degree_centrality = nx.degree_centrality(self.graph)
# Betweenness centrality (if graph has enough nodes)
if self.graph.number_of_nodes() > 2:
betweenness_centrality = nx.betweenness_centrality(self.graph)
else:
betweenness_centrality = {node: 0.0 for node in self.graph.nodes()}
# PageRank
try:
pagerank = nx.pagerank(self.graph)
except:
pagerank = {node: 1.0/self.graph.number_of_nodes() for node in self.graph.nodes()}
# Combine centrality measures
for node in self.graph.nodes():
importance = self.graph.nodes[node].get('importance', 0.0)
combined_score = (
0.3 * degree_centrality.get(node, 0.0) +
0.3 * betweenness_centrality.get(node, 0.0) +
0.2 * pagerank.get(node, 0.0) +
0.2 * importance
)
centralities[node] = combined_score
# Sort by centrality score
sorted_nodes = sorted(centralities.items(), key=lambda x: x[1], reverse=True)
return sorted_nodes[:top_k]
def filter_graph(self,
entity_types: List[str] = None,
min_importance: float = None,
relationship_types: List[str] = None) -> nx.DiGraph:
"""Filter graph by entity types, importance, or relationship types."""
filtered_graph = self.graph.copy()
# Filter nodes by type and importance
nodes_to_remove = []
for node, data in filtered_graph.nodes(data=True):
if entity_types and data.get('type') not in entity_types:
nodes_to_remove.append(node)
elif min_importance and data.get('importance', 0.0) < min_importance:
nodes_to_remove.append(node)
filtered_graph.remove_nodes_from(nodes_to_remove)
# Filter edges by relationship type
if relationship_types:
edges_to_remove = []
for u, v, data in filtered_graph.edges(data=True):
if data.get('relationship') not in relationship_types:
edges_to_remove.append((u, v))
filtered_graph.remove_edges_from(edges_to_remove)
return filtered_graph
def export_graph(self, format_type: str = "json") -> str:
"""Export graph in various formats."""
if format_type.lower() == "json":
return self._export_json()
elif format_type.lower() == "graphml":
return self._export_graphml()
elif format_type.lower() == "gexf":
return self._export_gexf()
else:
raise ValueError(f"Unsupported export format: {format_type}")
def _export_json(self) -> str:
"""Export graph as JSON."""
data = {
"nodes": [],
"edges": []
}
# Export nodes
for node, attrs in self.graph.nodes(data=True):
node_data = {"id": node}
node_data.update(attrs)
data["nodes"].append(node_data)
# Export edges
for u, v, attrs in self.graph.edges(data=True):
edge_data = {"source": u, "target": v}
edge_data.update(attrs)
data["edges"].append(edge_data)
return json.dumps(data, indent=2)
def _export_graphml(self) -> str:
"""Export graph as GraphML."""
import io
output = io.StringIO()
nx.write_graphml(self.graph, output)
return output.getvalue()
def _export_gexf(self) -> str:
"""Export graph as GEXF."""
import io
output = io.StringIO()
nx.write_gexf(self.graph, output)
return output.getvalue()
def get_subgraph_around_node(self, node: str, radius: int = 1) -> nx.DiGraph:
"""Get subgraph within specified radius of a node."""
if node not in self.graph:
return nx.DiGraph()
# Get nodes within radius
nodes_in_radius = set([node])
current_nodes = set([node])
for _ in range(radius):
next_nodes = set()
for n in current_nodes:
# Add neighbors (both incoming and outgoing)
next_nodes.update(self.graph.successors(n))
next_nodes.update(self.graph.predecessors(n))
nodes_in_radius.update(next_nodes)
current_nodes = next_nodes - nodes_in_radius
if not current_nodes:
break
return self.graph.subgraph(nodes_in_radius).copy()
def get_shortest_path(self, source: str, target: str) -> List[str]:
"""Get shortest path between two nodes."""
try:
# Convert to undirected for path finding
undirected = self.graph.to_undirected()
return nx.shortest_path(undirected, source, target)
except (nx.NetworkXNoPath, nx.NodeNotFound):
return []
def get_node_info(self, node: str) -> Dict[str, Any]:
"""Get detailed information about a specific node."""
if node not in self.graph:
return {}
node_data = dict(self.graph.nodes[node])
# Add connectivity information
predecessors = list(self.graph.predecessors(node))
successors = list(self.graph.successors(node))
node_data.update({
"in_degree": self.graph.in_degree(node),
"out_degree": self.graph.out_degree(node),
"predecessors": predecessors,
"successors": successors,
"total_connections": len(predecessors) + len(successors)
})
return node_data