File size: 6,617 Bytes
eb1f007 633ba95 eb1f007 633ba95 3d1def9 eb1f007 f3dd8bc eb1f007 3d1def9 633ba95 eb1f007 633ba95 eb1f007 633ba95 eb1f007 3d1def9 633ba95 a392df0 eb1f007 633ba95 eb1f007 633ba95 eb1f007 633ba95 f3dd8bc eb1f007 633ba95 eb1f007 a392df0 633ba95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# mcp/knowledge_graph.py
"""
Build an interactive Network Graph (Streamlit-Agraph) that links together
the heterogeneous entities returned by the MedGenesis pipeline:
• Papers (arXiv / PubMed)
• UMLS concepts
• DrugSafety records (OpenFDA)
• Genes (optional – when available)
Design goals
------------
1. **Resilience.** Any malformed record (e.g. a RuntimeError placeholder from
an upstream failure) is skipped silently - the graph still renders.
2. **Deduplication.** Identical nodes (same CUI, drug-name, paper ID…) are
added only once; edges are de-duplicated as well.
3. **Visual semantics.**
• Papers → blue (#3498db)
• UMLS concepts → green (#2ecc71)
• Drugs / safety → orange (#e67e22)
• Genes → purple (#9b59b6)
4. **Config tuned for large graphs.** Collapsible nodes, physics-based layout,
highlight on hover, full-width canvas.
"""
from __future__ import annotations
from typing import Any, List, Tuple, Dict, Iterable
import re
from collections import defaultdict
from streamlit_agraph import Node, Edge, Config
# --------------------------------------------------------------------------- #
# utility helpers #
# --------------------------------------------------------------------------- #
RGB = {
"paper": "#3498db",
"umls": "#2ecc71",
"drug": "#e67e22",
"gene": "#9b59b6",
}
def _safe_iter(obj: Any) -> Iterable:
"""Yield from *obj* if it is list-like, else yield the obj itself."""
if obj is None:
return ()
if isinstance(obj, (list, tuple, set)):
return obj
return (obj,)
def _dedup(seq: Iterable[Tuple]) -> List[Tuple]:
"""Remove duplicates while preserving order."""
seen = set()
out: List[Tuple] = []
for item in seq:
if item not in seen:
out.append(item)
seen.add(item)
return out
# --------------------------------------------------------------------------- #
# main builder #
# --------------------------------------------------------------------------- #
def build_agraph(
papers: List[Dict[str, Any]],
umls: List[Dict[str, Any]],
drug_safety: List[Dict[str, Any]],
genes: List[Dict[str, Any]] | None = None,
) -> Tuple[List[Node], List[Edge], Config]:
"""
Parameters
----------
papers : list[dict]
Output of `fetch_arxiv`/`fetch_pubmed` (must have *title* & *summary*).
umls : list[dict]
Items from `lookup_umls` (may be RuntimeError objects if failed).
drug_safety : list[dict]
Flattened OpenFDA adverse-event records.
genes : list[dict] | None
Optional – from gene resolver hub.
Returns
-------
nodes, edges, config
Objects suitable for `streamlit_agraph.agraph(...)`
"""
nodes: List[Node] = []
edges: List[Edge] = []
# ---- Papers ----------------------------------------------------------- #
for idx, p in enumerate(_safe_iter(papers)):
if not isinstance(p, dict):
continue
pid = f"paper_{idx}"
label = f"P{idx + 1}"
tooltip = p.get("title", "Paper")
nodes.append(Node(id=pid, label=label, tooltip=tooltip, size=15,
color=RGB["paper"]))
# ---- UMLS concepts ---------------------------------------------------- #
for c in _safe_iter(umls):
if not isinstance(c, dict):
continue
cui = c.get("cui")
name = c.get("name")
if cui and name:
nodes.append(Node(id=f"umls_{cui}", label=name, size=22,
color=RGB["umls"]))
# ---- Drug Safety ------------------------------------------------------ #
for i, rec in enumerate(_safe_iter(drug_safety)):
if not isinstance(rec, dict):
continue
dn = (
rec.get("drug_name")
or rec.get("medicinalproduct")
or "drug_{}".format(i)
)
nodes.append(Node(id=f"drug_{i}", label=dn, size=25,
color=RGB["drug"]))
# ---- Genes (optional) -------------------------------------------------- #
if genes:
for g in _safe_iter(genes):
if not isinstance(g, dict):
continue
sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name")
if sym:
gid = f"gene_{sym}"
nodes.append(Node(id=gid, label=sym, size=20,
color=RGB["gene"]))
# ---------------------------------------------------------------------- #
# Edges – naïve co-occurrence linking #
# ---------------------------------------------------------------------- #
paper_texts = [
(n.id, f"{p.get('title','')} {p.get('summary','')}".lower())
for n, p in zip(nodes, papers)
if n.id.startswith("paper_")
]
# connect paper ↔ umls / drugs / genes if mention appears in text
def _link(target_nodes: List[Node], pattern_getter):
for nid, blob in paper_texts:
for tn in target_nodes:
pat = pattern_getter(tn)
if pat and pat.search(blob):
edges.append(Edge(source=nid, target=tn.id, label="mentions"))
umls_nodes = [n for n in nodes if n.id.startswith("umls_")]
drug_nodes = [n for n in nodes if n.id.startswith("drug_")]
gene_nodes = [n for n in nodes if n.id.startswith("gene_")]
_link(umls_nodes, lambda n: re.compile(re.escape(n.label.lower())))
_link(drug_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
_link(gene_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
# de-duplicate everything ------------------------------------------------ #
nodes = _dedup(nodes)
edges = _dedup(edges)
# ---------------------------------------------------------------------- #
# Graph config #
# ---------------------------------------------------------------------- #
cfg = Config(
width="100%",
height="650px",
directed=False,
nodeHighlightBehavior=True,
highlightColor="#f1c40f",
collapsible=True,
physics=True,
hierarchical=False,
node={"labelProperty": "label"},
link={"labelProperty": "label", "renderLabel": False},
)
return nodes, edges, cfg
|