mgbam commited on
Commit
140bf8b
·
verified ·
1 Parent(s): 0bd4f6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -105
app.py CHANGED
@@ -1,6 +1,21 @@
1
- # app.py - MedGenesis AI Streamlit app (OpenAI/Gemini)
2
-
3
- import os, pathlib, asyncio, re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pathlib import Path
5
 
6
  import streamlit as st
@@ -9,25 +24,38 @@ import plotly.express as px
9
  from fpdf import FPDF
10
  from streamlit_agraph import agraph
11
 
 
 
 
12
  from mcp.orchestrator import orchestrate_search, answer_ai_question
13
  from mcp.workspace import get_workspace, save_query
14
  from mcp.knowledge_graph import build_agraph
15
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
16
  from mcp.alerts import check_alerts
17
 
18
- # --- Fix Streamlit temp dir ---
19
- os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
20
- os.environ["XDG_STATE_HOME"] = "/tmp"
21
- os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
 
 
 
 
22
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
23
 
24
  ROOT = Path(__file__).parent
25
  LOGO = ROOT / "assets" / "logo.png"
26
 
 
 
 
 
27
  def _latin1_safe(txt: str) -> str:
 
28
  return txt.encode("latin-1", "replace").decode("latin-1")
29
 
30
- def _pdf(papers):
 
31
  pdf = FPDF()
32
  pdf.set_auto_page_break(auto=True, margin=15)
33
  pdf.add_page()
@@ -36,13 +64,16 @@ def _pdf(papers):
36
  pdf.ln(3)
37
  for i, p in enumerate(papers, 1):
38
  pdf.set_font("Helvetica", "B", 11)
39
- pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p.get('title', '')}"))
40
  pdf.set_font("Helvetica", "", 9)
41
- body = f"{p.get('authors','')}\n{p.get('summary','')}\n{p.get('link','')}\n"
 
 
42
  pdf.multi_cell(0, 6, _latin1_safe(body))
43
  pdf.ln(1)
44
  return pdf.output(dest="S").encode("latin-1", "replace")
45
 
 
46
  def _workspace_sidebar():
47
  with st.sidebar:
48
  st.header("🗂️ Workspace")
@@ -52,141 +83,140 @@ def _workspace_sidebar():
52
  return
53
  for i, item in enumerate(ws, 1):
54
  with st.expander(f"{i}. {item['query']}"):
55
- st.write(item["result"].get("ai_summary", ""))
 
 
 
 
56
 
57
  def render_ui():
58
  st.set_page_config("MedGenesis AI", layout="wide")
59
 
60
- # Session state
61
- for k, v in [
62
- ("query_result", None), ("followup_input", ""),
63
- ("followup_response", None), ("last_query", ""), ("last_llm", "")
64
- ]:
65
- if k not in st.session_state:
66
- st.session_state[k] = v
 
 
 
67
 
68
  _workspace_sidebar()
69
- c1, c2 = st.columns([0.15, 0.85])
70
- with c1:
 
 
71
  if LOGO.exists():
72
  st.image(str(LOGO), width=105)
73
- with c2:
74
  st.markdown("## 🧬 **MedGenesis AI**")
75
- st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
76
 
77
- llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
78
- query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
 
79
 
80
  # Alerts
81
- wsq = get_workspace()
82
- if wsq:
83
  try:
84
- news = asyncio.run(check_alerts([w["query"] for w in wsq]))
85
- if news:
86
  with st.sidebar:
87
  st.subheader("🔔 New papers")
88
- for q, lnks in news.items():
89
  st.write(f"**{q}** – {len(lnks)} new")
90
  except Exception:
91
  pass
92
 
 
93
  if st.button("Run Search 🚀") and query:
94
  with st.spinner("Collecting literature & biomedical data …"):
95
- res = asyncio.run(orchestrate_search(query, llm=llm))
96
- st.success(f"Completed with **{res.get('llm_used','LLM').title()}**")
97
- st.session_state.query_result = res
98
- st.session_state.last_query = query
99
- st.session_state.last_llm = llm
100
- st.session_state.followup_input = ""
101
- st.session_state.followup_response = None
 
 
102
 
103
  res = st.session_state.query_result
104
  if not res:
105
  st.info("Enter a question and press **Run Search 🚀**")
106
  return
107
 
 
108
  tabs = st.tabs(["Results", "Genes", "Trials", "Variants", "Graph", "Metrics", "Visuals"])
109
- # --------------- Results Tab ---------------
 
110
  with tabs[0]:
111
- for i, p in enumerate(res.get("papers", []), 1):
 
112
  st.markdown(f"**{i}. [{p.get('title','')}]({p.get('link','')})** *{p.get('authors','')}*")
113
- st.write(p.get("summary", ""))
114
- col1, col2 = st.columns(2)
115
- with col1:
116
- st.download_button("CSV", pd.DataFrame(res.get("papers", [])).to_csv(index=False),
117
- "papers.csv", "text/csv")
118
- with col2:
119
- st.download_button("PDF", _pdf(res.get("papers", [])), "papers.pdf", "application/pdf")
120
  if st.button("💾 Save"):
121
  save_query(st.session_state.last_query, res)
122
  st.success("Saved to workspace")
 
123
  st.subheader("UMLS concepts")
124
- for c in res.get("umls", []):
125
- if isinstance(c, dict) and c.get("cui"):
126
  st.write(f"- **{c.get('name','')}** ({c.get('cui')})")
 
127
  st.subheader("OpenFDA safety signals")
128
- st.json(res.get("drug_safety", []))
 
 
129
  st.subheader("AI summary")
130
- st.info(res.get("ai_summary", ""))
131
 
132
- # --------------- Genes Tab ---------------
133
  with tabs[1]:
134
  st.header("Gene / Variant signals")
135
- genes = res.get("genes", [])
136
- if not genes:
137
- st.info("No gene hits (rate-limited or none found).")
138
- else:
139
- for g in genes:
140
- if isinstance(g, dict):
141
- lab = g.get("name") or g.get("symbol") or g.get("geneid")
142
- st.write(f"- **{lab}** {g.get('description','')}")
143
- if res.get("gene_disease"):
144
- st.markdown("### DisGeNET associations")
145
- st.json(res.get("gene_disease")[:15])
146
- if res.get("mesh_defs"):
147
  st.markdown("### MeSH definitions")
148
- for d in res["mesh_defs"]:
149
- if d:
150
- st.write("-", d)
 
 
151
 
152
- # --------------- Trials Tab ---------------
153
  with tabs[2]:
154
  st.header("Clinical trials")
155
- trials = res.get("clinical_trials", [])
156
  if not trials:
157
- st.info("No trials (rate-limited or none found).")
158
- else:
159
- for t in trials:
160
- nct = t.get("nctId") or (t.get("NCTId", [""])[0] if isinstance(t.get("NCTId"), list) else "")
161
- title = t.get("briefTitle") or (t.get("BriefTitle", [""])[0] if isinstance(t.get("BriefTitle"), list) else "")
162
- phase = t.get("phase") or (t.get("Phase", [""])[0] if isinstance(t.get("Phase"), list) else "")
163
- status = t.get("status") or (t.get("OverallStatus", [""])[0] if isinstance(t.get("OverallStatus"), list) else "")
164
- st.markdown(f"**{nct}** – {title}")
165
- st.write(f"Phase {phase} | Status {status}")
166
-
167
- # --------------- Variants Tab ---------------
168
  with tabs[3]:
169
  st.header("Cancer variants (cBioPortal)")
170
- variants = res.get("variants", [])
171
  if not variants:
172
- st.info("No variant data.")
173
  else:
174
- for v in variants:
175
- st.json(v)
176
 
177
- # --------------- Graph Tab ---------------
178
  with tabs[4]:
179
- nodes, edges, cfg = build_agraph(res.get("papers", []), res.get("umls", []), res.get("drug_safety", []))
180
- hl = st.text_input("Highlight node:", key="hl")
181
- if hl:
182
- pat = re.compile(re.escape(hl), re.I)
183
- for n in nodes:
184
- n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
185
  agraph(nodes, edges, cfg)
186
 
187
- # --------------- Metrics Tab ---------------
188
  with tabs[5]:
189
- nodes, edges, _ = build_agraph(res.get("papers", []), res.get("umls", []), res.get("drug_safety", []))
190
  G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
191
  st.metric("Density", f"{get_density(G):.3f}")
192
  st.markdown("**Top hubs**")
@@ -194,28 +224,31 @@ def render_ui():
194
  lab = next((n.label for n in nodes if n.id == nid), nid)
195
  st.write(f"- {lab} {sc:.3f}")
196
 
197
- # --------------- Visuals Tab ---------------
198
  with tabs[6]:
199
- years = [p.get("published", "") for p in res.get("papers", []) if p.get("published")]
200
  if years:
201
  st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
202
 
203
- # --------------- Follow-up Q&A ---------------
204
  st.markdown("---")
205
  st.text_input("Ask follow‑up question:", key="followup_input")
206
- def handle_followup():
207
- follow = st.session_state.followup_input
208
- if follow.strip():
209
- ans = asyncio.run(answer_ai_question(
210
- follow,
211
- context=st.session_state.last_query,
212
- llm=st.session_state.last_llm))
213
- st.session_state.followup_response = ans.get("answer", "No answer.")
214
- else:
215
- st.session_state.followup_response = None
216
- st.button("Ask AI", on_click=handle_followup)
217
  if st.session_state.followup_response:
218
  st.write(st.session_state.followup_response)
219
 
 
 
 
220
  if __name__ == "__main__":
221
  render_ui()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MedGenesis AI Streamlit front‑end (v3)
4
+ --------------------------------------
5
+ Supports **OpenAI** and **Gemini** engines and the enriched backend
6
+ payload introduced in orchestrator v3:
7
+ • papers, umls, drug_safety, genes, mesh_defs, gene_disease,
8
+ clinical_trials, variants, ai_summary
9
+ Tabs:
10
+ Results | Genes | Trials | Variants | Graph | Metrics | Visuals
11
+ """
12
+
13
+ ##############################################################################
14
+ # Std‑lib / third‑party
15
+ ##############################################################################
16
+ import os
17
+ import pathlib
18
+ import asyncio
19
  from pathlib import Path
20
 
21
  import streamlit as st
 
24
  from fpdf import FPDF
25
  from streamlit_agraph import agraph
26
 
27
+ ##############################################################################
28
+ # Internal helpers
29
+ ##############################################################################
30
  from mcp.orchestrator import orchestrate_search, answer_ai_question
31
  from mcp.workspace import get_workspace, save_query
32
  from mcp.knowledge_graph import build_agraph
33
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
34
  from mcp.alerts import check_alerts
35
 
36
+ # ---------------------------------------------------------------------------
37
+ # Streamlit telemetry directory → /tmp
38
+ # ---------------------------------------------------------------------------
39
+ os.environ.update({
40
+ "STREAMLIT_DATA_DIR": "/tmp/.streamlit",
41
+ "XDG_STATE_HOME" : "/tmp",
42
+ "STREAMLIT_BROWSER_GATHERUSAGESTATS": "false",
43
+ })
44
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
45
 
46
  ROOT = Path(__file__).parent
47
  LOGO = ROOT / "assets" / "logo.png"
48
 
49
+ ##############################################################################
50
+ # Utility helpers
51
+ ##############################################################################
52
+
53
  def _latin1_safe(txt: str) -> str:
54
+ """Coerce UTF‑8 → Latin‑1 with replacement (for FPDF)."""
55
  return txt.encode("latin-1", "replace").decode("latin-1")
56
 
57
+
58
+ def _pdf(papers: list[dict]) -> bytes:
59
  pdf = FPDF()
60
  pdf.set_auto_page_break(auto=True, margin=15)
61
  pdf.add_page()
 
64
  pdf.ln(3)
65
  for i, p in enumerate(papers, 1):
66
  pdf.set_font("Helvetica", "B", 11)
67
+ pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p.get('title','')}"))
68
  pdf.set_font("Helvetica", "", 9)
69
+ body = f"{p.get('authors','')}
70
+ {p.get('summary','')}
71
+ {p.get('link','')}\n"
72
  pdf.multi_cell(0, 6, _latin1_safe(body))
73
  pdf.ln(1)
74
  return pdf.output(dest="S").encode("latin-1", "replace")
75
 
76
+
77
  def _workspace_sidebar():
78
  with st.sidebar:
79
  st.header("🗂️ Workspace")
 
83
  return
84
  for i, item in enumerate(ws, 1):
85
  with st.expander(f"{i}. {item['query']}"):
86
+ st.write(item['result']['ai_summary'])
87
+
88
+ ##############################################################################
89
+ # Main UI renderer
90
+ ##############################################################################
91
 
92
  def render_ui():
93
  st.set_page_config("MedGenesis AI", layout="wide")
94
 
95
+ # Session‑state defaults
96
+ defaults = dict(
97
+ query_result=None,
98
+ followup_input="",
99
+ followup_response=None,
100
+ last_query="",
101
+ last_llm="openai",
102
+ )
103
+ for k, v in defaults.items():
104
+ st.session_state.setdefault(k, v)
105
 
106
  _workspace_sidebar()
107
+
108
+ # Header
109
+ col1, col2 = st.columns([0.15, 0.85])
110
+ with col1:
111
  if LOGO.exists():
112
  st.image(str(LOGO), width=105)
113
+ with col2:
114
  st.markdown("## 🧬 **MedGenesis AI**")
115
+ st.caption("Multisource biomedical assistant · OpenAI / Gemini")
116
 
117
+ # Controls
118
+ engine = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
119
+ query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
120
 
121
  # Alerts
122
+ if get_workspace():
 
123
  try:
124
+ alerts = asyncio.run(check_alerts([w["query"] for w in get_workspace()]))
125
+ if alerts:
126
  with st.sidebar:
127
  st.subheader("🔔 New papers")
128
+ for q, lnks in alerts.items():
129
  st.write(f"**{q}** – {len(lnks)} new")
130
  except Exception:
131
  pass
132
 
133
+ # Run Search
134
  if st.button("Run Search 🚀") and query:
135
  with st.spinner("Collecting literature & biomedical data …"):
136
+ res = asyncio.run(orchestrate_search(query, llm=engine))
137
+ st.session_state.update(
138
+ query_result=res,
139
+ last_query=query,
140
+ last_llm=engine,
141
+ followup_input="",
142
+ followup_response=None,
143
+ )
144
+ st.success(f"Completed with **{res['llm_used'].title()}**")
145
 
146
  res = st.session_state.query_result
147
  if not res:
148
  st.info("Enter a question and press **Run Search 🚀**")
149
  return
150
 
151
+ # Tabs
152
  tabs = st.tabs(["Results", "Genes", "Trials", "Variants", "Graph", "Metrics", "Visuals"])
153
+
154
+ # --- Results tab
155
  with tabs[0]:
156
+ st.subheader("Literature")
157
+ for i, p in enumerate(res['papers'], 1):
158
  st.markdown(f"**{i}. [{p.get('title','')}]({p.get('link','')})** *{p.get('authors','')}*")
159
+ st.write(p.get('summary',''))
160
+ c1, c2 = st.columns(2)
161
+ with c1:
162
+ st.download_button("CSV", pd.DataFrame(res['papers']).to_csv(index=False), "papers.csv", "text/csv")
163
+ with c2:
164
+ st.download_button("PDF", _pdf(res['papers']), "papers.pdf", "application/pdf")
 
165
  if st.button("💾 Save"):
166
  save_query(st.session_state.last_query, res)
167
  st.success("Saved to workspace")
168
+
169
  st.subheader("UMLS concepts")
170
+ for c in res['umls']:
171
+ if c.get('cui'):
172
  st.write(f"- **{c.get('name','')}** ({c.get('cui')})")
173
+
174
  st.subheader("OpenFDA safety signals")
175
+ for d in res['drug_safety']:
176
+ st.json(d)
177
+
178
  st.subheader("AI summary")
179
+ st.info(res['ai_summary'])
180
 
181
+ # --- Genes tab
182
  with tabs[1]:
183
  st.header("Gene / Variant signals")
184
+ for g in res['genes']:
185
+ sym = g.get('symbol') or g.get('name') or ''
186
+ st.write(f"- **{sym}**")
187
+ if res['mesh_defs']:
 
 
 
 
 
 
 
 
188
  st.markdown("### MeSH definitions")
189
+ for d in res['mesh_defs']:
190
+ st.write(f"- {d}")
191
+ if res['gene_disease']:
192
+ st.markdown("### DisGeNET links")
193
+ st.json(res['gene_disease'][:15])
194
 
195
+ # --- Trials tab
196
  with tabs[2]:
197
  st.header("Clinical trials")
198
+ trials = res['clinical_trials']
199
  if not trials:
200
+ st.info("No trials returned (ratelimited or none found).")
201
+ for t in trials:
202
+ st.markdown(f"**{t.get('nctId','')}** – {t.get('briefTitle','')} Phase {t.get('phase','?')} | Status {t.get('status','?')}")
203
+
204
+ # --- Variants tab
 
 
 
 
 
 
205
  with tabs[3]:
206
  st.header("Cancer variants (cBioPortal)")
207
+ variants = res['variants']
208
  if not variants:
209
+ st.info("No variants for this gene/profile.")
210
  else:
211
+ st.json(variants[:30])
 
212
 
213
+ # --- Graph tab
214
  with tabs[4]:
215
+ nodes, edges, cfg = build_agraph(res['papers'], res['umls'], res['drug_safety'])
 
 
 
 
 
216
  agraph(nodes, edges, cfg)
217
 
218
+ # --- Metrics tab
219
  with tabs[5]:
 
220
  G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
221
  st.metric("Density", f"{get_density(G):.3f}")
222
  st.markdown("**Top hubs**")
 
224
  lab = next((n.label for n in nodes if n.id == nid), nid)
225
  st.write(f"- {lab} {sc:.3f}")
226
 
227
+ # --- Visuals tab
228
  with tabs[6]:
229
+ years = [p.get('published') for p in res['papers'] if p.get('published')]
230
  if years:
231
  st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
232
 
233
+ # Followup QA
234
  st.markdown("---")
235
  st.text_input("Ask follow‑up question:", key="followup_input")
236
+
237
+ def _on_ask():
238
+ q = st.session_state.followup_input
239
+ if not q.strip():
240
+ st.warning("Please type a question first.")
241
+ return
242
+ with st.spinner("Querying LLM …"):
243
+ ans = asyncio.run(answer_ai_question(q, context=st.session_state.last_query, llm=st.session_state.last_llm))
244
+ st.session_state.followup_response = ans['answer']
245
+
246
+ st.button("Ask AI", on_click=_on_ask)
247
  if st.session_state.followup_response:
248
  st.write(st.session_state.followup_response)
249
 
250
+ ##############################################################################
251
+ # Entrypoint
252
+ ##############################################################################
253
  if __name__ == "__main__":
254
  render_ui()