Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,8 +1,9 @@ | |
| 1 | 
             
            #!/usr/bin/env python3
         | 
| 2 | 
             
            # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
         | 
| 3 |  | 
| 4 | 
            -
            # ── Streamlit telemetry dir fix ───────────────────────────────────────
         | 
| 5 | 
             
            import os, pathlib
         | 
|  | |
|  | |
| 6 | 
             
            os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
         | 
| 7 | 
             
            os.environ["XDG_STATE_HOME"] = "/tmp"
         | 
| 8 | 
             
            os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
         | 
| @@ -14,7 +15,7 @@ from pathlib import Path | |
| 14 | 
             
            import streamlit as st
         | 
| 15 | 
             
            import pandas as pd
         | 
| 16 | 
             
            import plotly.express as px
         | 
| 17 | 
            -
            from fpdf import FPDF | 
| 18 | 
             
            from streamlit_agraph import agraph
         | 
| 19 |  | 
| 20 | 
             
            # ── Internal helpers ────────────────────────────────────────────────
         | 
| @@ -29,7 +30,6 @@ LOGO = ROOT / "assets" / "logo.png" | |
| 29 |  | 
| 30 | 
             
            # ── PDF export helper (UTF-8 → Latin-1 “safe”) ──────────────────────
         | 
| 31 | 
             
            def _latin1_safe(txt: str) -> str:
         | 
| 32 | 
            -
                """Return text that FPDF(latin-1) can embed; replace unknown chars."""
         | 
| 33 | 
             
                return txt.encode("latin-1", "replace").decode("latin-1")
         | 
| 34 |  | 
| 35 | 
             
            def _pdf(papers):
         | 
| @@ -71,7 +71,6 @@ def render_ui(): | |
| 71 | 
             
                st.set_page_config("MedGenesis AI", layout="wide")
         | 
| 72 | 
             
                _workspace_sidebar()
         | 
| 73 |  | 
| 74 | 
            -
                # Header
         | 
| 75 | 
             
                c1, c2 = st.columns([0.15, 0.85])
         | 
| 76 | 
             
                with c1:
         | 
| 77 | 
             
                    if LOGO.exists():
         | 
| @@ -81,8 +80,7 @@ def render_ui(): | |
| 81 | 
             
                    st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
         | 
| 82 |  | 
| 83 | 
             
                llm   = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
         | 
| 84 | 
            -
                query = st.text_input("Enter biomedical question",
         | 
| 85 | 
            -
                                      placeholder="e.g. CRISPR glioblastoma therapy")
         | 
| 86 |  | 
| 87 | 
             
                # Alert check
         | 
| 88 | 
             
                if get_workspace():
         | 
| @@ -102,10 +100,8 @@ def render_ui(): | |
| 102 | 
             
                        res = asyncio.run(orchestrate_search(query, llm=llm))
         | 
| 103 | 
             
                    st.success(f"Completed with **{res['llm_used'].title()}**")
         | 
| 104 |  | 
| 105 | 
            -
                    tabs = st.tabs(["Results", "Genes", "Trials", "Graph",
         | 
| 106 | 
            -
                                    "Metrics", "Visuals"])
         | 
| 107 |  | 
| 108 | 
            -
                    # Results
         | 
| 109 | 
             
                    with tabs[0]:
         | 
| 110 | 
             
                        for i, p in enumerate(res["papers"], 1):
         | 
| 111 | 
             
                            st.markdown(f"**{i}. [{p['title']}]({p['link']})**  *{p['authors']}*")
         | 
| @@ -113,12 +109,9 @@ def render_ui(): | |
| 113 |  | 
| 114 | 
             
                        col1, col2 = st.columns(2)
         | 
| 115 | 
             
                        with col1:
         | 
| 116 | 
            -
                            st.download_button("CSV",
         | 
| 117 | 
            -
                                               pd.DataFrame(res["papers"]).to_csv(index=False),
         | 
| 118 | 
            -
                                               "papers.csv", "text/csv")
         | 
| 119 | 
             
                        with col2:
         | 
| 120 | 
            -
                            st.download_button("PDF", _pdf(res["papers"]),
         | 
| 121 | 
            -
                                               "papers.pdf", "application/pdf")
         | 
| 122 |  | 
| 123 | 
             
                        if st.button("💾 Save"):
         | 
| 124 | 
             
                            save_query(query, res)
         | 
| @@ -136,12 +129,10 @@ def render_ui(): | |
| 136 | 
             
                        st.subheader("AI summary")
         | 
| 137 | 
             
                        st.info(res["ai_summary"])
         | 
| 138 |  | 
| 139 | 
            -
                    # Genes
         | 
| 140 | 
             
                    with tabs[1]:
         | 
| 141 | 
             
                        st.header("Gene / Variant signals")
         | 
| 142 | 
             
                        for g in res["genes"]:
         | 
| 143 | 
            -
                            st.write(f"- **{g.get('name', g.get('geneid'))}** "
         | 
| 144 | 
            -
                                     f"{g.get('description', '')}")
         | 
| 145 | 
             
                        if res["gene_disease"]:
         | 
| 146 | 
             
                            st.markdown("### DisGeNET links")
         | 
| 147 | 
             
                            st.json(res["gene_disease"][:15])
         | 
| @@ -151,21 +142,16 @@ def render_ui(): | |
| 151 | 
             
                                if d:
         | 
| 152 | 
             
                                    st.write("-", d)
         | 
| 153 |  | 
| 154 | 
            -
                    # Trials
         | 
| 155 | 
             
                    with tabs[2]:
         | 
| 156 | 
             
                        st.header("Clinical trials")
         | 
| 157 | 
             
                        if not res["clinical_trials"]:
         | 
| 158 | 
             
                            st.info("No trials (rate-limited or none found).")
         | 
| 159 | 
             
                        for t in res["clinical_trials"]:
         | 
| 160 | 
             
                            st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
         | 
| 161 | 
            -
                            st.write(f"Phase {t.get('Phase', | 
| 162 | 
            -
                                     f"| Status {t['OverallStatus'][0]}")
         | 
| 163 |  | 
| 164 | 
            -
                    # Graph
         | 
| 165 | 
             
                    with tabs[3]:
         | 
| 166 | 
            -
                        nodes, edges, cfg = build_agraph(res["papers"],
         | 
| 167 | 
            -
                                                         res["umls"],
         | 
| 168 | 
            -
                                                         res["drug_safety"])
         | 
| 169 | 
             
                        hl = st.text_input("Highlight node:", key="hl")
         | 
| 170 | 
             
                        if hl:
         | 
| 171 | 
             
                            pat = re.compile(re.escape(hl), re.I)
         | 
| @@ -173,31 +159,29 @@ def render_ui(): | |
| 173 | 
             
                                n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
         | 
| 174 | 
             
                        agraph(nodes, edges, cfg)
         | 
| 175 |  | 
| 176 | 
            -
                    # Metrics
         | 
| 177 | 
             
                    with tabs[4]:
         | 
| 178 | 
            -
                        G = build_nx([n.__dict__ for n in nodes],
         | 
| 179 | 
            -
                                     [e.__dict__ for e in edges])
         | 
| 180 | 
             
                        st.metric("Density", f"{get_density(G):.3f}")
         | 
| 181 | 
             
                        st.markdown("**Top hubs**")
         | 
| 182 | 
             
                        for nid, sc in get_top_hubs(G):
         | 
| 183 | 
             
                            lab = next((n.label for n in nodes if n.id == nid), nid)
         | 
| 184 | 
             
                            st.write(f"- {lab}  {sc:.3f}")
         | 
| 185 |  | 
| 186 | 
            -
                    # Visuals
         | 
| 187 | 
             
                    with tabs[5]:
         | 
| 188 | 
             
                        years = [p["published"] for p in res["papers"] if p.get("published")]
         | 
| 189 | 
             
                        if years:
         | 
| 190 | 
            -
                            st.plotly_chart(px.histogram(years, nbins=12,
         | 
| 191 | 
            -
                                                         title="Publication Year"))
         | 
| 192 |  | 
| 193 | 
            -
                    # Follow-up Q-A
         | 
| 194 | 
             
                    st.markdown("---")
         | 
| 195 | 
            -
                    follow = st.text_input("Ask follow-up:")
         | 
| 196 | 
             
                    if st.button("Ask AI"):
         | 
| 197 | 
            -
                         | 
| 198 | 
            -
             | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
|  | |
|  | |
| 201 |  | 
| 202 | 
             
                else:
         | 
| 203 | 
             
                    st.info("Enter a question and press **Run Search 🚀**")
         | 
|  | |
| 1 | 
             
            #!/usr/bin/env python3
         | 
| 2 | 
             
            # MedGenesis AI · CPU-only Streamlit app (OpenAI / Gemini)
         | 
| 3 |  | 
|  | |
| 4 | 
             
            import os, pathlib
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # ── Streamlit telemetry dir fix ───────────────────────────────────────
         | 
| 7 | 
             
            os.environ["STREAMLIT_DATA_DIR"] = "/tmp/.streamlit"
         | 
| 8 | 
             
            os.environ["XDG_STATE_HOME"] = "/tmp"
         | 
| 9 | 
             
            os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
         | 
|  | |
| 15 | 
             
            import streamlit as st
         | 
| 16 | 
             
            import pandas as pd
         | 
| 17 | 
             
            import plotly.express as px
         | 
| 18 | 
            +
            from fpdf import FPDF
         | 
| 19 | 
             
            from streamlit_agraph import agraph
         | 
| 20 |  | 
| 21 | 
             
            # ── Internal helpers ────────────────────────────────────────────────
         | 
|  | |
| 30 |  | 
| 31 | 
             
            # ── PDF export helper (UTF-8 → Latin-1 “safe”) ──────────────────────
         | 
| 32 | 
             
            def _latin1_safe(txt: str) -> str:
         | 
|  | |
| 33 | 
             
                return txt.encode("latin-1", "replace").decode("latin-1")
         | 
| 34 |  | 
| 35 | 
             
            def _pdf(papers):
         | 
|  | |
| 71 | 
             
                st.set_page_config("MedGenesis AI", layout="wide")
         | 
| 72 | 
             
                _workspace_sidebar()
         | 
| 73 |  | 
|  | |
| 74 | 
             
                c1, c2 = st.columns([0.15, 0.85])
         | 
| 75 | 
             
                with c1:
         | 
| 76 | 
             
                    if LOGO.exists():
         | 
|  | |
| 80 | 
             
                    st.caption("Multi-source biomedical assistant · OpenAI / Gemini")
         | 
| 81 |  | 
| 82 | 
             
                llm   = st.radio("LLM engine", ["openai", "gemini"], horizontal=True)
         | 
| 83 | 
            +
                query = st.text_input("Enter biomedical question", placeholder="e.g. CRISPR glioblastoma therapy")
         | 
|  | |
| 84 |  | 
| 85 | 
             
                # Alert check
         | 
| 86 | 
             
                if get_workspace():
         | 
|  | |
| 100 | 
             
                        res = asyncio.run(orchestrate_search(query, llm=llm))
         | 
| 101 | 
             
                    st.success(f"Completed with **{res['llm_used'].title()}**")
         | 
| 102 |  | 
| 103 | 
            +
                    tabs = st.tabs(["Results", "Genes", "Trials", "Graph", "Metrics", "Visuals"])
         | 
|  | |
| 104 |  | 
|  | |
| 105 | 
             
                    with tabs[0]:
         | 
| 106 | 
             
                        for i, p in enumerate(res["papers"], 1):
         | 
| 107 | 
             
                            st.markdown(f"**{i}. [{p['title']}]({p['link']})**  *{p['authors']}*")
         | 
|  | |
| 109 |  | 
| 110 | 
             
                        col1, col2 = st.columns(2)
         | 
| 111 | 
             
                        with col1:
         | 
| 112 | 
            +
                            st.download_button("CSV", pd.DataFrame(res["papers"]).to_csv(index=False), "papers.csv", "text/csv")
         | 
|  | |
|  | |
| 113 | 
             
                        with col2:
         | 
| 114 | 
            +
                            st.download_button("PDF", _pdf(res["papers"]), "papers.pdf", "application/pdf")
         | 
|  | |
| 115 |  | 
| 116 | 
             
                        if st.button("💾 Save"):
         | 
| 117 | 
             
                            save_query(query, res)
         | 
|  | |
| 129 | 
             
                        st.subheader("AI summary")
         | 
| 130 | 
             
                        st.info(res["ai_summary"])
         | 
| 131 |  | 
|  | |
| 132 | 
             
                    with tabs[1]:
         | 
| 133 | 
             
                        st.header("Gene / Variant signals")
         | 
| 134 | 
             
                        for g in res["genes"]:
         | 
| 135 | 
            +
                            st.write(f"- **{g.get('name', g.get('geneid'))}** {g.get('description', '')}")
         | 
|  | |
| 136 | 
             
                        if res["gene_disease"]:
         | 
| 137 | 
             
                            st.markdown("### DisGeNET links")
         | 
| 138 | 
             
                            st.json(res["gene_disease"][:15])
         | 
|  | |
| 142 | 
             
                                if d:
         | 
| 143 | 
             
                                    st.write("-", d)
         | 
| 144 |  | 
|  | |
| 145 | 
             
                    with tabs[2]:
         | 
| 146 | 
             
                        st.header("Clinical trials")
         | 
| 147 | 
             
                        if not res["clinical_trials"]:
         | 
| 148 | 
             
                            st.info("No trials (rate-limited or none found).")
         | 
| 149 | 
             
                        for t in res["clinical_trials"]:
         | 
| 150 | 
             
                            st.markdown(f"**{t['NCTId'][0]}** – {t['BriefTitle'][0]}")
         | 
| 151 | 
            +
                            st.write(f"Phase {t.get('Phase',[''])[0]} | Status {t['OverallStatus'][0]}")
         | 
|  | |
| 152 |  | 
|  | |
| 153 | 
             
                    with tabs[3]:
         | 
| 154 | 
            +
                        nodes, edges, cfg = build_agraph(res["papers"], res["umls"], res["drug_safety"])
         | 
|  | |
|  | |
| 155 | 
             
                        hl = st.text_input("Highlight node:", key="hl")
         | 
| 156 | 
             
                        if hl:
         | 
| 157 | 
             
                            pat = re.compile(re.escape(hl), re.I)
         | 
|  | |
| 159 | 
             
                                n.color = "#f1c40f" if pat.search(n.label) else "#d3d3d3"
         | 
| 160 | 
             
                        agraph(nodes, edges, cfg)
         | 
| 161 |  | 
|  | |
| 162 | 
             
                    with tabs[4]:
         | 
| 163 | 
            +
                        G = build_nx([n.__dict__ for n in nodes], [e.__dict__ for e in edges])
         | 
|  | |
| 164 | 
             
                        st.metric("Density", f"{get_density(G):.3f}")
         | 
| 165 | 
             
                        st.markdown("**Top hubs**")
         | 
| 166 | 
             
                        for nid, sc in get_top_hubs(G):
         | 
| 167 | 
             
                            lab = next((n.label for n in nodes if n.id == nid), nid)
         | 
| 168 | 
             
                            st.write(f"- {lab}  {sc:.3f}")
         | 
| 169 |  | 
|  | |
| 170 | 
             
                    with tabs[5]:
         | 
| 171 | 
             
                        years = [p["published"] for p in res["papers"] if p.get("published")]
         | 
| 172 | 
             
                        if years:
         | 
| 173 | 
            +
                            st.plotly_chart(px.histogram(years, nbins=12, title="Publication Year"))
         | 
|  | |
| 174 |  | 
| 175 | 
            +
                    # ── Follow-up Q-A (fixed) ───────────────────────────────────────
         | 
| 176 | 
             
                    st.markdown("---")
         | 
| 177 | 
            +
                    follow = st.text_input("Ask follow-up question:", key="followup_input")  # ✅ UPDATED
         | 
| 178 | 
             
                    if st.button("Ask AI"):
         | 
| 179 | 
            +
                        if follow.strip():  # ✅ UPDATED
         | 
| 180 | 
            +
                            with st.spinner("Generating AI response..."):
         | 
| 181 | 
            +
                                ans = asyncio.run(answer_ai_question(follow, context=query, llm=llm))
         | 
| 182 | 
            +
                            st.write(ans["answer"])
         | 
| 183 | 
            +
                        else:
         | 
| 184 | 
            +
                            st.warning("Please type a follow-up question before submitting.")  # ✅ UPDATED
         | 
| 185 |  | 
| 186 | 
             
                else:
         | 
| 187 | 
             
                    st.info("Enter a question and press **Run Search 🚀**")
         | 
