import streamlit as st import os from config import ( OPENAI_API_KEY, GEMINI_API_KEY, DEFAULT_CHUNK_SIZE, ) from models import configure_llms, openai_chat, gemini_chat from pubmed_utils import ( search_pubmed, fetch_pubmed_abstracts, chunk_and_summarize ) from image_pipeline import load_image_model, analyze_image ############################################################################### # STREAMLIT PAGE CONFIG # ############################################################################### st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide") ############################################################################### # INITIALIZE & LOAD MODELS # ############################################################################### @st.cache_resource def initialize_app(): """ Configure LLMs (OpenAI/Gemini) and load the image captioning model once. """ configure_llms() model = load_image_model() return model image_model = initialize_app() ############################################################################### # HELPER: BUILD SYSTEM PROMPT WITH REFS # ############################################################################### def build_system_prompt_with_refs(pmids, summaries): """ Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc. """ system_context = "You have access to the following summarized PubMed articles:\n\n" for idx, pmid in enumerate(pmids, start=1): ref_label = f"[Ref{idx}]" system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n" system_context += "Use this info to answer the user's question, citing references if needed." return system_context ############################################################################### # MAIN APP # ############################################################################### def main(): st.title("RAG + Image Captioning: Production Demo") st.markdown(""" This demonstration shows: 1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM. 2. **Image Captioning**: Upload an image for analysis using a known stable model. """) # Section A: Image Upload / Caption st.subheader("Image Captioning") uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"]) if uploaded_img: with st.spinner("Analyzing image..."): caption = analyze_image(uploaded_img, image_model) st.image(uploaded_img, use_column_width=True) st.write("**Caption**:", caption) st.write("---") # Section B: PubMed-based RAG st.subheader("PubMed RAG Pipeline") user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?") c1, c2, c3 = st.columns([2,1,1]) with c1: st.markdown("**Parameters**:") max_papers = st.slider("Number of Articles", 1, 10, 3) chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE) with c2: llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"]) with c3: temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1) if st.button("Run RAG Pipeline"): if not user_query.strip(): st.warning("Please enter a query.") return with st.spinner("Searching PubMed..."): pmids = search_pubmed(user_query, max_papers) if not pmids: st.error("No PubMed results. Try a different query.") return with st.spinner("Fetching & Summarizing..."): abstracts_map = fetch_pubmed_abstracts(pmids) summarized_map = {} for pmid, text in abstracts_map.items(): if text.startswith("Error:"): summarized_map[pmid] = text else: summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size) st.subheader("Retrieved & Summarized PubMed Articles") for idx, pmid in enumerate(pmids, start=1): st.markdown(f"**[Ref{idx}] PMID {pmid}**") st.write(summarized_map[pmid]) st.write("---") st.subheader("RAG-Enhanced Final Answer") system_prompt = build_system_prompt_with_refs(pmids, summarized_map) with st.spinner("Generating LLM response..."): if llm_choice == "OpenAI: GPT-3.5": answer = openai_chat(system_prompt, user_query, temperature=temperature) else: answer = gemini_chat(system_prompt, user_query, temperature=temperature) st.write(answer) st.success("Pipeline Complete.") st.markdown("---") st.markdown(""" **Production Tips**: - Vector DB for advanced retrieval - Precise citation parsing - Rate limiting on PubMed - Multi-lingual expansions - Logging & monitoring - Security & privacy compliance """) if __name__ == "__main__": main()