File size: 5,305 Bytes
8225d31 113401c 8225d31 113401c 1df70ca 113401c 1df70ca 113401c 8225d31 1df70ca 8225d31 1df70ca 8225d31 113401c 8225d31 1df70ca 113401c 8225d31 1df70ca 8225d31 1df70ca 718c260 113401c 8225d31 1df70ca 8225d31 113401c 8225d31 1df70ca 8225d31 113401c 8225d31 113401c 1df70ca 8225d31 1df70ca 8225d31 1df70ca 8225d31 1df70ca 113401c 8225d31 1df70ca 113401c 1df70ca 113401c 1df70ca 8225d31 113401c 8225d31 1df70ca 8225d31 1df70ca 8225d31 1df70ca 8225d31 1df70ca 8225d31 1df70ca 8225d31 1df70ca 113401c 8225d31 113401c 8225d31 113401c 8225d31 718c260 8225d31 1df70ca 113401c 8225d31 113401c 8225d31 1df70ca 8225d31 1df70ca 8225d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import streamlit as st
import os
from config import (
OPENAI_API_KEY,
GEMINI_API_KEY,
DEFAULT_CHUNK_SIZE,
)
from models import configure_llms, openai_chat, gemini_chat
from pubmed_utils import (
search_pubmed,
fetch_pubmed_abstracts,
chunk_and_summarize
)
from image_pipeline import load_image_model, analyze_image
###############################################################################
# STREAMLIT PAGE CONFIG #
###############################################################################
st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide")
###############################################################################
# INITIALIZE & LOAD MODELS #
###############################################################################
@st.cache_resource
def initialize_app():
"""
Configure LLMs (OpenAI/Gemini) and load the image captioning model once.
"""
configure_llms()
model = load_image_model()
return model
image_model = initialize_app()
###############################################################################
# HELPER: BUILD SYSTEM PROMPT WITH REFS #
###############################################################################
def build_system_prompt_with_refs(pmids, summaries):
"""
Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc.
"""
system_context = "You have access to the following summarized PubMed articles:\n\n"
for idx, pmid in enumerate(pmids, start=1):
ref_label = f"[Ref{idx}]"
system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
system_context += "Use this info to answer the user's question, citing references if needed."
return system_context
###############################################################################
# MAIN APP #
###############################################################################
def main():
st.title("RAG + Image Captioning: Production Demo")
st.markdown("""
This demonstration shows:
1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM.
2. **Image Captioning**: Upload an image for analysis using a known stable model.
""")
# Section A: Image Upload / Caption
st.subheader("Image Captioning")
uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
if uploaded_img:
with st.spinner("Analyzing image..."):
caption = analyze_image(uploaded_img, image_model)
st.image(uploaded_img, use_column_width=True)
st.write("**Caption**:", caption)
st.write("---")
# Section B: PubMed-based RAG
st.subheader("PubMed RAG Pipeline")
user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?")
c1, c2, c3 = st.columns([2,1,1])
with c1:
st.markdown("**Parameters**:")
max_papers = st.slider("Number of Articles", 1, 10, 3)
chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE)
with c2:
llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"])
with c3:
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1)
if st.button("Run RAG Pipeline"):
if not user_query.strip():
st.warning("Please enter a query.")
return
with st.spinner("Searching PubMed..."):
pmids = search_pubmed(user_query, max_papers)
if not pmids:
st.error("No PubMed results. Try a different query.")
return
with st.spinner("Fetching & Summarizing..."):
abstracts_map = fetch_pubmed_abstracts(pmids)
summarized_map = {}
for pmid, text in abstracts_map.items():
if text.startswith("Error:"):
summarized_map[pmid] = text
else:
summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
st.subheader("Retrieved & Summarized PubMed Articles")
for idx, pmid in enumerate(pmids, start=1):
st.markdown(f"**[Ref{idx}] PMID {pmid}**")
st.write(summarized_map[pmid])
st.write("---")
st.subheader("RAG-Enhanced Final Answer")
system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
with st.spinner("Generating LLM response..."):
if llm_choice == "OpenAI: GPT-3.5":
answer = openai_chat(system_prompt, user_query, temperature=temperature)
else:
answer = gemini_chat(system_prompt, user_query, temperature=temperature)
st.write(answer)
st.success("Pipeline Complete.")
st.markdown("---")
st.markdown("""
**Production Tips**:
- Vector DB for advanced retrieval
- Precise citation parsing
- Rate limiting on PubMed
- Multi-lingual expansions
- Logging & monitoring
- Security & privacy compliance
""")
if __name__ == "__main__":
main()
|