mgbam commited on
Commit
590f907
·
verified ·
1 Parent(s): 72a7fea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -130
app.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python3
2
- # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
3
 
4
- import os, pathlib
5
- import asyncio, re
6
  from pathlib import Path
7
 
8
  import streamlit as st
@@ -11,25 +11,29 @@ import plotly.express as px
11
  from fpdf import FPDF
12
  from streamlit_agraph import agraph
13
 
14
- from mcp.orchestrator import orchestrate_search, answer_ai_question
15
- from mcp.workspace import get_workspace, save_query
16
  from mcp.knowledge_graph import build_agraph
17
- from mcp.graph_metrics import build_nx, get_top_hubs, get_density
18
- from mcp.alerts import check_alerts
19
 
20
- # ── Streamlit telemetry dir fix ───────────────────────────────────────
21
- os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
22
- os.environ["XDG_STATE_HOME"] = "/tmp"
23
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
24
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
28
 
 
 
 
29
  def _latin1_safe(txt: str) -> str:
30
  return txt.encode("latin-1", "replace").decode("latin-1")
31
 
32
- def _pdf(papers):
 
33
  pdf = FPDF()
34
  pdf.set_auto_page_break(auto=True, margin=15)
35
  pdf.add_page()
@@ -45,34 +49,41 @@ def _pdf(papers):
45
  pdf.ln(1)
46
  return pdf.output(dest="S").encode("latin-1", "replace")
47
 
48
- def _workspace_sidebar():
 
49
  with st.sidebar:
50
  st.header("🗂️ Workspace")
51
  ws = get_workspace()
52
  if not ws:
53
  st.info("Run a search then press **Save** to populate this list.")
54
  return
 
 
 
55
  for i, item in enumerate(ws, 1):
56
  with st.expander(f"{i}. {item['query']}"):
57
  st.write(item["result"]["ai_summary"])
58
 
59
- def render_ui():
 
 
 
 
60
  st.set_page_config("MedGenesis AI", layout="wide")
61
 
62
- # Initialize session state keys if missing
63
- if "query_result" not in st.session_state:
64
- st.session_state.query_result = None
65
- if "followup_input" not in st.session_state:
66
- st.session_state.followup_input = ""
67
- if "followup_response" not in st.session_state:
68
- st.session_state.followup_response = None
69
- if "last_query" not in st.session_state:
70
- st.session_state.last_query = ""
71
- if "last_llm" not in st.session_state:
72
- st.session_state.last_llm = ""
73
 
74
  _workspace_sidebar()
75
 
 
76
  c1, c2 = st.columns([0.15, 0.85])
77
  with c1:
78
  if LOGO.exists():
@@ -81,13 +92,14 @@ def render_ui():
81
  st.markdown("## 🧬 **MedGenesis AI**")
82
  st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
83
 
84
- llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
85
- query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
 
86
 
87
- # Alerts
88
- if get_workspace():
89
  try:
90
- news = asyncio.run(check_alerts([w["query"] for w in get_workspace()]))
91
  if news:
92
  with st.sidebar:
93
  st.subheader("🔔 New papers")
@@ -96,113 +108,143 @@ def render_ui():
96
  except Exception:
97
  pass
98
 
99
- # Run search button
100
- if st.button("Run Search 🚀") and query:
101
  with st.spinner("Collecting literature & biomedical data …"):
102
  res = asyncio.run(orchestrate_search(query, llm=llm))
103
  st.success(f"Completed with **{res['llm_used'].title()}**")
104
- st.session_state.query_result = res
105
- st.session_state.last_query = query
106
- st.session_state.last_llm = llm
107
- st.session_state.followup_input = ""
108
- st.session_state.followup_response = None
 
 
109
 
110
  res = st.session_state.query_result
111
-
112
- if res:
113
- tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
114
-
115
- with tabs[0]:
116
- for i, p in enumerate(res["papers"], 1):
117
- st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
118
- st.write(p["summary"])
119
- col1, col2 = st.columns(2)
120
- with col1:
121
- st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False),
122
- "papers.csv", "text/csv")
123
- with col2:
124
- st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
125
- if st.button("💾 Save"):
126
- save_query(st.session_state.last_query, res)
127
- st.success("Saved to workspace")
128
-
129
- st.subheader("UMLS concepts")
130
- for c in res["umls"]:
131
- if c.get("cui"):
132
- st.write(f"- **{c['name']}** ({c['cui']})")
133
-
134
- st.subheader("OpenFDA safety")
135
- for d in res["drug_safety"]:
136
- st.json(d)
137
-
138
- st.subheader("AI summary")
139
- st.info(res["ai_summary"])
140
-
141
- with tabs[1]:
142
- st.header("Gene / Variant signals")
143
- for g in res["genes"]:
144
- st.write(f"- **{g.get('name', g.get('geneid'))}** {g.get('description','')}")
145
- if res["gene_disease"]:
146
- st.markdown("### DisGeNET links")
147
- st.json(res["gene_disease"][:15])
148
- if res["mesh_defs"]:
149
- st.markdown("### MeSH definitions")
150
- for d in res["mesh_defs"]:
151
- if d:
152
- st.write("-", d)
153
-
154
- with tabs[2]:
155
- st.header("Clinical trials")
156
- if not res["clinical_trials"]:
157
- st.info("No trials (rate-limited or none found).")
158
- for t in res["clinical_trials"]:
159
- st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
160
- st.write(f"Phase {t.get('Phase',[''])[0]} | Status {t['OverallStatus'][0]}")
161
-
162
- with tabs[3]:
163
- nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
164
- hl = st.text_input("Highlight node:", key="hl")
165
- if hl:
166
- pat = re.compile(re.escape(hl), re.I)
167
- for n in nodes:
168
- n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
169
- agraph(nodes, edges, cfg)
170
-
171
- with tabs[4]:
172
- nodes, edges, _ = build_agraph(res["papers"], res["umls"], res["drug_safety"])
173
- G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
174
- st.metric("Density", f"{get_density(G):.3f}")
175
- st.markdown("**Top hubs**")
176
- for nid, sc in get_top_hubs(G):
177
- lab = next((n.label for n in nodes if n.id == nid), nid)
178
- st.write(f"- {lab} {sc:.3f}")
179
-
180
- with tabs[5]:
181
- years = [p["published"] for p in res["papers"] if p.get("published")]
182
- if years:
183
- st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
184
-
185
- # Follow-up Q&A block with callback
186
- st.markdown("---")
187
- st.text_input("Ask follow‑up question:", key="followup_input")
188
- def handle_followup():
189
- follow = st.session_state.followup_input
190
- if follow.strip():
191
- ans = asyncio.run(answer_ai_question(
192
- follow,
193
- context=st.session_state.last_query,
194
- llm=st.session_state.last_llm))
195
- st.session_state.followup_response = ans["answer"]
196
- else:
197
- st.session_state.followup_response = None
198
-
199
- st.button("Ask AI", on_click=handle_followup)
200
-
201
- if st.session_state.followup_response:
202
- st.write(st.session_state.followup_response)
203
-
204
- else:
205
  st.info("Enter a question and press **Run Search 🚀**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
207
  if __name__ == "__main__":
208
  render_ui()
 
1
  #!/usr/bin/env python3
2
+ # MedGenesis AI · CPU-only Streamlit front-end (OpenAI / Gemini)
3
 
4
+ from __future__ import annotations
5
+ import os, pathlib, asyncio, re
6
  from pathlib import Path
7
 
8
  import streamlit as st
 
11
  from fpdf import FPDF
12
  from streamlit_agraph import agraph
13
 
14
+ from mcp.orchestrator import orchestrate_search, answer_ai_question
15
+ from mcp.workspace import get_workspace, save_query, clear_workspace
16
  from mcp.knowledge_graph import build_agraph
17
+ from mcp.graph_utils import build_nx, get_top_hubs, get_density
18
+ from mcp.alerts import check_alerts
19
 
20
+ # ── Streamlit telemetry dir fix ──────────────────────────────────────
21
+ os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
22
+ os.environ["XDG_STATE_HOME"] = "/tmp"
23
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
24
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
28
 
29
+ # -------------------------------------------------------------------#
30
+ # Utility helpers #
31
+ # -------------------------------------------------------------------#
32
  def _latin1_safe(txt: str) -> str:
33
  return txt.encode("latin-1", "replace").decode("latin-1")
34
 
35
+
36
+ def _pdf(papers: list[dict]) -> bytes:
37
  pdf = FPDF()
38
  pdf.set_auto_page_break(auto=True, margin=15)
39
  pdf.add_page()
 
49
  pdf.ln(1)
50
  return pdf.output(dest="S").encode("latin-1", "replace")
51
 
52
+
53
+ def _workspace_sidebar() -> None:
54
  with st.sidebar:
55
  st.header("🗂️ Workspace")
56
  ws = get_workspace()
57
  if not ws:
58
  st.info("Run a search then press **Save** to populate this list.")
59
  return
60
+ if st.button("Clear workspace 🗑️"):
61
+ clear_workspace()
62
+ st.experimental_rerun()
63
  for i, item in enumerate(ws, 1):
64
  with st.expander(f"{i}. {item['query']}"):
65
  st.write(item["result"]["ai_summary"])
66
 
67
+
68
+ # -------------------------------------------------------------------#
69
+ # Streamlit main UI #
70
+ # -------------------------------------------------------------------#
71
+ def render_ui() -> None:
72
  st.set_page_config("MedGenesis AI", layout="wide")
73
 
74
+ # ── session_state bootstrap ────────────────────────────────────
75
+ for key, default in {
76
+ "query_result" : None,
77
+ "followup_input" : "",
78
+ "followup_response" : None,
79
+ "last_query" : "",
80
+ "last_llm" : "",
81
+ }.items():
82
+ st.session_state.setdefault(key, default)
 
 
83
 
84
  _workspace_sidebar()
85
 
86
+ # ── header ─────────────────────────────────────────────────────
87
  c1, c2 = st.columns([0.15, 0.85])
88
  with c1:
89
  if LOGO.exists():
 
92
  st.markdown("## 🧬 **MedGenesis AI**")
93
  st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
94
 
95
+ llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
96
+ query = st.text_input("Enter biomedical question",
97
+ placeholder="e.g. CRISPR glioblastoma therapy")
98
 
99
+ # ── alerts for saved queries ───────────────────────────────────
100
+ if ws := get_workspace():
101
  try:
102
+ news = asyncio.run(check_alerts([w["query"] for w in ws]))
103
  if news:
104
  with st.sidebar:
105
  st.subheader("🔔 New papers")
 
108
  except Exception:
109
  pass
110
 
111
+ # ── primary search trigger ─────────────────────────────────────
112
+ if st.button("Run Search 🚀") and query.strip():
113
  with st.spinner("Collecting literature & biomedical data …"):
114
  res = asyncio.run(orchestrate_search(query, llm=llm))
115
  st.success(f"Completed with **{res['llm_used'].title()}**")
116
+ st.session_state.update({
117
+ "query_result" : res,
118
+ "last_query" : query,
119
+ "last_llm" : llm,
120
+ "followup_input" : "",
121
+ "followup_response" : None,
122
+ })
123
 
124
  res = st.session_state.query_result
125
+ if not res:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  st.info("Enter a question and press **Run Search 🚀**")
127
+ return
128
+
129
+ # ----------------------------------------------------------------#
130
+ # Tabs #
131
+ # ----------------------------------------------------------------#
132
+ tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
133
+
134
+ # Results ---------------------------------------------------------
135
+ with tabs[0]:
136
+ for i, p in enumerate(res["papers"], 1):
137
+ st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
138
+ st.write(p["summary"])
139
+
140
+ c1, c2 = st.columns(2)
141
+ with c1:
142
+ st.download_button("CSV",
143
+ pd.DataFrame(res["papers"]).to_csv(index=False),
144
+ "papers.csv", "text/csv")
145
+ with c2:
146
+ st.download_button("PDF", _pdf(res["papers"]),
147
+ "papers.pdf", "application/pdf")
148
+
149
+ if st.button("💾 Save this result"):
150
+ save_query(st.session_state.last_query, res)
151
+ st.success("Saved to workspace")
152
+
153
+ st.subheader("UMLS concepts")
154
+ for c in res["umls"]:
155
+ if c.get("cui"):
156
+ st.write(f"- **{c['name']}** ({c['cui']})")
157
+
158
+ st.subheader("OpenFDA safety")
159
+ for d in res["drug_safety"]:
160
+ st.json(d)
161
+
162
+ st.subheader("AI summary")
163
+ st.info(res["ai_summary"])
164
+
165
+ # Genes -----------------------------------------------------------
166
+ with tabs[1]:
167
+ st.header("Gene / Variant signals")
168
+ if not res["genes"]:
169
+ st.info("No gene hits (rate-limited or none found).")
170
+ for g in res["genes"]:
171
+ st.write(f"- **{g.get('symbol', g.get('name', ''))}** "
172
+ f"{g.get('summary', '')[:120]}…")
173
+
174
+ if res["gene_disease"]:
175
+ st.markdown("### DisGeNET links")
176
+ st.json(res["gene_disease"][:15])
177
+
178
+ if res["mesh_defs"]:
179
+ st.markdown("### MeSH definitions")
180
+ for d in res["mesh_defs"]:
181
+ if d:
182
+ st.write("-", d)
183
+
184
+ # Trials ----------------------------------------------------------
185
+ with tabs[2]:
186
+ st.header("Clinical trials")
187
+ trials = res["clinical_trials"]
188
+ if not trials:
189
+ st.info("No trials (rate-limited or none found).")
190
+ for t in trials:
191
+ st.markdown(f"**{t['nctId']}** – {t['briefTitle']}")
192
+ st.write(f"Phase {t.get('phase','')} | Status {t.get('status')}")
193
+
194
+ # Graph -----------------------------------------------------------
195
+ with tabs[3]:
196
+ nodes, edges, cfg = build_agraph(
197
+ res["papers"], res["umls"], res["drug_safety"],
198
+ res["genes"], res["clinical_trials"], res.get("ot_associations", [])
199
+ )
200
+ hl = st.text_input("Highlight node:")
201
+ if hl:
202
+ pat = re.compile(re.escape(hl), re.I)
203
+ for n in nodes:
204
+ n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
205
+ agraph(nodes, edges, cfg)
206
+
207
+ # Metrics ---------------------------------------------------------
208
+ with tabs[4]:
209
+ G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
210
+ st.metric("Density", f"{get_density(G):.3f}")
211
+ st.markdown("**Top hubs**")
212
+ for nid, sc in get_top_hubs(G):
213
+ lab = next((n.label for n in nodes if n.id == nid), nid)
214
+ st.write(f"- {lab} {sc:.3f}")
215
+
216
+ # Visuals ---------------------------------------------------------
217
+ with tabs[5]:
218
+ years = [p["published"] for p in res["papers"] if p.get("published")]
219
+ if years:
220
+ st.plotly_chart(px.histogram(years, nbins=12,
221
+ title="Publication Year"))
222
+
223
+ # ----------------------------------------------------------------#
224
+ # Follow-up Q & A #
225
+ # ----------------------------------------------------------------#
226
+ st.markdown("---")
227
+ st.text_input("Ask follow-up question:",
228
+ key="followup_input",
229
+ placeholder="e.g. Any phase III trials recruiting now?")
230
+
231
+ def _on_ask() -> None:
232
+ q = st.session_state.followup_input.strip()
233
+ if not q:
234
+ st.warning("Please type a question first.")
235
+ return
236
+ with st.spinner("Querying LLM …"):
237
+ ans = asyncio.run(
238
+ answer_ai_question(q,
239
+ context=st.session_state.last_query,
240
+ llm=st.session_state.last_llm))
241
+ st.session_state.followup_response = ans["answer"]
242
+
243
+ st.button("Ask AI", on_click=_on_ask)
244
+ if st.session_state.followup_response:
245
+ st.write(st.session_state.followup_response)
246
+
247
 
248
+ # -------------------------------------------------------------------#
249
  if __name__ == "__main__":
250
  render_ui()