MCP_Res / mcp /knowledge_graph.py
mgbam's picture
Update mcp/knowledge_graph.py
eb1f007 verified
raw
history blame
6.62 kB
# mcp/knowledge_graph.py
"""
Build an interactive Network Graph (Streamlit-Agraph) that links together
the heterogeneous entities returned by the MedGenesis pipeline:
β€’ Papers (arXiv / PubMed)
β€’ UMLS concepts
β€’ DrugSafety records (OpenFDA)
β€’ Genes (optional – when available)
Design goals
------------
1. **Resilience.** Any malformed record (e.g. a RuntimeError placeholder from
an upstream failure) is skipped silently - the graph still renders.
2. **Deduplication.** Identical nodes (same CUI, drug-name, paper ID…) are
added only once; edges are de-duplicated as well.
3. **Visual semantics.**
β€’ Papers β†’ blue (#3498db)
β€’ UMLS concepts β†’ green (#2ecc71)
β€’ Drugs / safety β†’ orange (#e67e22)
β€’ Genes β†’ purple (#9b59b6)
4. **Config tuned for large graphs.** Collapsible nodes, physics-based layout,
highlight on hover, full-width canvas.
"""
from __future__ import annotations
from typing import Any, List, Tuple, Dict, Iterable
import re
from collections import defaultdict
from streamlit_agraph import Node, Edge, Config
# --------------------------------------------------------------------------- #
# utility helpers #
# --------------------------------------------------------------------------- #
RGB = {
"paper": "#3498db",
"umls": "#2ecc71",
"drug": "#e67e22",
"gene": "#9b59b6",
}
def _safe_iter(obj: Any) -> Iterable:
"""Yield from *obj* if it is list-like, else yield the obj itself."""
if obj is None:
return ()
if isinstance(obj, (list, tuple, set)):
return obj
return (obj,)
def _dedup(seq: Iterable[Tuple]) -> List[Tuple]:
"""Remove duplicates while preserving order."""
seen = set()
out: List[Tuple] = []
for item in seq:
if item not in seen:
out.append(item)
seen.add(item)
return out
# --------------------------------------------------------------------------- #
# main builder #
# --------------------------------------------------------------------------- #
def build_agraph(
papers: List[Dict[str, Any]],
umls: List[Dict[str, Any]],
drug_safety: List[Dict[str, Any]],
genes: List[Dict[str, Any]] | None = None,
) -> Tuple[List[Node], List[Edge], Config]:
"""
Parameters
----------
papers : list[dict]
Output of `fetch_arxiv`/`fetch_pubmed` (must have *title* & *summary*).
umls : list[dict]
Items from `lookup_umls` (may be RuntimeError objects if failed).
drug_safety : list[dict]
Flattened OpenFDA adverse-event records.
genes : list[dict] | None
Optional – from gene resolver hub.
Returns
-------
nodes, edges, config
Objects suitable for `streamlit_agraph.agraph(...)`
"""
nodes: List[Node] = []
edges: List[Edge] = []
# ---- Papers ----------------------------------------------------------- #
for idx, p in enumerate(_safe_iter(papers)):
if not isinstance(p, dict):
continue
pid = f"paper_{idx}"
label = f"P{idx + 1}"
tooltip = p.get("title", "Paper")
nodes.append(Node(id=pid, label=label, tooltip=tooltip, size=15,
color=RGB["paper"]))
# ---- UMLS concepts ---------------------------------------------------- #
for c in _safe_iter(umls):
if not isinstance(c, dict):
continue
cui = c.get("cui")
name = c.get("name")
if cui and name:
nodes.append(Node(id=f"umls_{cui}", label=name, size=22,
color=RGB["umls"]))
# ---- Drug Safety ------------------------------------------------------ #
for i, rec in enumerate(_safe_iter(drug_safety)):
if not isinstance(rec, dict):
continue
dn = (
rec.get("drug_name")
or rec.get("medicinalproduct")
or "drug_{}".format(i)
)
nodes.append(Node(id=f"drug_{i}", label=dn, size=25,
color=RGB["drug"]))
# ---- Genes (optional) -------------------------------------------------- #
if genes:
for g in _safe_iter(genes):
if not isinstance(g, dict):
continue
sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name")
if sym:
gid = f"gene_{sym}"
nodes.append(Node(id=gid, label=sym, size=20,
color=RGB["gene"]))
# ---------------------------------------------------------------------- #
# Edges – naΓ―ve co-occurrence linking #
# ---------------------------------------------------------------------- #
paper_texts = [
(n.id, f"{p.get('title','')} {p.get('summary','')}".lower())
for n, p in zip(nodes, papers)
if n.id.startswith("paper_")
]
# connect paper ↔ umls / drugs / genes if mention appears in text
def _link(target_nodes: List[Node], pattern_getter):
for nid, blob in paper_texts:
for tn in target_nodes:
pat = pattern_getter(tn)
if pat and pat.search(blob):
edges.append(Edge(source=nid, target=tn.id, label="mentions"))
umls_nodes = [n for n in nodes if n.id.startswith("umls_")]
drug_nodes = [n for n in nodes if n.id.startswith("drug_")]
gene_nodes = [n for n in nodes if n.id.startswith("gene_")]
_link(umls_nodes, lambda n: re.compile(re.escape(n.label.lower())))
_link(drug_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
_link(gene_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
# de-duplicate everything ------------------------------------------------ #
nodes = _dedup(nodes)
edges = _dedup(edges)
# ---------------------------------------------------------------------- #
# Graph config #
# ---------------------------------------------------------------------- #
cfg = Config(
width="100%",
height="650px",
directed=False,
nodeHighlightBehavior=True,
highlightColor="#f1c40f",
collapsible=True,
physics=True,
hierarchical=False,
node={"labelProperty": "label"},
link={"labelProperty": "label", "renderLabel": False},
)
return nodes, edges, cfg