MCP_Res / app.py
mgbam's picture
Update app.py
c4bf66f verified
raw
history blame
3.59 kB
# app.py
import asyncio, re
from pathlib import Path
import streamlit as st
import pandas as pd
import plotly.express as px
from fpdf import FPDF
from streamlit_agraph import agraph
from mcp.orchestrator import orchestrate_search, answer_ai_question
from mcp.knowledge_graph import build_agraph
from mcp.graph_metrics import build_nx, get_top_hubs, get_density
st.set_page_config(layout="wide", page_title="MedGenesis AI")
if "res" not in st.session_state:
st.session_state.res = None
st.title("🧬 MedGenesis AI")
llm = st.radio("LLM engine", ["openai","gemini"], horizontal=True)
query= st.text_input("Enter biomedical question")
def _make_pdf(papers):
pdf = FPDF(); pdf.add_page(); pdf.set_font("Helvetica",size=12)
pdf.cell(0,10,"MedGenesis AI – Results",ln=True,align="C"); pdf.ln(5)
for i,p in enumerate(papers,1):
pdf.set_font("Helvetica","B",11); pdf.multi_cell(0,7,f"{i}. {p.get('title','')}")
pdf.set_font("Helvetica",size=9)
body = f"{p.get('authors','')}\n{p.get('summary','')}\n{p.get('link','')}"
pdf.multi_cell(0,6,body); pdf.ln(3)
return pdf.output(dest="S").encode("latin-1",errors="replace")
if st.button("Run Search πŸš€") and query:
with st.spinner("Gathering data…"):
st.session_state.res = asyncio.run(orchestrate_search(query, llm))
res = st.session_state.res
if not res:
st.info("Enter a query and press Run Search")
st.stop()
# ── Results tab
tabs = st.tabs(["Results","Graph","Variants","Trials","Metrics","Visuals"])
with tabs[0]:
for i,p in enumerate(res["papers"],1):
st.markdown(f"**{i}. [{p['title']}]({p['link']})**")
st.write(p["summary"])
c1,c2 = st.columns(2)
c1.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False),
"papers.csv","text/csv")
c2.download_button("PDF", _make_pdf(res["papers"]),
"papers.pdf","application/pdf")
st.subheader("AI summary"); st.info(res["ai_summary"])
# ── Graph tab
with tabs[1]:
nodes,edges,cfg = build_agraph(
res["papers"], res["umls"], res["drug_safety"], res["umls_relations"]
)
hl = st.text_input("Highlight node:", key="hl")
if hl:
pat = re.compile(re.escape(hl), re.I)
for n in nodes:
n.color = "#f1c40f" if pat.search(n.label) else n.color
agraph(nodes, edges, cfg)
# ── Variants tab
with tabs[2]:
if res["variants"]:
st.json(res["variants"])
else:
st.warning("No variants found. Try β€˜TP53’ or β€˜BRCA1’.")
# ── Trials tab
with tabs[3]:
if res["clinical_trials"]:
st.json(res["clinical_trials"])
else:
st.warning("No trials found. Try a disease or drug.")
# ── Metrics tab
with tabs[4]:
G = build_nx([n.__dict__ for n in nodes],[e.__dict__ for e in edges])
st.metric("Density", f"{get_density(G):.3f}")
st.markdown("**Top hubs**")
for nid,sc in get_top_hubs(G):
lbl = next((n.label for n in nodes if n.id==nid), nid)
st.write(f"- {lbl}: {sc:.3f}")
# ── Visuals tab
with tabs[5]:
yrs = [p.get("published","")[:4] for p in res["papers"] if p.get("published")]
if yrs:
st.plotly_chart(px.histogram(yrs,nbins=10,title="Publication Year"))
# ── Follow-up QA
st.markdown("---")
q = st.text_input("Ask follow-up question:", key="followup_input")
if st.button("Ask AI"):
with st.spinner("Querying LLM…"):
ans = asyncio.run(answer_ai_question(
q, context=res["ai_summary"], llm=llm))
st.write(ans["answer"])