Medapp / app.py
mgbam's picture
Update app.py
4d04338 verified
raw
history blame
5.31 kB
import streamlit as st
import os
from config import (
OPENAI_API_KEY,
GEMINI_API_KEY,
DEFAULT_CHUNK_SIZE,
)
from models import configure_llms, openai_chat, gemini_chat
from pubmed_utils import (
search_pubmed,
fetch_pubmed_abstracts,
chunk_and_summarize
)
from image_pipeline import load_image_model, analyze_image
###############################################################################
# STREAMLIT PAGE CONFIG #
###############################################################################
st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide")
###############################################################################
# INITIALIZE & LOAD MODELS #
###############################################################################
@st.cache_resource
def initialize_app():
"""
Configure LLMs (OpenAI/Gemini) and load the image captioning model once.
"""
configure_llms()
model = load_image_model()
return model
image_model = initialize_app()
###############################################################################
# HELPER: BUILD SYSTEM PROMPT WITH REFS #
###############################################################################
def build_system_prompt_with_refs(pmids, summaries):
"""
Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc.
"""
system_context = "You have access to the following summarized PubMed articles:\n\n"
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
system_context += "Use this info to answer the user's question, citing references if needed."
return system_context
###############################################################################
# MAIN APP #
###############################################################################
def main():
st.title("RAG + Image Captioning: Production Demo")
st.markdown("""
This demonstration shows:
1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM.
2. **Image Captioning**: Upload an image for analysis using a known stable model.
""")
# Section A: Image Upload / Caption
st.subheader("Image Captioning")
uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
if uploaded_img:
with st.spinner("Analyzing image..."):
caption = analyze_image(uploaded_img, image_model)
st.image(uploaded_img, use_column_width=True)
st.write("**Caption**:", caption)
st.write("---")
# Section B: PubMed-based RAG
st.subheader("PubMed RAG Pipeline")
user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?")
c1, c2, c3 = st.columns([2,1,1])
with c1:
st.markdown("**Parameters**:")
max_papers = st.slider("Number of Articles", 1, 10, 3)
chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE)
with c2:
llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"])
with c3:
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1)
if st.button("Run RAG Pipeline"):
if not user_query.strip():
st.warning("Please enter a query.")
return
with st.spinner("Searching PubMed..."):
pmids = search_pubmed(user_query, max_papers)
if not pmids:
st.error("No PubMed results. Try a different query.")
return
with st.spinner("Fetching & Summarizing..."):
abstracts_map = fetch_pubmed_abstracts(pmids)
summarized_map = {}
for pmid, text in abstracts_map.items():
if text.startswith("Error:"):
summarized_map[pmid] = text
else:
summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
st.subheader("Retrieved & Summarized PubMed Articles")
for idx, pmid in enumerate(pmids, start=1):
st.markdown(f"**[Ref{idx}] PMID {pmid}**")
st.write(summarized_map[pmid])
st.write("---")
st.subheader("RAG-Enhanced Final Answer")
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
with st.spinner("Generating LLM response..."):
if llm_choice == "OpenAI: GPT-3.5":
answer = openai_chat(system_prompt, user_query, temperature=temperature)
else:
answer = gemini_chat(system_prompt, user_query, temperature=temperature)
st.write(answer)
st.success("Pipeline Complete.")
st.markdown("---")
st.markdown("""
**Production Tips**:
- Vector DB for advanced retrieval
- Precise citation parsing
- Rate limiting on PubMed
- Multi-lingual expansions
- Logging & monitoring
- Security & privacy compliance
""")
if __name__ == "__main__":
main()