MCP_Res / mcp /graph_metrics.py
mgbam's picture
Update mcp/graph_metrics.py
1bc973b verified
raw
history blame
3.77 kB
#!/usr/bin/env python3
"""MedGenesis – NetworkX helpers (robust version)
Key upgrades over the legacy helper:
1. **Edge‑key flexibility** – `build_nx` now recognises *four* common
schemas produced by Streamlit‑agraph, PyVis, Neo4j exports or OT graphs:
• `{"source": "n1", "target": "n2"}` (agraph)
• `{"from": "n1", "to": "n2"}` (PyVis)
• `{"src": "n1", "dst": "n2"}` (neo4j/json)
• `{"u": "n1", "v": "n2"}` (NetworkX native)
2. **Weight aware** – optional numeric `weight` (or `value`) field becomes
an edge attribute (defaults to 1).
3. **Self‑loop skip** – ignores self‑edges to keep density sensible.
4. **Utility metrics** – adds `betweenness` & `clustering` helpers in
addition to top‑hub degree ranking.
"""
from __future__ import annotations
from typing import Dict, List, Tuple
import networkx as nx
__all__ = [
"build_nx",
"get_top_hubs",
"get_density",
"get_betweenness",
"get_clustering_coeff",
]
# ---------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------
def _edge_endpoints(e: Dict) -> Tuple[str, str] | None:
"""Return (src, dst) if both ends exist; else None."""
src = e.get("source") or e.get("from") or e.get("src") or e.get("u")
dst = e.get("target") or e.get("to") or e.get("dst") or e.get("v")
if src and dst and src != dst:
return str(src), str(dst)
return None
# ---------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------
def build_nx(nodes: List[Dict], edges: List[Dict]) -> nx.Graph:
"""Convert heterogeneous node/edge dicts into an undirected NetworkX graph.
Parameters
----------
nodes : list of node dicts – each must contain an `id` key; other keys
are copied as attributes.
edges : list of edge dicts – keys can be any of the recognised schemas.
Returns
-------
nx.Graph – ready for downstream centrality / drawing.
"""
G = nx.Graph()
# Nodes ----------------------------------------------------------------
for n in nodes:
node_id = str(n["id"])
attrs = {k: v for k, v in n.items() if k != "id"}
G.add_node(node_id, **attrs)
# Edges ----------------------------------------------------------------
for e in edges:
endpoints = _edge_endpoints(e)
if not endpoints:
continue
u, v = endpoints
w = e.get("weight") or e.get("value") or 1
G.add_edge(u, v, weight=float(w))
return G
# ---------------------------------------------------------------------
# Metrics helpers
# ---------------------------------------------------------------------
def get_top_hubs(G: nx.Graph, k: int = 5) -> List[Tuple[str, float]]:
"""Return top‑*k* nodes by **degree centrality**."""
dc = nx.degree_centrality(G)
return sorted(dc.items(), key=lambda kv: kv[1], reverse=True)[:k]
def get_betweenness(G: nx.Graph, k: int = 5) -> List[Tuple[str, float]]:
"""Top‑*k* nodes by betweenness centrality (approx if |V| > 500)."""
if G.number_of_nodes() > 500:
bc = nx.betweenness_centrality(G, k=200, seed=42)
else:
bc = nx.betweenness_centrality(G)
return sorted(bc.items(), key=lambda kv: kv[1], reverse=True)[:k]
def get_clustering_coeff(G: nx.Graph) -> float:
"""Return average clustering coefficient (0‑1)."""
return nx.average_clustering(G)
def get_density(G: nx.Graph) -> float:
"""Graph density in [0, 1]."""
return nx.density(G)