|
|
|
""" |
|
Build an interactive Network Graph (Streamlit-Agraph) that links together |
|
the heterogeneous entities returned by the MedGenesis pipeline: |
|
|
|
β’ Papers (arXiv / PubMed) |
|
β’ UMLS concepts |
|
β’ DrugSafety records (OpenFDA) |
|
β’ Genes (optional β when available) |
|
|
|
Design goals |
|
------------ |
|
1. **Resilience.** Any malformed record (e.g. a RuntimeError placeholder from |
|
an upstream failure) is skipped silently - the graph still renders. |
|
2. **Deduplication.** Identical nodes (same CUI, drug-name, paper IDβ¦) are |
|
added only once; edges are de-duplicated as well. |
|
3. **Visual semantics.** |
|
β’ Papers β blue (#3498db) |
|
β’ UMLS concepts β green (#2ecc71) |
|
β’ Drugs / safety β orange (#e67e22) |
|
β’ Genes β purple (#9b59b6) |
|
4. **Config tuned for large graphs.** Collapsible nodes, physics-based layout, |
|
highlight on hover, full-width canvas. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
from typing import Any, List, Tuple, Dict, Iterable |
|
import re |
|
from collections import defaultdict |
|
|
|
from streamlit_agraph import Node, Edge, Config |
|
|
|
|
|
|
|
|
|
RGB = { |
|
"paper": "#3498db", |
|
"umls": "#2ecc71", |
|
"drug": "#e67e22", |
|
"gene": "#9b59b6", |
|
} |
|
|
|
|
|
def _safe_iter(obj: Any) -> Iterable: |
|
"""Yield from *obj* if it is list-like, else yield the obj itself.""" |
|
if obj is None: |
|
return () |
|
if isinstance(obj, (list, tuple, set)): |
|
return obj |
|
return (obj,) |
|
|
|
|
|
def _dedup(seq: Iterable[Tuple]) -> List[Tuple]: |
|
"""Remove duplicates while preserving order.""" |
|
seen = set() |
|
out: List[Tuple] = [] |
|
for item in seq: |
|
if item not in seen: |
|
out.append(item) |
|
seen.add(item) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def build_agraph( |
|
papers: List[Dict[str, Any]], |
|
umls: List[Dict[str, Any]], |
|
drug_safety: List[Dict[str, Any]], |
|
genes: List[Dict[str, Any]] | None = None, |
|
) -> Tuple[List[Node], List[Edge], Config]: |
|
""" |
|
Parameters |
|
---------- |
|
papers : list[dict] |
|
Output of `fetch_arxiv`/`fetch_pubmed` (must have *title* & *summary*). |
|
umls : list[dict] |
|
Items from `lookup_umls` (may be RuntimeError objects if failed). |
|
drug_safety : list[dict] |
|
Flattened OpenFDA adverse-event records. |
|
genes : list[dict] | None |
|
Optional β from gene resolver hub. |
|
|
|
Returns |
|
------- |
|
nodes, edges, config |
|
Objects suitable for `streamlit_agraph.agraph(...)` |
|
""" |
|
nodes: List[Node] = [] |
|
edges: List[Edge] = [] |
|
|
|
|
|
for idx, p in enumerate(_safe_iter(papers)): |
|
if not isinstance(p, dict): |
|
continue |
|
pid = f"paper_{idx}" |
|
label = f"P{idx + 1}" |
|
tooltip = p.get("title", "Paper") |
|
nodes.append(Node(id=pid, label=label, tooltip=tooltip, size=15, |
|
color=RGB["paper"])) |
|
|
|
|
|
for c in _safe_iter(umls): |
|
if not isinstance(c, dict): |
|
continue |
|
cui = c.get("cui") |
|
name = c.get("name") |
|
if cui and name: |
|
nodes.append(Node(id=f"umls_{cui}", label=name, size=22, |
|
color=RGB["umls"])) |
|
|
|
|
|
for i, rec in enumerate(_safe_iter(drug_safety)): |
|
if not isinstance(rec, dict): |
|
continue |
|
dn = ( |
|
rec.get("drug_name") |
|
or rec.get("medicinalproduct") |
|
or "drug_{}".format(i) |
|
) |
|
nodes.append(Node(id=f"drug_{i}", label=dn, size=25, |
|
color=RGB["drug"])) |
|
|
|
|
|
if genes: |
|
for g in _safe_iter(genes): |
|
if not isinstance(g, dict): |
|
continue |
|
sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name") |
|
if sym: |
|
gid = f"gene_{sym}" |
|
nodes.append(Node(id=gid, label=sym, size=20, |
|
color=RGB["gene"])) |
|
|
|
|
|
|
|
|
|
paper_texts = [ |
|
(n.id, f"{p.get('title','')} {p.get('summary','')}".lower()) |
|
for n, p in zip(nodes, papers) |
|
if n.id.startswith("paper_") |
|
] |
|
|
|
|
|
def _link(target_nodes: List[Node], pattern_getter): |
|
for nid, blob in paper_texts: |
|
for tn in target_nodes: |
|
pat = pattern_getter(tn) |
|
if pat and pat.search(blob): |
|
edges.append(Edge(source=nid, target=tn.id, label="mentions")) |
|
|
|
umls_nodes = [n for n in nodes if n.id.startswith("umls_")] |
|
drug_nodes = [n for n in nodes if n.id.startswith("drug_")] |
|
gene_nodes = [n for n in nodes if n.id.startswith("gene_")] |
|
|
|
_link(umls_nodes, lambda n: re.compile(re.escape(n.label.lower()))) |
|
_link(drug_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b")) |
|
_link(gene_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b")) |
|
|
|
|
|
nodes = _dedup(nodes) |
|
edges = _dedup(edges) |
|
|
|
|
|
|
|
|
|
cfg = Config( |
|
width="100%", |
|
height="650px", |
|
directed=False, |
|
nodeHighlightBehavior=True, |
|
highlightColor="#f1c40f", |
|
collapsible=True, |
|
physics=True, |
|
hierarchical=False, |
|
node={"labelProperty": "label"}, |
|
link={"labelProperty": "label", "renderLabel": False}, |
|
) |
|
|
|
return nodes, edges, cfg |
|
|