MCP_Res / app.py
mgbam's picture
Update app.py
0f74db4 verified
raw
history blame
6.66 kB
# app.py β€’ MedGenesis AI – CPU-only powerhouse
import asyncio, os, re, httpx
from pathlib import Path
import streamlit as st
import pandas as pd
import plotly.express as px
from fpdf import FPDF
from streamlit_agraph import agraph
from mcp.orchestrator import orchestrate_search, answer_ai_question
from mcp.workspace import get_workspace, save_query
from mcp.knowledge_graph import build_agraph
from mcp.graph_metrics import build_nx, get_top_hubs, get_density
from mcp.alerts import check_alerts
ROOT = Path(__file__).parent
LOGO = ROOT / "assets" / "logo.png"
# ---------- utilities ----------
def gen_pdf(papers):
pdf = FPDF(); pdf.add_page(); pdf.set_font("Arial", size=12)
pdf.cell(200, 10, "MedGenesis AI – Results", ln=True, align="C"); pdf.ln(10)
for i, p in enumerate(papers, 1):
pdf.set_font("Arial", "B", 12); pdf.multi_cell(0, 10, f"{i}. {p['title']}")
pdf.set_font("Arial", "", 9)
pdf.multi_cell(0, 7, f"Authors: {p['authors']}\n{p['summary']}\n{p['link']}\n")
pdf.ln(2)
return pdf.output(dest="S").encode("latin-1")
# ---------- UI ----------
def render_ui():
st.set_page_config(page_title="MedGenesis AI", layout="wide")
# πŸ”” Alert check (non-blocking)
saved_qs = [w["query"] for w in get_workspace()]
if saved_qs:
try:
news = asyncio.run(check_alerts(saved_qs))
if news:
with st.sidebar:
st.subheader("πŸ”” New Papers")
for q, links in news.items():
st.write(f"**{q}** – {len(links)} new")
except Exception as e:
st.sidebar.error(f"Alert check error: {e}")
# Workspace sidebar
with st.sidebar:
st.header("πŸ—‚οΈ Workspace")
for i, itm in enumerate(get_workspace(), 1):
with st.expander(f"{i}. {itm['query']}"):
st.write("AI summary:", itm["result"]["ai_summary"])
st.download_button(
"CSV", pd.DataFrame(itm["result"]["papers"]).to_csv(index=False),
f"ws_{i}.csv", "text/csv"
)
if not get_workspace():
st.info("No saved queries.")
# Header
col1, col2 = st.columns([0.15, 0.85])
with col1:
if LOGO.exists(): st.image(str(LOGO), width=100)
with col2:
st.markdown("## 🧬 **MedGenesis AI**")
st.caption("PubMed β€’ ArXiv β€’ OpenFDA β€’ UMLS β€’ NCBI β€’ DisGeNET β€’ ClinicalTrials β€’ GPT-4o")
st.markdown("---")
query = st.text_input("πŸ” Ask a biomedical research question:",
placeholder="e.g. CRISPR glioblastoma treatment")
if st.button("Run Search πŸš€") and query:
with st.spinner("Crunching literature & biomedical databases…"):
res = asyncio.run(orchestrate_search(query))
st.success("Done!")
tabs = st.tabs([
"Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"
])
# --- Results ---
with tabs[0]:
st.header("πŸ“š Top Papers")
for i, p in enumerate(res["papers"], 1):
st.markdown(f"**{i}. [{p['title']}]({p['link']})** – *{p['authors']}*")
st.markdown(f"<span style='color:gray'>{p['summary']}</span>", unsafe_allow_html=True)
if st.button("Save Query"):
save_query(query, res); st.success("Saved to workspace")
csv = pd.DataFrame(res["papers"]).to_csv(index=False)
st.download_button("CSV", csv, "papers.csv", "text/csv")
st.download_button("PDF", gen_pdf(res["papers"]), "papers.pdf", "application/pdf")
st.subheader("🧠 Key UMLS Concepts")
for c in res["umls"]:
if c.get("cui"):
st.write(f"- **{c['name']}** ({c['cui']})")
st.subheader("πŸ’Š Drug Safety (OpenFDA)")
for d in res["drug_safety"]: st.json(d)
st.subheader("πŸ€– AI Synthesis")
st.info(res["ai_summary"])
# --- Genes / Variants ---
with tabs[1]:
st.header("🧬 Gene & Variant Signals")
for g in res["genes"]:
st.write(f"- **{g.get('name', g.get('geneid'))}** – {g.get('description','')}")
if res["gene_disease"]:
st.write("### DisGeNET Links")
st.json(res["gene_disease"][:15])
if res["mesh_defs"]:
st.write("### MeSH Definitions")
for d in res["mesh_defs"]: st.write("-", d)
# --- Clinical Trials ---
with tabs[2]:
st.header("πŸ’Š Registered Clinical Trials")
if not res["clinical_trials"]:
st.info("No trials (API rate-limited or none found).")
for t in res["clinical_trials"]:
st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
st.write(f"Phase: {t.get('Phase', [''])[0]} | Status: {t['OverallStatus'][0]}")
# --- Knowledge Graph ---
with tabs[3]:
st.header("πŸ—ΊοΈ Knowledge Graph")
nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
highlight = st.text_input("Highlight nodes:", key="hl")
if highlight:
pat = re.compile(re.escape(highlight), re.I)
for n in nodes:
if pat.search(n.label): n.color, n.size = "#f1c40f", 30
else: n.color = "#d3d3d3"
agraph(nodes=nodes, edges=edges, config=cfg)
# --- Metrics ---
with tabs[4]:
st.header("πŸ“ˆ Graph Metrics")
import networkx as nx
G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
st.metric("Density", f"{get_density(G):.3f}")
st.markdown("#### Hub Nodes")
for nid, sc in get_top_hubs(G):
lab = next((n.label for n in nodes if n.id == nid), nid)
st.write(f"- **{lab}** – {sc:.3f}")
# --- Visuals ---
with tabs[5]:
yrs = [p["published"] for p in res["papers"] if p.get("published")]
if yrs: st.plotly_chart(px.histogram(yrs, nbins=10, title="Publication Year"))
# --- Follow-up Q&A ---
st.markdown("---")
q = st.text_input("Ask follow-up question:")
if st.button("Ask AI"):
st.write(asyncio.run(answer_ai_question(q, context=query))["answer"])
else:
st.info("Enter a question and press **Run Search πŸš€**")
# Run
if __name__ == "__main__":
render_ui()