MCP_Res / app.py
mgbam's picture
Update app.py
39219c6 verified
raw
history blame
11.4 kB
#!/usr/bin/env python3
# ──────────────────────────────────────────────────────────────────────
# MedGenesis AI – Streamlit UI (OpenAI + Gemini, CPU-only)
# ──────────────────────────────────────────────────────────────────────
import os, pathlib, asyncio, re
from pathlib import Path
from datetime import datetime
import streamlit as st
import pandas as pd
import plotly.express as px
from fpdf import FPDF
from streamlit_agraph import agraph
# ── internal helpers --------------------------------------------------
from mcp.orchestrator import orchestrate_search, answer_ai_question
from mcp.workspace import get_workspace, save_query
from mcp.knowledge_graph import build_agraph
from mcp.graph_metrics import build_nx, get_top_hubs, get_density
from mcp.alerts import check_alerts
# ── Streamlit telemetry dir fix (HF Spaces sandbox quirks) ------------
os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
os.environ["XDG_STATE_HOME"] = "/tmp"
os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
pathlib.Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)
ROOT = Path(__file__).parent
LOGO = ROOT / "assets" / "logo.png"
# ══════════════════════════════════════════════════════════════════════
# Small util helpers
# ══════════════════════════════════════════════════════════════════════
def _latin1_safe(txt: str) -> str:
"""Replace non-Latin-1 chars – keeps FPDF happy."""
return txt.encode("latin-1", "replace").decode("latin-1")
def _pdf(papers: list[dict]) -> bytes:
pdf = FPDF()
pdf.set_auto_page_break(auto=True, margin=15)
pdf.add_page()
pdf.set_font("Helvetica", size=11)
pdf.cell(200, 8, _latin1_safe("MedGenesis AI – Literature results"),
ln=True, align="C")
pdf.ln(3)
for i, p in enumerate(papers, 1):
pdf.set_font("Helvetica", "B", 11)
pdf.multi_cell(0, 7, _latin1_safe(f"{i}. {p['title']}"))
pdf.set_font("Helvetica", "", 9)
body = (
f"{p['authors']}\n"
f"{p['summary']}\n"
f"{p['link']}\n"
)
pdf.multi_cell(0, 6, _latin1_safe(body))
pdf.ln(1)
# FPDF already returns latin-1 bytes – no extra encode needed
return pdf.output(dest="S").encode("latin-1", "replace")
def _workspace_sidebar() -> None:
with st.sidebar:
st.header("πŸ—‚ Workspace")
ws = get_workspace()
if not ws:
st.info("Run a search then press **Save** to populate this list.")
return
for i, item in enumerate(ws, 1):
with st.expander(f"{i}. {item['query']}"):
st.write(item["result"]["ai_summary"])
# ══════════════════════════════════════════════════════════════════════
# Main Streamlit UI
# ══════════════════════════════════════════════════════════════════════
def render_ui() -> None:
st.set_page_config("MedGenesis AI", layout="wide")
# ── Session-state defaults ────────────────────────────────────────
for k, v in {
"query_result": None,
"followup_input": "",
"followup_response": None,
"last_query": "",
"last_llm": "",
}.items():
st.session_state.setdefault(k, v)
_workspace_sidebar()
col_logo, col_title = st.columns([0.15, 0.85])
with col_logo:
if LOGO.exists():
st.image(LOGO, width=110)
with col_title:
st.markdown("## 🧬 **MedGenesis AI**")
st.caption("Multi-source biomedical assistant Β· OpenAI / Gemini")
llm = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
query = st.text_input("Enter biomedical question",
placeholder="e.g. CRISPR glioblastoma therapy")
# ── alert notifications (async) ───────────────────────────────────
saved_qs = [w["query"] for w in get_workspace()]
if saved_qs:
try:
news = asyncio.run(check_alerts(saved_qs))
if news:
with st.sidebar:
st.subheader("πŸ”” New papers")
for q, lnks in news.items():
st.write(f"**{q}** – {len(lnks)} new")
except Exception:
pass # network hiccups – silent
# ── Run Search ----------------------------------------------------
if st.button("Run Search πŸš€") and query.strip():
with st.spinner("Collecting literature & biomedical data …"):
res = asyncio.run(orchestrate_search(query, llm=llm))
# store in session
st.session_state.update(
query_result=res,
last_query=query,
last_llm=llm,
followup_input="",
followup_response=None,
)
st.success(f"Completed with **{res['llm_used'].title()}**")
res = st.session_state.query_result
if not res:
st.info("Enter a biomedical question and press **Run Search πŸš€**")
return
# ── Tabs ----------------------------------------------------------
tabs = st.tabs(["Results", "Genes", "Trials",
"Graph", "Metrics", "Visuals"])
# 1) Results -------------------------------------------------------
with tabs[0]:
for i, p in enumerate(res["papers"], 1):
st.markdown(
f"**{i}. [{p['title']}]({p['link']})** "
f"*{p['authors']}*"
)
st.write(p["summary"])
c_csv, c_pdf = st.columns(2)
with c_csv:
st.download_button(
"CSV",
pd.DataFrame(res["papers"]).to_csv(index=False),
"papers.csv",
"text/csv",
)
with c_pdf:
st.download_button("PDF", _pdf(res["papers"]),
"papers.pdf", "application/pdf")
if st.button("πŸ’Ύ Save"):
save_query(st.session_state.last_query, res)
st.success("Saved to workspace")
st.subheader("UMLS concepts")
for c in (res["umls"] or []):
if isinstance(c, dict) and c.get("cui"):
st.write(f"- **{c['name']}** ({c['cui']})")
st.subheader("OpenFDA safety signals")
for d in (res["drug_safety"] or []):
st.json(d)
st.subheader("AI summary")
st.info(res["ai_summary"])
# 2) Genes ---------------------------------------------------------
with tabs[1]:
st.header("Gene / Variant signals")
genes_list = [
g for g in res["genes"]
if isinstance(g, dict) and (g.get("symbol") or g.get("name"))
]
if not genes_list:
st.info("No gene hits (rate-limited or none found).")
for g in genes_list:
st.write(f"- **{g.get('symbol') or g.get('name')}** "
f"{g.get('description','')}")
if res["gene_disease"]:
st.markdown("### DisGeNET associations")
ok = [d for d in res["gene_disease"] if isinstance(d, dict)]
if ok:
st.json(ok[:15])
defs = [d for d in res["mesh_defs"] if isinstance(d, str) and d]
if defs:
st.markdown("### MeSH definitions")
for d in defs:
st.write("-", d)
# 3) Trials --------------------------------------------------------
with tabs[2]:
st.header("Clinical trials")
ct = res["clinical_trials"]
if not ct:
st.info("No trials (rate-limited or none found).")
for t in ct:
nct = t.get("NCTId", [""])[0]
bttl = t.get("BriefTitle", [""])[0]
phase= t.get("Phase", [""])[0]
stat = t.get("OverallStatus", [""])[0]
st.markdown(f"**{nct}** – {bttl}")
st.write(f"Phase {phase} | Status {stat}")
# 4) Graph ---------------------------------------------------------
with tabs[3]:
nodes, edges, cfg = build_agraph(
res["papers"], res["umls"], res["drug_safety"]
)
hl = st.text_input("Highlight node:", key="hl")
if hl:
pat = re.compile(re.escape(hl), re.I)
for n in nodes:
n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
agraph(nodes, edges, cfg)
# 5) Metrics -------------------------------------------------------
with tabs[4]:
G = build_nx(
[n.__dict__ for n in nodes],
[e.__dict__ for e in edges],
)
st.metric("Density", f"{get_density(G):.3f}")
st.markdown("**Top hubs**")
for nid, sc in get_top_hubs(G, k=5):
label = next((n.label for n in nodes if n.id == nid), nid)
st.write(f"- {label} {sc:.3f}")
# 6) Visuals -------------------------------------------------------
with tabs[5]:
years = [
p["published"][:4] for p in res["papers"]
if p.get("published") and len(p["published"]) >= 4
]
if years:
st.plotly_chart(
px.histogram(
years, nbins=min(15, len(set(years))),
title="Publication Year"
)
)
# ── Follow-up Q-A -------------------------------------------------
st.markdown("---")
st.text_input("Ask follow-up question:",
key="followup_input",
placeholder="e.g. Any Phase III trials recruiting now?")
def _on_ask():
q = st.session_state.followup_input.strip()
if not q:
st.warning("Please type a question first.")
return
with st.spinner("Querying LLM …"):
ans = asyncio.run(
answer_ai_question(
q,
context=st.session_state.last_query,
llm=st.session_state.last_llm)
)
st.session_state.followup_response = (
ans.get("answer") or "LLM unavailable or quota exceeded."
)
st.button("Ask AI", on_click=_on_ask)
if st.session_state.followup_response:
st.write(st.session_state.followup_response)
# ── entry-point ───────────────────────────────────────────────────────
if __name__ == "__main__":
render_ui()