mgbam commited on
Commit
d55cbab
Β·
verified Β·
1 Parent(s): fbb4b8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -90
app.py CHANGED
@@ -4,8 +4,8 @@
4
  import os, pathlib
5
 
6
  # ── Streamlit telemetry dir fix ───────────────────────────────────────
7
- os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
8
- os.environ["XDG_STATE_HOME"] = "/tmp"
9
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
10
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
11
 
@@ -18,10 +18,10 @@ from fpdf import FPDF
18
  from streamlit_agraph import agraph
19
 
20
  from mcp.orchestrator import orchestrate_search, answer_ai_question
21
- from mcp.workspace import get_workspace, save_query
22
  from mcp.knowledge_graph import build_agraph
23
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
24
- from mcp.alerts import check_alerts
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
@@ -59,11 +59,9 @@ def _workspace_sidebar():
59
  def render_ui():
60
  st.set_page_config("MedGenesis AI", layout="wide")
61
 
62
- # Initialize session-state
63
  if "followup_input" not in st.session_state:
64
  st.session_state.followup_input = ""
65
- if "tab_index" not in st.session_state:
66
- st.session_state.tab_index = 0
67
 
68
  _workspace_sidebar()
69
 
@@ -78,7 +76,6 @@ def render_ui():
78
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
79
  query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
80
 
81
- # Alert check
82
  if get_workspace():
83
  try:
84
  news = asyncio.run(check_alerts([w["query"] for w in get_workspace()]))
@@ -96,94 +93,86 @@ def render_ui():
96
  res = asyncio.run(orchestrate_search(query, llm=llm))
97
  st.success(f"Completed with **{res['llm_used'].title()}**")
98
  st.session_state.query_result = res
99
- # Reset follow-up input
100
  st.session_state.followup_input = ""
101
- st.session_state.tab_index = 0
102
  else:
103
  res = st.session_state.get("query_result", None)
104
 
105
  if res:
106
- tabs_list = ["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"]
107
- tabs = st.tabs(tabs_list, index=st.session_state.tab_index)
108
-
109
- for idx, name in enumerate(tabs_list):
110
- with tabs[idx]:
111
- st.session_state.tab_index = idx
112
- if name == "Results":
113
- for i, p in enumerate(res["papers"], 1):
114
- st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
115
- st.write(p["summary"])
116
- col1, col2 = st.columns(2)
117
- with col1:
118
- st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False),
119
- "papers.csv", "text/csv")
120
- with col2:
121
- st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
122
- if st.button("πŸ’Ύ Save"):
123
- save_query(query, res)
124
- st.success("Saved to workspace")
125
- st.subheader("UMLS concepts")
126
- for c in res["umls"]:
127
- if c.get("cui"):
128
- st.write(f"- **{c['name']}** ({c['cui']})")
129
- st.subheader("OpenFDA safety")
130
- for d in res["drug_safety"]:
131
- st.json(d)
132
- st.subheader("AI summary")
133
- st.info(res["ai_summary"])
134
-
135
- elif name == "Genes":
136
- st.header("Gene / Variant signals")
137
- for g in res["genes"]:
138
- st.write(f"- **{g.get('name', g.get('geneid'))}** "
139
- f"{g.get('description', '')}")
140
- if res["gene_disease"]:
141
- st.markdown("### DisGeNET links")
142
- st.json(res["gene_disease"][:15])
143
- if res["mesh_defs"]:
144
- st.markdown("### MeSH definitions")
145
- for d in res["mesh_defs"]:
146
- if d:
147
- st.write("-", d)
148
-
149
- elif name == "Trials":
150
- st.header("Clinical trials")
151
- if not res["clinical_trials"]:
152
- st.info("No trials (rate-limited or none found).")
153
- for t in res["clinical_trials"]:
154
- st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
155
- st.write(f"Phase {t.get('Phase',[''])[0]} | "
156
- f"Status {t['OverallStatus'][0]}")
157
-
158
- elif name == "Graph":
159
- nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
160
- hl = st.text_input("Highlight node:", key="hl")
161
- if hl:
162
- pat = re.compile(re.escape(hl), re.I)
163
- for n in nodes:
164
- n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
165
- agraph(nodes, edges, cfg)
166
-
167
- elif name == "Metrics":
168
- nodes, edges, _ = build_agraph(res["papers"], res["umls"], res["drug_safety"])
169
- G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
170
- st.metric("Density", f"{get_density(G):.3f}")
171
- st.markdown("**Top hubs**")
172
- for nid, sc in get_top_hubs(G):
173
- lab = next((n.label for n in nodes if n.id == nid), nid)
174
- st.write(f"- {lab} {sc:.3f}")
175
-
176
- elif name == "Visuals":
177
- years = [p["published"] for p in res["papers"] if p.get("published")]
178
- if years:
179
- st.plotly_chart(px.histogram(years, nbins=12,
180
- title="Publication Year"))
181
-
182
- # Follow-up Q-A persistently under tabs
183
  st.markdown("---")
184
- follow = st.text_input("Ask follow‑up question:",
185
- value=st.session_state.followup_input,
186
- key="followup_input")
 
 
187
  if st.button("Ask AI"):
188
  st.session_state.followup_input = follow
189
  if follow.strip():
 
4
  import os, pathlib
5
 
6
  # ── Streamlit telemetry dir fix ───────────────────────────────────────
7
+ os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
8
+ os.environ["XDG_STATE_HOME"] = "/tmp"
9
  os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
10
  pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
11
 
 
18
  from streamlit_agraph import agraph
19
 
20
  from mcp.orchestrator import orchestrate_search, answer_ai_question
21
+ from mcp.workspace import get_workspace, save_query
22
  from mcp.knowledge_graph import build_agraph
23
  from mcp.graph_metrics import build_nx, get_top_hubs, get_density
24
+ from mcp.alerts import check_alerts
25
 
26
  ROOT = Path(__file__).parent
27
  LOGO = ROOT / "assets" / "logo.png"
 
59
  def render_ui():
60
  st.set_page_config("MedGenesis AI", layout="wide")
61
 
62
+ # Persist follow-up input
63
  if "followup_input" not in st.session_state:
64
  st.session_state.followup_input = ""
 
 
65
 
66
  _workspace_sidebar()
67
 
 
76
  llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
77
  query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
78
 
 
79
  if get_workspace():
80
  try:
81
  news = asyncio.run(check_alerts([w["query"] for w in get_workspace()]))
 
93
  res = asyncio.run(orchestrate_search(query, llm=llm))
94
  st.success(f"Completed with **{res['llm_used'].title()}**")
95
  st.session_state.query_result = res
 
96
  st.session_state.followup_input = ""
 
97
  else:
98
  res = st.session_state.get("query_result", None)
99
 
100
  if res:
101
+ tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
102
+
103
+ with tabs[0]: # Results
104
+ for i, p in enumerate(res["papers"], 1):
105
+ st.markdown(f"**{i}. [{p['title']}]({p['link']})** *{p['authors']}*")
106
+ st.write(p["summary"])
107
+ col1, col2 = st.columns(2)
108
+ with col1:
109
+ st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False),
110
+ "papers.csv", "text/csv")
111
+ with col2:
112
+ st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
113
+ if st.button("πŸ’Ύ Save"):
114
+ save_query(query, res)
115
+ st.success("Saved to workspace")
116
+ st.subheader("UMLS concepts")
117
+ for c in res["umls"]:
118
+ if c.get("cui"):
119
+ st.write(f"- **{c['name']}** ({c['cui']})")
120
+ st.subheader("OpenFDA safety")
121
+ for d in res["drug_safety"]:
122
+ st.json(d)
123
+ st.subheader("AI summary")
124
+ st.info(res["ai_summary"])
125
+
126
+ with tabs[1]: # Genes
127
+ st.header("Gene / Variant signals")
128
+ for g in res["genes"]:
129
+ st.write(f"- **{g.get('name', g.get('geneid'))}** {g.get('description', '')}")
130
+ if res["gene_disease"]:
131
+ st.markdown("### DisGeNET links")
132
+ st.json(res["gene_disease"][:15])
133
+ if res["mesh_defs"]:
134
+ st.markdown("### MeSH definitions")
135
+ for d in res["mesh_defs"]:
136
+ if d:
137
+ st.write("-", d)
138
+
139
+ with tabs[2]: # Trials
140
+ st.header("Clinical trials")
141
+ if not res["clinical_trials"]:
142
+ st.info("No trials (rate-limited or none found).")
143
+ for t in res["clinical_trials"]:
144
+ st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
145
+ st.write(f"Phase {t.get('Phase',[''])[0]} | Status {t['OverallStatus'][0]}")
146
+
147
+ with tabs[3]: # Graph
148
+ nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
149
+ hl = st.text_input("Highlight node:", key="hl")
150
+ if hl:
151
+ pat = re.compile(re.escape(hl), re.I)
152
+ for n in nodes:
153
+ n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
154
+ agraph(nodes, edges, cfg)
155
+
156
+ with tabs[4]: # Metrics
157
+ nodes, edges, _ = build_agraph(res["papers"], res["umls"], res["drug_safety"])
158
+ G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
159
+ st.metric("Density", f"{get_density(G):.3f}")
160
+ st.markdown("**Top hubs**")
161
+ for nid, sc in get_top_hubs(G):
162
+ lab = next((n.label for n in nodes if n.id == nid), nid)
163
+ st.write(f"- {lab} {sc:.3f}")
164
+
165
+ with tabs[5]: # Visuals
166
+ years = [p["published"] for p in res["papers"] if p.get("published")]
167
+ if years:
168
+ st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
169
+
 
 
 
 
 
 
 
 
170
  st.markdown("---")
171
+ follow = st.text_input(
172
+ "Ask follow‑up question:",
173
+ value=st.session_state.followup_input,
174
+ key="followup_input"
175
+ )
176
  if st.button("Ask AI"):
177
  st.session_state.followup_input = follow
178
  if follow.strip():