MCP_Res / mcp /graph_metrics.py
mgbam's picture
Update mcp/graph_metrics.py
7808af5 verified
# mcp/graph_metrics.py
"""
Tiny NetworkX helpers for MedGenesis graphs.
✨ 2025-06-25 REVAMP
────────────────────
β€’ Accepts edge-dicts that use either
{'source': 'n1', 'target': 'n2'} ← agraph / d3.js
or {'from' : 'n1', 'to' : 'n2'} ← PyVis
β€’ Silently skips malformed edges (no more KeyError).
β€’ Works whether you pass plain dicts or streamlit-agraph Node/Edge objects.
"""
from __future__ import annotations
import networkx as nx
from typing import List, Dict, Tuple, Union
# ── helpers -----------------------------------------------------------------
def _edge_ends(e: Dict) -> Tuple[str, str] | None:
"""
Normalise edge formats.
Returns
-------
(src, dst) tuple – or None if either endpoint is missing.
"""
src = e.get("source") or e.get("from")
dst = e.get("target") or e.get("to")
if src and dst:
return str(src), str(dst)
return None
def _node_id(n: Union[Dict, object]) -> str:
"""
Accept either a dict *or* a streamlit_agraph.Node and return its id.
"""
if isinstance(n, dict):
return str(n.get("id"))
# fallback for Node dataclass
return str(getattr(n, "id", ""))
def _node_label(n: Union[Dict, object]) -> str:
"""
Extract label safely from dict or Node.
"""
if isinstance(n, dict):
return str(n.get("label", n.get("id")))
return str(getattr(n, "label", getattr(n, "id", "")))
# ── public API --------------------------------------------------------------
def build_nx(
nodes: List[Dict | object],
edges: List[Dict | object],
) -> nx.Graph:
"""
Convert generic node/edge payloads into a NetworkX graph.
"""
G = nx.Graph()
# add nodes
for n in nodes:
nid = _node_id(n)
if not nid:
continue
G.add_node(nid, label=_node_label(n))
# add edges
for e in edges:
if not isinstance(e, dict):
e = e.__dict__ # Edge dataclass β†’ dict
ends = _edge_ends(e)
if ends:
G.add_edge(*ends)
return G
def get_top_hubs(G: nx.Graph, k: int = 5) -> List[Tuple[str, float]]:
"""
Return top-k nodes by **degree centrality**.
"""
dc = nx.degree_centrality(G)
return sorted(dc.items(), key=lambda x: x[1], reverse=True)[:k]
def get_density(G: nx.Graph) -> float:
"""
Graph density in [0, 1].
"""
return nx.density(G)