MCP_Res / app.py
mgbam's picture
Update app.py
d26962d verified
raw
history blame
6.03 kB
# app.py
import asyncio, os, re
from pathlib import Path
import streamlit as st
import pandas as pd
import plotly.express as px
from fpdf import FPDF
from streamlit_agraph import agraph
from mcp.orchestrator import orchestrate_search, answer_ai_question
from mcp.workspace import get_workspace, save_query
from mcp.knowledge_graph import build_agraph
ROOT = Path(__file__).parent
LOGO = ROOT / "assets" / "logo.png"
# ------------------------- helpers -------------------------
def generate_pdf(papers):
pdf = FPDF(); pdf.add_page(); pdf.set_font("Arial", size=12)
pdf.cell(200, 10, "MedGenesis AI - Search Results", ln=True, align="C")
pdf.ln(10)
for i, p in enumerate(papers, 1):
pdf.set_font("Arial", "B", 12)
pdf.multi_cell(0, 10, f"{i}. {p['title']}")
pdf.set_font("Arial", "", 10)
pdf.multi_cell(0, 8,
f"Authors: {p['authors']}\nLink: {p['link']}\nSummary: {p['summary']}\n")
pdf.ln(2)
return pdf.output(dest="S").encode("latin-1")
# ------------------------- UI -------------------------
def render_ui():
st.set_page_config(page_title="MedGenesis AI", layout="wide")
# ----- Sidebar ----------
with st.sidebar:
st.header("πŸ—‚οΈ Workspace")
for i, it in enumerate(get_workspace(), 1):
with st.expander(f"{i}. {it['query']}"):
st.write("**AI Summary:**", it["result"]["ai_summary"])
df = pd.DataFrame(it["result"]["papers"])
st.download_button("CSV", df.to_csv(index=False), f"ws_{i}.csv", "text/csv")
if not get_workspace():
st.info("Run and save searches to populate workspace.")
# ----- Header ----------
col1, col2 = st.columns([0.15, 0.85])
with col1:
if LOGO.exists(): st.image(str(LOGO), width=100)
with col2:
st.markdown("## 🧬 MedGenesis AI")
st.write("*Unified PubMed, ArXiv, OpenFDA, UMLS, NCBI, DisGeNET, ClinicalTrials & GPT-4o*")
st.markdown("---")
query = st.text_input("πŸ” Biomedical research question:", placeholder="e.g. CRISPR glioblastoma")
# ------------- Search -------------
if st.button("Run Search πŸš€") and query:
with st.spinner("Collecting literature & biomedical data…"):
results = asyncio.run(orchestrate_search(query))
st.success("Completed!")
# -------- Tabs ---------
tabs = st.tabs([
"πŸ“ Results", "🧬 Genes & Variants", "πŸ’Š Clinical Trials",
"πŸ—ΊοΈ Graph", "πŸ“Š Visuals"
])
# -- Results ----------
with tabs[0]:
st.header("πŸ“š Top Papers")
for i, p in enumerate(results["papers"], 1):
st.markdown(f"**{i}. [{p['title']}]({p['link']})** \n*{p['authors']}* ({p['source']})")
st.markdown(f"<span style='color:gray'>{p['summary']}</span>", unsafe_allow_html=True)
if st.button("Save to Workspace"):
save_query(query, results); st.success("Saved!")
df = pd.DataFrame(results["papers"])
st.download_button("CSV", df.to_csv(index=False), "results.csv", "text/csv")
st.download_button("PDF", generate_pdf(results["papers"]), "results.pdf", "application/pdf")
st.subheader("🧠 UMLS Concepts")
for c in results["umls"]:
if c.get("cui"):
st.markdown(f"- **{c['name']}** (`{c['cui']}`)")
st.subheader("πŸ’Š Drug Safety (OpenFDA)")
for d in results["drug_safety"]:
st.json(d)
st.subheader("πŸ€– AI Summary")
st.info(results["ai_summary"])
# -- Genes & Variants ----------
with tabs[1]:
st.header("🧬 Gene Associations (NCBI / DisGeNET)")
for g in results["genes"]:
st.write(f"- **{g.get('name', g.get('geneid'))}** – {g.get('description','')}")
if results["gene_disease"]:
st.markdown("#### DisGeNET Disease β†’ Gene links")
st.json(results["gene_disease"][:15])
if results["mesh_definitions"]:
st.markdown("#### MeSH Definitions")
for d in results["mesh_definitions"]:
if d: st.write(f"- {d}")
# -- Clinical Trials ----------
with tabs[2]:
st.header("πŸ’Š Registered Clinical Trials")
for t in results["clinical_trials"]:
st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
st.write(f"Condition: {', '.join(t['Condition'])} | Phase: {t.get('Phase',[None])[0]} | Status: {t['OverallStatus'][0]}")
# -- Graph ----------
with tabs[3]:
search_term = st.text_input("Highlight node containing:", key="graphsearch")
try:
nodes, edges, cfg = build_agraph(results["papers"], results["umls"], results["drug_safety"])
if search_term:
pat = re.compile(re.escape(search_term), re.I)
for n in nodes:
if pat.search(n.label): n.color, n.size = "#f1c40f", max(n.size, 30)
else: n.color = "#ddd"
agraph(nodes=nodes, edges=edges, config=cfg)
except Exception as e:
st.error(f"Graph error: {e}")
# -- Visualizations ----------
with tabs[4]:
yrs = [p["published"] for p in results["papers"] if p.get("published")]
if yrs: st.plotly_chart(px.histogram(yrs, nbins=10, title="Publication Year"))
# -- Follow-up Q&A ----------
st.markdown("---")
q = st.text_input("Ask follow-up:", key="follow")
if st.button("Ask AI"):
ans = asyncio.run(answer_ai_question(q, context=query))
st.write(ans["answer"])
else:
st.info("Enter a question and press **Run Search πŸš€**")
# ------------- Run -------------
if __name__ == "__main__":
render_ui()