File size: 2,474 Bytes
b7556e4
08a3e96
7808af5
 
 
 
 
 
 
 
 
 
 
08a3e96
b7556e4
 
08a3e96
b7556e4
7808af5
b7556e4
 
7808af5
 
 
 
b7556e4
7808af5
 
 
 
 
 
 
 
 
b7556e4
7808af5
 
 
 
b7556e4
7808af5
 
 
 
b7556e4
7808af5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7556e4
08a3e96
b7556e4
7808af5
08a3e96
7808af5
 
b7556e4
7808af5
b7556e4
7808af5
08a3e96
7808af5
 
 
b7556e4
 
 
08a3e96
 
 
7808af5
 
 
 
08a3e96
 
b7556e4
 
 
7808af5
 
 
b7556e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# mcp/graph_metrics.py
"""
Tiny NetworkX helpers for MedGenesis graphs.

✨ 2025-06-25 REVAMP
────────────────────
β€’ Accepts edge-dicts that use either
      {'source': 'n1', 'target': 'n2'}   ← agraph / d3.js
   or {'from'  : 'n1', 'to'    : 'n2'}   ← PyVis

β€’ Silently skips malformed edges (no more KeyError).

β€’ Works whether you pass plain dicts or streamlit-agraph Node/Edge objects.
"""

from __future__ import annotations

import networkx as nx
from typing import List, Dict, Tuple, Union


# ── helpers -----------------------------------------------------------------
def _edge_ends(e: Dict) -> Tuple[str, str] | None:
    """
    Normalise edge formats.

    Returns
    -------
    (src, dst) tuple  – or None if either endpoint is missing.
    """
    src = e.get("source") or e.get("from")
    dst = e.get("target") or e.get("to")
    if src and dst:
        return str(src), str(dst)
    return None


def _node_id(n: Union[Dict, object]) -> str:
    """
    Accept either a dict *or* a streamlit_agraph.Node and return its id.
    """
    if isinstance(n, dict):
        return str(n.get("id"))
    # fallback for Node dataclass
    return str(getattr(n, "id", ""))


def _node_label(n: Union[Dict, object]) -> str:
    """
    Extract label safely from dict or Node.
    """
    if isinstance(n, dict):
        return str(n.get("label", n.get("id")))
    return str(getattr(n, "label", getattr(n, "id", "")))


# ── public API --------------------------------------------------------------
def build_nx(
    nodes: List[Dict | object],
    edges: List[Dict | object],
) -> nx.Graph:
    """
    Convert generic node/edge payloads into a NetworkX graph.
    """
    G = nx.Graph()

    # add nodes
    for n in nodes:
        nid = _node_id(n)
        if not nid:
            continue
        G.add_node(nid, label=_node_label(n))

    # add edges
    for e in edges:
        if not isinstance(e, dict):
            e = e.__dict__  # Edge dataclass β†’ dict
        ends = _edge_ends(e)
        if ends:
            G.add_edge(*ends)

    return G


def get_top_hubs(G: nx.Graph, k: int = 5) -> List[Tuple[str, float]]:
    """
    Return top-k nodes by **degree centrality**.
    """
    dc = nx.degree_centrality(G)
    return sorted(dc.items(), key=lambda x: x[1], reverse=True)[:k]


def get_density(G: nx.Graph) -> float:
    """
    Graph density in [0, 1].
    """
    return nx.density(G)