mgbam commited on
Commit
eb1f007
·
verified ·
1 Parent(s): 4b96f76

Update mcp/knowledge_graph.py

Browse files
Files changed (1) hide show
  1. mcp/knowledge_graph.py +157 -136
mcp/knowledge_graph.py CHANGED
@@ -1,161 +1,182 @@
1
- #!/usr/bin/env python3
2
- """MedGenesis – knowledge‑graph builder for Streamlit‑Agraph.
3
-
4
- This version recognises **all new enrichment layers** introduced in the
5
- latest orchestrator:
6
- • UMLS concepts → green nodes
7
- • MyGene / NCBI gene hits → purple nodes
8
- • openFDA / DrugCentral drugs → orange nodes
9
- • ClinicalTrials.gov studies → pink nodes
10
- • Open Targets associations → red drug–gene / gene–disease edges
11
- • Literature papers → blue nodes (tooltip = title)
12
-
13
- The entry‑point `build_agraph` now receives a richer payload and returns
14
- *(nodes, edges, config)* ready for `streamlit_agraph.agraph`.
15
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from __future__ import annotations
17
 
 
18
  import re
19
- from typing import List, Dict, Tuple
20
 
21
  from streamlit_agraph import Node, Edge, Config
22
 
23
- # ---------------------------------------------------------------------
24
- # Colour palette (flat‑UI)
25
- # ---------------------------------------------------------------------
26
- C_PAPER = "#0984e3"
27
- C_CONCEPT = "#00b894"
28
- C_GENE = "#6c5ce7"
29
- C_DRUG = "#d35400"
30
- C_TRIAL = "#fd79a8"
31
- C_OT_EDGE = "#c0392b"
32
-
33
-
34
- # ---------------------------------------------------------------------
35
- # Helper builders
36
- # ---------------------------------------------------------------------
37
-
38
- def _add_node(nodes: List[Node], node_id: str, label: str, color: str, tooltip: str | None = None, size: int = 25):
39
- """Append Node only if id not yet present (agraph duplicates crash)."""
40
- if any(n.id == node_id for n in nodes):
41
- return
42
- nodes.append(Node(id=node_id, label=label, color=color, size=size, tooltip=tooltip))
43
-
44
-
45
- def _match(text: str, pattern: str) -> bool:
46
- return bool(re.search(re.escape(pattern), text, flags=re.I))
47
-
48
-
49
- # ---------------------------------------------------------------------
50
- # Public API
51
- # ---------------------------------------------------------------------
52
-
 
 
 
 
53
  def build_agraph(
54
- papers: List[Dict],
55
- umls: List[Dict],
56
- drug_safety: List[Dict],
57
- genes: List[Dict] | None = None,
58
- trials: List[Dict] | None = None,
59
- ot_associations: List[Dict] | None = None,
60
- ):
61
- """Return (nodes, edges, config) for streamlit_agraph. Safe‑duplicates.
62
-
63
  Parameters
64
  ----------
65
- papers : PubMed / arXiv merged list (dicts with title & summary).
66
- umls : List of UMLS concept dicts `{cui, name}`.
67
- drug_safety : openFDA / DrugCentral outputs (mixed dict / list).
68
- genes : Optional list with MyGene/NCBI dicts (symbol, name,...).
69
- trials : Optional ClinicalTrials.gov v2 studies list.
70
- ot_associations : Optional list from Open Targets.
 
 
 
 
 
 
 
71
  """
72
-
73
  nodes: List[Node] = []
74
  edges: List[Edge] = []
75
 
76
- # 1️⃣ Concepts ----------------------------------------------------
77
- for c in umls:
78
- cui, name = c.get("cui"), c.get("name", "")
79
- if cui and name:
80
- cid = f"concept_{cui}"
81
- _add_node(nodes, cid, name, C_CONCEPT)
82
-
83
- # 2️⃣ Genes -------------------------------------------------------
84
- genes = genes or []
85
- for g in genes:
86
- sym = g.get("symbol") or g.get("name")
87
- gid = f"gene_{sym}"
88
- tooltip = g.get("summary", "")
89
- _add_node(nodes, gid, sym, C_GENE, tooltip=tooltip)
90
-
91
- # 3️⃣ Drugs (normalize mixed structures) -------------------------
92
- drug_tuples: List[Tuple[str, str]] = [] # (node_id, drug_name)
93
- for i, dr in enumerate(drug_safety):
94
- recs = dr if isinstance(dr, list) else [dr]
95
- for j, rec in enumerate(recs):
96
- name = (
97
- rec.get("drug_name") or
98
- rec.get("patient", {}).get("drug") or
99
- rec.get("medicinalproduct") or
100
- f"drug_{i}_{j}"
101
- )
102
- did = f"drug_{i}_{j}"
103
- drug_tuples.append((did, name))
104
- _add_node(nodes, did, name, C_DRUG)
105
-
106
- # 4️⃣ Trials ------------------------------------------------------
107
- trials = trials or []
108
- for t in trials:
109
- nct = t.get("nctId") or t.get("nctid")
110
- if not nct:
111
  continue
112
- tid = f"trial_{nct}"
113
- label = nct
114
- tooltip = t.get("briefTitle") or "Clinical trial"
115
- _add_node(nodes, tid, label, C_TRIAL, tooltip=tooltip, size=20)
116
-
117
- # 5️⃣ Papers & mention edges -------------------------------------
118
- for idx, p in enumerate(papers):
119
  pid = f"paper_{idx}"
120
- _add_node(nodes, pid, f"P{idx+1}", C_PAPER, tooltip=p.get("title", ""), size=15)
121
-
122
- text_blob = f"{p.get('title','')} {p.get('summary','')}".lower()
123
-
124
- # concept links
125
- for c in umls:
126
- if c.get("name") and _match(text_blob, c["name"]):
127
- edges.append(Edge(source=pid, target=f"concept_{c['cui']}", label="mentions"))
128
- # gene links
129
- for g in genes:
130
- if g.get("symbol") and _match(text_blob, g["symbol"]):
131
- edges.append(Edge(source=pid, target=f"gene_{g['symbol']}", label="mentions"))
132
- # drug links
133
- for did, dname in drug_tuples:
134
- if _match(text_blob, dname):
135
- edges.append(Edge(source=pid, target=did, label="mentions"))
136
-
137
- # 6️⃣ Open Targets edges (drug–gene / gene–disease) --------------
138
- if ot_associations:
139
- for row in ot_associations:
140
- gsym = row.get("target", {}).get("symbol")
141
- dis = row.get("disease", {}).get("name")
142
- score = row.get("score", 0)
143
- if gsym and dis:
144
- gid = f"gene_{gsym}"
145
- did = f"disease_{dis}"
146
- _add_node(nodes, did, dis, C_CONCEPT, size=20)
147
- edges.append(Edge(source=gid, target=did, color=C_OT_EDGE, label=f"OT {score:.2f}"))
148
-
149
- # 7️⃣ Config ------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  cfg = Config(
151
- directed=False,
152
  width="100%",
153
- height="600",
 
154
  nodeHighlightBehavior=True,
155
  highlightColor="#f1c40f",
156
  collapsible=True,
157
- showLegend=False,
 
158
  node={"labelProperty": "label"},
 
159
  )
160
 
161
  return nodes, edges, cfg
 
1
+ # mcp/knowledge_graph.py
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  """
3
+ Build an interactive Network Graph (Streamlit-Agraph) that links together
4
+ the heterogeneous entities returned by the MedGenesis pipeline:
5
+
6
+ • Papers (arXiv / PubMed)
7
+ • UMLS concepts
8
+ • DrugSafety records (OpenFDA)
9
+ • Genes (optional – when available)
10
+
11
+ Design goals
12
+ ------------
13
+ 1. **Resilience.** Any malformed record (e.g. a RuntimeError placeholder from
14
+ an upstream failure) is skipped silently - the graph still renders.
15
+ 2. **Deduplication.** Identical nodes (same CUI, drug-name, paper ID…) are
16
+ added only once; edges are de-duplicated as well.
17
+ 3. **Visual semantics.**
18
+ • Papers → blue (#3498db)
19
+ • UMLS concepts → green (#2ecc71)
20
+ • Drugs / safety → orange (#e67e22)
21
+ • Genes → purple (#9b59b6)
22
+ 4. **Config tuned for large graphs.** Collapsible nodes, physics-based layout,
23
+ highlight on hover, full-width canvas.
24
+ """
25
+
26
  from __future__ import annotations
27
 
28
+ from typing import Any, List, Tuple, Dict, Iterable
29
  import re
30
+ from collections import defaultdict
31
 
32
  from streamlit_agraph import Node, Edge, Config
33
 
34
+ # --------------------------------------------------------------------------- #
35
+ # utility helpers #
36
+ # --------------------------------------------------------------------------- #
37
+ RGB = {
38
+ "paper": "#3498db",
39
+ "umls": "#2ecc71",
40
+ "drug": "#e67e22",
41
+ "gene": "#9b59b6",
42
+ }
43
+
44
+
45
+ def _safe_iter(obj: Any) -> Iterable:
46
+ """Yield from *obj* if it is list-like, else yield the obj itself."""
47
+ if obj is None:
48
+ return ()
49
+ if isinstance(obj, (list, tuple, set)):
50
+ return obj
51
+ return (obj,)
52
+
53
+
54
+ def _dedup(seq: Iterable[Tuple]) -> List[Tuple]:
55
+ """Remove duplicates while preserving order."""
56
+ seen = set()
57
+ out: List[Tuple] = []
58
+ for item in seq:
59
+ if item not in seen:
60
+ out.append(item)
61
+ seen.add(item)
62
+ return out
63
+
64
+
65
+ # --------------------------------------------------------------------------- #
66
+ # main builder #
67
+ # --------------------------------------------------------------------------- #
68
  def build_agraph(
69
+ papers: List[Dict[str, Any]],
70
+ umls: List[Dict[str, Any]],
71
+ drug_safety: List[Dict[str, Any]],
72
+ genes: List[Dict[str, Any]] | None = None,
73
+ ) -> Tuple[List[Node], List[Edge], Config]:
74
+ """
 
 
 
75
  Parameters
76
  ----------
77
+ papers : list[dict]
78
+ Output of `fetch_arxiv`/`fetch_pubmed` (must have *title* & *summary*).
79
+ umls : list[dict]
80
+ Items from `lookup_umls` (may be RuntimeError objects if failed).
81
+ drug_safety : list[dict]
82
+ Flattened OpenFDA adverse-event records.
83
+ genes : list[dict] | None
84
+ Optional – from gene resolver hub.
85
+
86
+ Returns
87
+ -------
88
+ nodes, edges, config
89
+ Objects suitable for `streamlit_agraph.agraph(...)`
90
  """
 
91
  nodes: List[Node] = []
92
  edges: List[Edge] = []
93
 
94
+ # ---- Papers ----------------------------------------------------------- #
95
+ for idx, p in enumerate(_safe_iter(papers)):
96
+ if not isinstance(p, dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  continue
 
 
 
 
 
 
 
98
  pid = f"paper_{idx}"
99
+ label = f"P{idx + 1}"
100
+ tooltip = p.get("title", "Paper")
101
+ nodes.append(Node(id=pid, label=label, tooltip=tooltip, size=15,
102
+ color=RGB["paper"]))
103
+
104
+ # ---- UMLS concepts ---------------------------------------------------- #
105
+ for c in _safe_iter(umls):
106
+ if not isinstance(c, dict):
107
+ continue
108
+ cui = c.get("cui")
109
+ name = c.get("name")
110
+ if cui and name:
111
+ nodes.append(Node(id=f"umls_{cui}", label=name, size=22,
112
+ color=RGB["umls"]))
113
+
114
+ # ---- Drug Safety ------------------------------------------------------ #
115
+ for i, rec in enumerate(_safe_iter(drug_safety)):
116
+ if not isinstance(rec, dict):
117
+ continue
118
+ dn = (
119
+ rec.get("drug_name")
120
+ or rec.get("medicinalproduct")
121
+ or "drug_{}".format(i)
122
+ )
123
+ nodes.append(Node(id=f"drug_{i}", label=dn, size=25,
124
+ color=RGB["drug"]))
125
+
126
+ # ---- Genes (optional) -------------------------------------------------- #
127
+ if genes:
128
+ for g in _safe_iter(genes):
129
+ if not isinstance(g, dict):
130
+ continue
131
+ sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name")
132
+ if sym:
133
+ gid = f"gene_{sym}"
134
+ nodes.append(Node(id=gid, label=sym, size=20,
135
+ color=RGB["gene"]))
136
+
137
+ # ---------------------------------------------------------------------- #
138
+ # Edges – naïve co-occurrence linking #
139
+ # ---------------------------------------------------------------------- #
140
+ paper_texts = [
141
+ (n.id, f"{p.get('title','')} {p.get('summary','')}".lower())
142
+ for n, p in zip(nodes, papers)
143
+ if n.id.startswith("paper_")
144
+ ]
145
+
146
+ # connect paper ↔ umls / drugs / genes if mention appears in text
147
+ def _link(target_nodes: List[Node], pattern_getter):
148
+ for nid, blob in paper_texts:
149
+ for tn in target_nodes:
150
+ pat = pattern_getter(tn)
151
+ if pat and pat.search(blob):
152
+ edges.append(Edge(source=nid, target=tn.id, label="mentions"))
153
+
154
+ umls_nodes = [n for n in nodes if n.id.startswith("umls_")]
155
+ drug_nodes = [n for n in nodes if n.id.startswith("drug_")]
156
+ gene_nodes = [n for n in nodes if n.id.startswith("gene_")]
157
+
158
+ _link(umls_nodes, lambda n: re.compile(re.escape(n.label.lower())))
159
+ _link(drug_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
160
+ _link(gene_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
161
+
162
+ # de-duplicate everything ------------------------------------------------ #
163
+ nodes = _dedup(nodes)
164
+ edges = _dedup(edges)
165
+
166
+ # ---------------------------------------------------------------------- #
167
+ # Graph config #
168
+ # ---------------------------------------------------------------------- #
169
  cfg = Config(
 
170
  width="100%",
171
+ height="650px",
172
+ directed=False,
173
  nodeHighlightBehavior=True,
174
  highlightColor="#f1c40f",
175
  collapsible=True,
176
+ physics=True,
177
+ hierarchical=False,
178
  node={"labelProperty": "label"},
179
+ link={"labelProperty": "label", "renderLabel": False},
180
  )
181
 
182
  return nodes, edges, cfg