mgbam commited on
Commit
a7d3db7
·
verified ·
1 Parent(s): 0ca81fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -36
app.py CHANGED
@@ -1,8 +1,9 @@
1
  #!/usr/bin/env python3
2
  # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
3
 
4
- # ── Streamlit telemetry dir fix ───────────────────────────────────────
5
  import os, pathlib
 
 
6
  os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
7
  os.environ["XDG_STATE_HOME"] = "/tmp"
8
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
@@ -14,7 +15,7 @@ from pathlib import Path
14
  import streamlit as st
15
  import pandas as pd
16
  import plotly.express as px
17
- from fpdf import FPDF # classic FPDF → Latin-1 only
18
  from streamlit_agraph import agraph
19
 
20
  # ── Internal helpers ────────────────────────────────────────────────
@@ -29,7 +30,6 @@ LOGO = ROOT / "assets" / "logo.png"
29
 
30
  # ── PDF export helper (UTF-8 → Latin-1 “safe”) ──────────────────────
31
  def _latin1_safe(txt: str) -> str:
32
- """Return text that FPDF(latin-1) can embed; replace unknown chars."""
33
  return txt.encode("latin-1", "replace").decode("latin-1")
34
 
35
  def _pdf(papers):
@@ -71,7 +71,6 @@ def render_ui():
71
  st.set_page_config("MedGenesis AI", layout="wide")
72
  _workspace_sidebar()
73
 
74
- # Header
75
  c1, c2 = st.columns([0.15, 0.85])
76
  with c1:
77
  if LOGO.exists():
@@ -81,8 +80,7 @@ def render_ui():
81
  st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
82
 
83
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
84
- query = st.text_input("Enter biomedical question",
85
- placeholder="e.g. CRISPR glioblastoma therapy")
86
 
87
  # Alert check
88
  if get_workspace():
@@ -102,10 +100,8 @@ def render_ui():
102
  res = asyncio.run(orchestrate_search(query, llm=llm))
103
  st.success(f"Completed with **{res['llm_used'].title()}**")
104
 
105
- tabs = st.tabs(["Results", "Genes", "Trials", "Graph",
106
- "Metrics", "Visuals"])
107
 
108
- # Results
109
  with tabs[0]:
110
  for i, p in enumerate(res["papers"], 1):
111
  st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
@@ -113,12 +109,9 @@ def render_ui():
113
 
114
  col1, col2 = st.columns(2)
115
  with col1:
116
- st.download_button("CSV",
117
- pd.DataFrame(res["papers"]).to_csv(index=False),
118
- "papers.csv", "text/csv")
119
  with col2:
120
- st.download_button("PDF", _pdf(res["papers"]),
121
- "papers.pdf", "application/pdf")
122
 
123
  if st.button("💾 Save"):
124
  save_query(query, res)
@@ -136,12 +129,10 @@ def render_ui():
136
  st.subheader("AI summary")
137
  st.info(res["ai_summary"])
138
 
139
- # Genes
140
  with tabs[1]:
141
  st.header("Gene / Variant signals")
142
  for g in res["genes"]:
143
- st.write(f"- **{g.get('name', g.get('geneid'))}** "
144
- f"{g.get('description', '')}")
145
  if res["gene_disease"]:
146
  st.markdown("### DisGeNET links")
147
  st.json(res["gene_disease"][:15])
@@ -151,21 +142,16 @@ def render_ui():
151
  if d:
152
  st.write("-", d)
153
 
154
- # Trials
155
  with tabs[2]:
156
  st.header("Clinical trials")
157
  if not res["clinical_trials"]:
158
  st.info("No trials (rate-limited or none found).")
159
  for t in res["clinical_trials"]:
160
  st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
161
- st.write(f"Phase {t.get('Phase', [''])[0]} "
162
- f"| Status {t['OverallStatus'][0]}")
163
 
164
- # Graph
165
  with tabs[3]:
166
- nodes, edges, cfg = build_agraph(res["papers"],
167
- res["umls"],
168
- res["drug_safety"])
169
  hl = st.text_input("Highlight node:", key="hl")
170
  if hl:
171
  pat = re.compile(re.escape(hl), re.I)
@@ -173,31 +159,29 @@ def render_ui():
173
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
174
  agraph(nodes, edges, cfg)
175
 
176
- # Metrics
177
  with tabs[4]:
178
- G = build_nx([n.__dict__ for n in nodes],
179
- [e.__dict__ for e in edges])
180
  st.metric("Density", f"{get_density(G):.3f}")
181
  st.markdown("**Top hubs**")
182
  for nid, sc in get_top_hubs(G):
183
  lab = next((n.label for n in nodes if n.id == nid), nid)
184
  st.write(f"- {lab} {sc:.3f}")
185
 
186
- # Visuals
187
  with tabs[5]:
188
  years = [p["published"] for p in res["papers"] if p.get("published")]
189
  if years:
190
- st.plotly_chart(px.histogram(years, nbins=12,
191
- title="Publication Year"))
192
 
193
- # Follow-up Q-A
194
  st.markdown("---")
195
- follow = st.text_input("Ask follow-up:")
196
  if st.button("Ask AI"):
197
- ans = asyncio.run(answer_ai_question(follow,
198
- context=query,
199
- llm=llm))
200
- st.write(ans["answer"])
 
 
201
 
202
  else:
203
  st.info("Enter a question and press **Run Search 🚀**")
 
1
  #!/usr/bin/env python3
2
  # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
3
 
 
4
  import os, pathlib
5
+
6
+ # ── Streamlit telemetry dir fix ───────────────────────────────────────
7
  os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
8
  os.environ["XDG_STATE_HOME"] = "/tmp"
9
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
 
15
  import streamlit as st
16
  import pandas as pd
17
  import plotly.express as px
18
+ from fpdf import FPDF
19
  from streamlit_agraph import agraph
20
 
21
  # ── Internal helpers ────────────────────────────────────────────────
 
30
 
31
  # ── PDF export helper (UTF-8 → Latin-1 “safe”) ──────────────────────
32
  def _latin1_safe(txt: str) -> str:
 
33
  return txt.encode("latin-1", "replace").decode("latin-1")
34
 
35
  def _pdf(papers):
 
71
  st.set_page_config("MedGenesis AI", layout="wide")
72
  _workspace_sidebar()
73
 
 
74
  c1, c2 = st.columns([0.15, 0.85])
75
  with c1:
76
  if LOGO.exists():
 
80
  st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
81
 
82
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
83
+ query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
 
84
 
85
  # Alert check
86
  if get_workspace():
 
100
  res = asyncio.run(orchestrate_search(query, llm=llm))
101
  st.success(f"Completed with **{res['llm_used'].title()}**")
102
 
103
+ tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
 
104
 
 
105
  with tabs[0]:
106
  for i, p in enumerate(res["papers"], 1):
107
  st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
 
109
 
110
  col1, col2 = st.columns(2)
111
  with col1:
112
+ st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False), "papers.csv", "text/csv")
 
 
113
  with col2:
114
+ st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
 
115
 
116
  if st.button("💾 Save"):
117
  save_query(query, res)
 
129
  st.subheader("AI summary")
130
  st.info(res["ai_summary"])
131
 
 
132
  with tabs[1]:
133
  st.header("Gene / Variant signals")
134
  for g in res["genes"]:
135
+ st.write(f"- **{g.get('name', g.get('geneid'))}** {g.get('description', '')}")
 
136
  if res["gene_disease"]:
137
  st.markdown("### DisGeNET links")
138
  st.json(res["gene_disease"][:15])
 
142
  if d:
143
  st.write("-", d)
144
 
 
145
  with tabs[2]:
146
  st.header("Clinical trials")
147
  if not res["clinical_trials"]:
148
  st.info("No trials (rate-limited or none found).")
149
  for t in res["clinical_trials"]:
150
  st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
151
+ st.write(f"Phase {t.get('Phase',[''])[0]} | Status {t['OverallStatus'][0]}")
 
152
 
 
153
  with tabs[3]:
154
+ nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
 
 
155
  hl = st.text_input("Highlight node:", key="hl")
156
  if hl:
157
  pat = re.compile(re.escape(hl), re.I)
 
159
  n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
160
  agraph(nodes, edges, cfg)
161
 
 
162
  with tabs[4]:
163
+ G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
 
164
  st.metric("Density", f"{get_density(G):.3f}")
165
  st.markdown("**Top hubs**")
166
  for nid, sc in get_top_hubs(G):
167
  lab = next((n.label for n in nodes if n.id == nid), nid)
168
  st.write(f"- {lab} {sc:.3f}")
169
 
 
170
  with tabs[5]:
171
  years = [p["published"] for p in res["papers"] if p.get("published")]
172
  if years:
173
+ st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
 
174
 
175
+ # ── Follow-up Q-A (fixed) ───────────────────────────────────────
176
  st.markdown("---")
177
+ follow = st.text_input("Ask follow-up question:", key="followup_input") # ✅ UPDATED
178
  if st.button("Ask AI"):
179
+ if follow.strip(): # ✅ UPDATED
180
+ with st.spinner("Generating AI response..."):
181
+ ans = asyncio.run(answer_ai_question(follow, context=query, llm=llm))
182
+ st.write(ans["answer"])
183
+ else:
184
+ st.warning("Please type a follow-up question before submitting.") # ✅ UPDATED
185
 
186
  else:
187
  st.info("Enter a question and press **Run Search 🚀**")