File size: 6,617 Bytes
eb1f007
633ba95
eb1f007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633ba95
3d1def9
eb1f007
f3dd8bc
eb1f007
3d1def9
633ba95
 
eb1f007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633ba95
eb1f007
 
 
 
 
 
633ba95
 
eb1f007
 
 
 
 
 
 
 
 
 
 
 
 
3d1def9
633ba95
 
a392df0
eb1f007
 
 
633ba95
 
eb1f007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633ba95
 
eb1f007
 
633ba95
 
f3dd8bc
eb1f007
 
633ba95
eb1f007
a392df0
633ba95
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# mcp/knowledge_graph.py
"""
Build an interactive Network Graph (Streamlit-Agraph) that links together
the heterogeneous entities returned by the MedGenesis pipeline:

• Papers (arXiv / PubMed)
• UMLS concepts
• DrugSafety records (OpenFDA)
• Genes (optional – when available)

Design goals
------------
1. **Resilience.**  Any malformed record (e.g. a RuntimeError placeholder from
   an upstream failure) is skipped silently - the graph still renders.
2. **Deduplication.**  Identical nodes (same CUI, drug-name, paper ID…) are
   added only once; edges are de-duplicated as well.
3. **Visual semantics.**
      • Papers          → blue   (#3498db)
      • UMLS concepts   → green  (#2ecc71)
      • Drugs / safety  → orange (#e67e22)
      • Genes           → purple (#9b59b6)
4. **Config tuned for large graphs.**  Collapsible nodes, physics-based layout,
   highlight on hover, full-width canvas.
"""

from __future__ import annotations

from typing import Any, List, Tuple, Dict, Iterable
import re
from collections import defaultdict

from streamlit_agraph import Node, Edge, Config

# --------------------------------------------------------------------------- #
# utility helpers                                                             #
# --------------------------------------------------------------------------- #
RGB = {
    "paper":   "#3498db",
    "umls":    "#2ecc71",
    "drug":    "#e67e22",
    "gene":    "#9b59b6",
}


def _safe_iter(obj: Any) -> Iterable:
    """Yield from *obj* if it is list-like, else yield the obj itself."""
    if obj is None:
        return ()
    if isinstance(obj, (list, tuple, set)):
        return obj
    return (obj,)


def _dedup(seq: Iterable[Tuple]) -> List[Tuple]:
    """Remove duplicates while preserving order."""
    seen = set()
    out: List[Tuple] = []
    for item in seq:
        if item not in seen:
            out.append(item)
            seen.add(item)
    return out


# --------------------------------------------------------------------------- #
# main builder                                                                #
# --------------------------------------------------------------------------- #
def build_agraph(
    papers: List[Dict[str, Any]],
    umls:   List[Dict[str, Any]],
    drug_safety: List[Dict[str, Any]],
    genes: List[Dict[str, Any]] | None = None,
) -> Tuple[List[Node], List[Edge], Config]:
    """
    Parameters
    ----------
    papers : list[dict]
        Output of `fetch_arxiv`/`fetch_pubmed` (must have *title* & *summary*).
    umls : list[dict]
        Items from `lookup_umls` (may be RuntimeError objects if failed).
    drug_safety : list[dict]
        Flattened OpenFDA adverse-event records.
    genes : list[dict] | None
        Optional – from gene resolver hub.

    Returns
    -------
    nodes, edges, config
        Objects suitable for `streamlit_agraph.agraph(...)`
    """
    nodes: List[Node] = []
    edges: List[Edge] = []

    # ---- Papers ----------------------------------------------------------- #
    for idx, p in enumerate(_safe_iter(papers)):
        if not isinstance(p, dict):
            continue
        pid = f"paper_{idx}"
        label = f"P{idx + 1}"
        tooltip = p.get("title", "Paper")
        nodes.append(Node(id=pid, label=label, tooltip=tooltip, size=15,
                          color=RGB["paper"]))

    # ---- UMLS concepts ---------------------------------------------------- #
    for c in _safe_iter(umls):
        if not isinstance(c, dict):
            continue
        cui = c.get("cui")
        name = c.get("name")
        if cui and name:
            nodes.append(Node(id=f"umls_{cui}", label=name, size=22,
                              color=RGB["umls"]))

    # ---- Drug Safety ------------------------------------------------------ #
    for i, rec in enumerate(_safe_iter(drug_safety)):
        if not isinstance(rec, dict):
            continue
        dn = (
            rec.get("drug_name")
            or rec.get("medicinalproduct")
            or "drug_{}".format(i)
        )
        nodes.append(Node(id=f"drug_{i}", label=dn, size=25,
                          color=RGB["drug"]))

    # ---- Genes (optional) -------------------------------------------------- #
    if genes:
        for g in _safe_iter(genes):
            if not isinstance(g, dict):
                continue
            sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name")
            if sym:
                gid = f"gene_{sym}"
                nodes.append(Node(id=gid, label=sym, size=20,
                                  color=RGB["gene"]))

    # ---------------------------------------------------------------------- #
    # Edges – naïve co-occurrence linking                                    #
    # ---------------------------------------------------------------------- #
    paper_texts = [
        (n.id, f"{p.get('title','')} {p.get('summary','')}".lower())
        for n, p in zip(nodes, papers)
        if n.id.startswith("paper_")
    ]

    # connect paper ↔ umls / drugs / genes if mention appears in text
    def _link(target_nodes: List[Node], pattern_getter):
        for nid, blob in paper_texts:
            for tn in target_nodes:
                pat = pattern_getter(tn)
                if pat and pat.search(blob):
                    edges.append(Edge(source=nid, target=tn.id, label="mentions"))

    umls_nodes = [n for n in nodes if n.id.startswith("umls_")]
    drug_nodes = [n for n in nodes if n.id.startswith("drug_")]
    gene_nodes = [n for n in nodes if n.id.startswith("gene_")]

    _link(umls_nodes, lambda n: re.compile(re.escape(n.label.lower())))
    _link(drug_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
    _link(gene_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))

    # de-duplicate everything ------------------------------------------------ #
    nodes = _dedup(nodes)
    edges = _dedup(edges)

    # ---------------------------------------------------------------------- #
    # Graph config                                                            #
    # ---------------------------------------------------------------------- #
    cfg = Config(
        width="100%",
        height="650px",
        directed=False,
        nodeHighlightBehavior=True,
        highlightColor="#f1c40f",
        collapsible=True,
        physics=True,
        hierarchical=False,
        node={"labelProperty": "label"},
        link={"labelProperty": "label", "renderLabel": False},
    )

    return nodes, edges, cfg