mgbam commited on
Commit
1e92bb0
·
verified ·
1 Parent(s): 57eeff5

Update mcp/knowledge_graph.py

Browse files
Files changed (1) hide show
  1. mcp/knowledge_graph.py +32 -77
mcp/knowledge_graph.py CHANGED
@@ -1,82 +1,37 @@
1
  # mcp/knowledge_graph.py
2
-
3
  from streamlit_agraph import Node, Edge, Config
4
- import re
5
-
6
- # Set colors for node types
7
- PAPER_COLOR = "#0984e3"
8
- UMLS_COLOR = "#00b894"
9
- DRUG_COLOR = "#d35400"
10
 
11
- def build_agraph(papers, umls, drug_safety):
12
- """
13
- Build interactive agraph nodes and edges.
14
- Defensive: handles unexpected types gracefully.
15
- """
16
  nodes, edges = [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # UMLS nodes
19
- for c in umls or []:
20
- if not isinstance(c, dict):
21
- continue
22
- cui = str(c.get("cui", "") or "")
23
- name = str(c.get("name", "") or "")
24
- if cui and name:
25
- nid = f"concept_{cui}"
26
- nodes.append(Node(
27
- id=nid, label=name, size=25, color=UMLS_COLOR,
28
- tooltip=f"UMLS {cui}: {name}"
29
- ))
30
-
31
- # Drug nodes
32
- drug_names = []
33
- for i, dr in enumerate(drug_safety or []):
34
- if not dr:
35
- continue
36
- # Normalize to single dict
37
- recs = dr if isinstance(dr, list) else [dr]
38
- for j, rec in enumerate(recs):
39
- if not isinstance(rec, dict):
40
- continue
41
- dn = rec.get("drug_name") \
42
- or (rec.get("patient", {}) or {}).get("drug", "") \
43
- or rec.get("medicinalproduct", "")
44
- dn = str(dn or f"drug_{i}_{j}")
45
- did = f"drug_{i}_{j}"
46
- drug_names.append((did, dn))
47
- nodes.append(Node(id=did, label=dn, size=25, color=DRUG_COLOR,
48
- tooltip=f"Drug: {dn}"))
49
-
50
- # Paper nodes and edges
51
- for k, p in enumerate(papers or []):
52
- pid = f"paper_{k}"
53
- title = str(p.get("title", f"Paper {k+1}"))
54
- summary = str(p.get("summary", ""))
55
- label = f"P{k+1}"
56
- nodes.append(Node(
57
- id=pid,
58
- label=label,
59
- tooltip=title,
60
- size=14,
61
- color=PAPER_COLOR,
62
- ))
63
- txt = (title + " " + summary).lower()
64
- # Link to concepts
65
- for c in umls or []:
66
- name = str(c.get("name", "") or "")
67
- cui = str(c.get("cui", "") or "")
68
- if name and name.lower() in txt and cui:
69
- edges.append(Edge(source=pid, target=f"concept_{cui}", label="mentions"))
70
- # Link to drugs
71
- for did, dn in drug_names:
72
- if dn and dn.lower() in txt:
73
- edges.append(Edge(source=pid, target=did, label="mentions"))
74
-
75
- config = Config(
76
- width="100%", height="600", directed=False,
77
- nodeHighlightBehavior=True, highlightColor="#f1c40f",
78
- collapsible=True,
79
- node={"labelProperty": "label"},
80
- link={"labelProperty": "label"},
81
- )
82
- return nodes, edges, config
 
1
  # mcp/knowledge_graph.py
 
2
  from streamlit_agraph import Node, Edge, Config
 
 
 
 
 
 
3
 
4
+ def build_agraph(res: Dict) -> (list, list, Config):
 
 
 
 
5
  nodes, edges = [], []
6
+ # add each paper as a node
7
+ for i,p in enumerate(res["papers"]):
8
+ nid = f"paper_{i}"
9
+ nodes.append(Node(id=nid, label=p["title"], size=20, color="#0984e3"))
10
+ # connect to AI summary?
11
+ # add UMLS concepts
12
+ for u in res["umls"]:
13
+ cid = f"cui_{u['cui']}"
14
+ label = f"{u['name']} ({u['cui']})"
15
+ nodes.append(Node(id=cid, label=label, size=25, color="#00b894"))
16
+ # connect concept → first paper
17
+ edges.append(Edge(source=cid, target="paper_0", label="mentioned_in"))
18
+ # genes
19
+ g = res.get("gene",{})
20
+ if g:
21
+ gid = "gene_node"
22
+ nodes.append(Node(id=gid, label=g.get("symbol",g.get("name","gene")), color="#d63031"))
23
+ edges.append(Edge(source=gid, target="cui_"+res["umls"][0]["cui"], label="related"))
24
+ # variants
25
+ for v in res["variants"]:
26
+ vid = f"var_{v['mutationId']}"
27
+ nodes.append(Node(id=vid, label=v["mutationId"], color="#fdcb6e", size=15))
28
+ edges.append(Edge(source=vid, target=gid, label="affects"))
29
+ # trials
30
+ for t in res["trials"]:
31
+ tid = t["NCTId"][0]
32
+ nodes.append(Node(id=tid, label=tid, color="#6c5ce7"))
33
+ edges.append(Edge(source=tid, target=gid, label="studies"))
34
 
35
+ cfg = Config(width="100%", height="600", directed=True,
36
+ nodeHighlightBehavior=True, highlightColor="#fdcb6e")
37
+ return nodes, edges, cfg