Medapp / app.py
mgbam's picture
Update app.py
113401c verified
raw
history blame
5.91 kB
import streamlit as st
import os
from config import (
OPENAI_API_KEY,
GEMINI_API_KEY,
DEFAULT_CHUNK_SIZE
)
from models import configure_llms, openai_chat, gemini_chat
from pubmed_utils import (
search_pubmed,
fetch_pubmed_abstracts,
chunk_and_summarize
)
from image_pipeline import load_image_model, analyze_image
###############################################################################
# PAGE CONFIG FIRST #
###############################################################################
st.set_page_config(page_title="RAG + Image: Production Scenario", layout="wide")
###############################################################################
# INITIALIZE & LOAD MODELS #
###############################################################################
def initialize_app():
"""
Configures LLMs, loads image model, etc.
Cache these calls for performance in HF Spaces.
"""
configure_llms() # sets openai.api_key and genai.configure if keys are present
image_model = load_image_model()
return image_model
image_model = initialize_app()
###############################################################################
# HELPER: BUILD SYSTEM PROMPT WITH REFERENCES #
###############################################################################
def build_system_prompt_with_refs(pmids, summaries):
"""
Creates a system prompt that includes references [Ref1], [Ref2], etc.
"""
system_context = "You have access to the following summarized PubMed articles:\n\n"
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
system_context += (
"Use this info to answer the user's question. Cite references as needed."
)
return system_context
###############################################################################
# MAIN APP #
###############################################################################
def main():
st.title("RAG + Image: Production-Ready Medical AI")
st.markdown("""
**Features**:
1. *PubMed RAG Pipeline*: Search, fetch, summarize, then generate a final answer with LLM.
2. *Optional Image Analysis*: Upload an image for a simple caption or interpretive text.
3. *Separation of Concerns*: Each major function is in its own module for maintainability.
**Disclaimer**: Not a substitute for professional medical advice.
""")
# Section A: Image pipeline
st.subheader("Image Analysis")
uploaded_image = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
if uploaded_image:
with st.spinner("Analyzing image..."):
caption = analyze_image(uploaded_image, image_model)
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
st.write("**Model Output:**", caption)
st.write("---")
# Section B: PubMed-based RAG
st.subheader("PubMed Retrieval & Summarization")
user_query = st.text_input("Enter your medical question:", "What are the latest treatments for type 2 diabetes complications?")
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
st.markdown("**Set Pipeline Params**")
max_papers = st.slider("PubMed Articles to Retrieve", 1, 10, 3)
chunk_size = st.slider("Summarization Chunk Size", 256, 1024, DEFAULT_CHUNK_SIZE)
with col2:
selected_llm = st.selectbox("Select LLM", ["OpenAI GPT-3.5", "Gemini PaLM2"])
with col3:
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.1)
if st.button("Run RAG Pipeline"):
if not user_query.strip():
st.warning("Please enter a question.")
return
# 1) PubMed retrieval
with st.spinner("Searching PubMed..."):
pmids = search_pubmed(user_query, max_results=max_papers)
if not pmids:
st.error("No relevant results found. Try a different query.")
return
# 2) Fetch & Summarize
with st.spinner("Fetching & Summarizing abstracts..."):
abs_map = fetch_pubmed_abstracts(pmids)
summarized_map = {}
for pmid, text in abs_map.items():
if text.startswith("Error:"):
summarized_map[pmid] = text
else:
summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
# 3) Display Summaries
st.subheader("Retrieved & Summarized PubMed Articles")
for idx, pmid in enumerate(pmids, start=1):
st.markdown(f"**[Ref{idx}] PMID {pmid}**")
st.write(summarized_map[pmid])
st.write("---")
# 4) Final LLM Answer
st.subheader("RAG-Enhanced Final Answer")
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
with st.spinner("Generating answer..."):
if selected_llm == "OpenAI GPT-3.5":
answer = openai_chat(system_prompt, user_query, temperature=temperature)
else:
answer = gemini_chat(system_prompt, user_query, temperature=temperature)
st.write(answer)
st.success("RAG Pipeline Complete.")
# Production tips
st.markdown("---")
st.markdown("""
### Production Enhancements
- **Vector Database** for advanced retrieval
- **Citation Parsing** for accurate referencing
- **Multi-Lingual** expansions
- **Rate Limiting** for PubMed (max ~3 requests/sec)
- **Robust Logging / Monitoring**
- **Security & Privacy** if patient data is integrated
""")
if __name__ == "__main__":
main()