mgbam commited on
Commit
b305748
Β·
verified Β·
1 Parent(s): 96208dc

Update mcp/knowledge_graph.py

Browse files
Files changed (1) hide show
  1. mcp/knowledge_graph.py +74 -163
mcp/knowledge_graph.py CHANGED
@@ -1,182 +1,93 @@
1
- # mcp/knowledge_graph.py
2
  """
3
- Build an interactive Network Graph (Streamlit-Agraph) that links together
4
- the heterogeneous entities returned by the MedGenesis pipeline:
5
-
6
- β€’ Papers (arXiv / PubMed)
7
- β€’ UMLS concepts
8
- β€’ DrugSafety records (OpenFDA)
9
- β€’ Genes (optional – when available)
10
-
11
- Design goals
12
- ------------
13
- 1. **Resilience.** Any malformed record (e.g. a RuntimeError placeholder from
14
- an upstream failure) is skipped silently - the graph still renders.
15
- 2. **Deduplication.** Identical nodes (same CUI, drug-name, paper ID…) are
16
- added only once; edges are de-duplicated as well.
17
- 3. **Visual semantics.**
18
- β€’ Papers β†’ blue (#3498db)
19
- β€’ UMLS concepts β†’ green (#2ecc71)
20
- β€’ Drugs / safety β†’ orange (#e67e22)
21
- β€’ Genes β†’ purple (#9b59b6)
22
- 4. **Config tuned for large graphs.** Collapsible nodes, physics-based layout,
23
- highlight on hover, full-width canvas.
24
  """
25
 
26
  from __future__ import annotations
27
-
28
- from typing import Any, List, Tuple, Dict, Iterable
29
  import re
30
- from collections import defaultdict
31
-
32
  from streamlit_agraph import Node, Edge, Config
33
 
34
- # --------------------------------------------------------------------------- #
35
- # utility helpers #
36
- # --------------------------------------------------------------------------- #
37
- RGB = {
38
- "paper": "#3498db",
39
- "umls": "#2ecc71",
40
- "drug": "#e67e22",
41
- "gene": "#9b59b6",
42
- }
43
-
44
-
45
- def _safe_iter(obj: Any) -> Iterable:
46
- """Yield from *obj* if it is list-like, else yield the obj itself."""
47
- if obj is None:
48
- return ()
49
- if isinstance(obj, (list, tuple, set)):
50
- return obj
51
- return (obj,)
52
 
 
 
 
53
 
54
- def _dedup(seq: Iterable[Tuple]) -> List[Tuple]:
55
- """Remove duplicates while preserving order."""
56
- seen = set()
57
- out: List[Tuple] = []
58
- for item in seq:
59
- if item not in seen:
60
- out.append(item)
61
- seen.add(item)
62
- return out
63
-
64
-
65
- # --------------------------------------------------------------------------- #
66
- # main builder #
67
- # --------------------------------------------------------------------------- #
68
- def build_agraph(
69
- papers: List[Dict[str, Any]],
70
- umls: List[Dict[str, Any]],
71
- drug_safety: List[Dict[str, Any]],
72
- genes: List[Dict[str, Any]] | None = None,
73
- ) -> Tuple[List[Node], List[Edge], Config]:
74
  """
75
- Parameters
76
- ----------
77
- papers : list[dict]
78
- Output of `fetch_arxiv`/`fetch_pubmed` (must have *title* & *summary*).
79
- umls : list[dict]
80
- Items from `lookup_umls` (may be RuntimeError objects if failed).
81
- drug_safety : list[dict]
82
- Flattened OpenFDA adverse-event records.
83
- genes : list[dict] | None
84
- Optional – from gene resolver hub.
85
-
86
- Returns
87
- -------
88
- nodes, edges, config
89
- Objects suitable for `streamlit_agraph.agraph(...)`
90
  """
 
91
  nodes: List[Node] = []
92
  edges: List[Edge] = []
93
 
94
- # ---- Papers ----------------------------------------------------------- #
95
- for idx, p in enumerate(_safe_iter(papers)):
96
- if not isinstance(p, dict):
 
 
 
 
 
 
 
 
 
 
 
 
97
  continue
98
- pid = f"paper_{idx}"
99
- label = f"P{idx + 1}"
100
- tooltip = p.get("title", "Paper")
101
- nodes.append(Node(id=pid, label=label, tooltip=tooltip, size=15,
102
- color=RGB["paper"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # ---- UMLS concepts ---------------------------------------------------- #
105
- for c in _safe_iter(umls):
106
- if not isinstance(c, dict):
107
- continue
108
- cui = c.get("cui")
109
- name = c.get("name")
110
- if cui and name:
111
- nodes.append(Node(id=f"umls_{cui}", label=name, size=22,
112
- color=RGB["umls"]))
113
-
114
- # ---- Drug Safety ------------------------------------------------------ #
115
- for i, rec in enumerate(_safe_iter(drug_safety)):
116
- if not isinstance(rec, dict):
117
- continue
118
- dn = (
119
- rec.get("drug_name")
120
- or rec.get("medicinalproduct")
121
- or "drug_{}".format(i)
122
- )
123
- nodes.append(Node(id=f"drug_{i}", label=dn, size=25,
124
- color=RGB["drug"]))
125
-
126
- # ---- Genes (optional) -------------------------------------------------- #
127
- if genes:
128
- for g in _safe_iter(genes):
129
- if not isinstance(g, dict):
130
- continue
131
- sym = g.get("symbol") or g.get("approvedSymbol") or g.get("name")
132
- if sym:
133
- gid = f"gene_{sym}"
134
- nodes.append(Node(id=gid, label=sym, size=20,
135
- color=RGB["gene"]))
136
-
137
- # ---------------------------------------------------------------------- #
138
- # Edges – naΓ―ve co-occurrence linking #
139
- # ---------------------------------------------------------------------- #
140
- paper_texts = [
141
- (n.id, f"{p.get('title','')} {p.get('summary','')}".lower())
142
- for n, p in zip(nodes, papers)
143
- if n.id.startswith("paper_")
144
- ]
145
-
146
- # connect paper ↔ umls / drugs / genes if mention appears in text
147
- def _link(target_nodes: List[Node], pattern_getter):
148
- for nid, blob in paper_texts:
149
- for tn in target_nodes:
150
- pat = pattern_getter(tn)
151
- if pat and pat.search(blob):
152
- edges.append(Edge(source=nid, target=tn.id, label="mentions"))
153
-
154
- umls_nodes = [n for n in nodes if n.id.startswith("umls_")]
155
- drug_nodes = [n for n in nodes if n.id.startswith("drug_")]
156
- gene_nodes = [n for n in nodes if n.id.startswith("gene_")]
157
-
158
- _link(umls_nodes, lambda n: re.compile(re.escape(n.label.lower())))
159
- _link(drug_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
160
- _link(gene_nodes, lambda n: re.compile(rf"\b{re.escape(n.label.lower())}\b"))
161
-
162
- # de-duplicate everything ------------------------------------------------ #
163
- nodes = _dedup(nodes)
164
- edges = _dedup(edges)
165
-
166
- # ---------------------------------------------------------------------- #
167
- # Graph config #
168
- # ---------------------------------------------------------------------- #
169
  cfg = Config(
170
- width="100%",
171
- height="650px",
172
- directed=False,
173
- nodeHighlightBehavior=True,
174
- highlightColor="#f1c40f",
175
- collapsible=True,
176
- physics=True,
177
- hierarchical=False,
178
- node={"labelProperty": "label"},
179
- link={"labelProperty": "label", "renderLabel": False},
180
  )
181
-
182
  return nodes, edges, cfg
 
 
1
  """
2
+ Graph builder for Streamlit-agraph.
3
+ β€’ Safely skips bad nodes (RuntimeError / None)
4
+ β€’ Colour-codes papers = #3498db (blue)
5
+ drugs = #e67e22 (orange)
6
+ concepts= #2ecc71 (green)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
 
9
  from __future__ import annotations
 
 
10
  import re
11
+ from typing import List, Dict, Tuple
 
12
  from streamlit_agraph import Node, Edge, Config
13
 
14
+ # ── colour palette ----------------------------------------------------
15
+ CLR_PAPER = "#3498db"
16
+ CLR_DRUG = "#e67e22"
17
+ CLR_CONCEPT = "#2ecc71"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def _add_node_safe(nodes: List[Node], nid: str, **kwargs) -> None:
20
+ if not any(n.id == nid for n in nodes):
21
+ nodes.append(Node(id=nid, **kwargs))
22
 
23
+ def build_agraph(papers: List[Dict],
24
+ umls : List[Dict],
25
+ safety: List[Dict]) -> Tuple[List[Node], List[Edge], Config]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
+ Map MedGenesis payload β†’ (nodes, edges, config) for Streamlit-agraph.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
+
30
  nodes: List[Node] = []
31
  edges: List[Edge] = []
32
 
33
+ # 1 UMLS CONCEPTS ---------------------------------------------------
34
+ for concept in umls:
35
+ try:
36
+ cui = concept.get("cui")
37
+ name = concept.get("name", "")
38
+ if cui and name:
39
+ cid = f"CUI_{cui}"
40
+ _add_node_safe(nodes, cid, label=name, size=25, color=CLR_CONCEPT)
41
+ except Exception:
42
+ continue # malformed concept
43
+
44
+ # 2 DRUGS (OpenFDA) -------------------------------------------------
45
+ drug_label_pairs: List[Tuple[str, str]] = []
46
+ for i, rec in enumerate(safety):
47
+ if not rec: # may be {} if rate-limited
48
  continue
49
+ drug_name = (rec.get("drug_name") or
50
+ rec.get("patient", {}).get("drug") or
51
+ rec.get("medicinalproduct") or
52
+ f"unknown_{i}")
53
+ did = f"DRUG_{i}"
54
+ drug_label_pairs.append((did, drug_name))
55
+ _add_node_safe(nodes, did, label=drug_name, size=22, color=CLR_DRUG)
56
+
57
+ # 3 PAPERS ----------------------------------------------------------
58
+ for idx, p in enumerate(papers):
59
+ pid = f"PAPER_{idx}"
60
+ _add_node_safe(nodes, pid,
61
+ label=f"P{idx+1}",
62
+ tooltip=p.get("title", ""),
63
+ size=15,
64
+ color=CLR_PAPER)
65
+
66
+ plain = (p.get("title", "") + " " + p.get("summary", "")).lower()
67
+
68
+ # (i) concept edges
69
+ for concept in umls:
70
+ name = concept.get("name", "")
71
+ cui = concept.get("cui")
72
+ if name and cui and name.lower() in plain:
73
+ edges.append(Edge(source=pid,
74
+ target=f"CUI_{cui}",
75
+ label="mentions"))
76
+
77
+ # (ii) drug edges
78
+ for did, dname in drug_label_pairs:
79
+ if dname.lower() in plain:
80
+ edges.append(Edge(source=pid,
81
+ target=did,
82
+ label="mentions"))
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  cfg = Config(
85
+ width = "100%",
86
+ height = 600,
87
+ directed = False,
88
+ nodeHighlightBehavior = True,
89
+ highlightColor = "#f1c40f",
90
+ collapsible = True,
91
+ node = {"labelProperty": "label"}
 
 
 
92
  )
 
93
  return nodes, edges, cfg