|
import streamlit as st |
|
import os |
|
|
|
from config import ( |
|
OPENAI_API_KEY, |
|
GEMINI_API_KEY, |
|
DEFAULT_CHUNK_SIZE, |
|
) |
|
from models import configure_llms, openai_chat, gemini_chat |
|
from pubmed_utils import ( |
|
search_pubmed, |
|
fetch_pubmed_abstracts, |
|
chunk_and_summarize |
|
) |
|
from image_pipeline import load_image_model, analyze_image |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide") |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def initialize_app(): |
|
""" |
|
Configure LLMs (OpenAI/Gemini) and load the image captioning model once. |
|
""" |
|
configure_llms() |
|
model = load_image_model() |
|
return model |
|
|
|
image_model = initialize_app() |
|
|
|
|
|
|
|
|
|
def build_system_prompt_with_refs(pmids, summaries): |
|
""" |
|
Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc. |
|
""" |
|
system_context = "You have access to the following summarized PubMed articles:\n\n" |
|
for idx, pmid in enumerate(pmids, start=1): |
|
ref_label = f"[Ref{idx}]" |
|
system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n" |
|
system_context += "Use this info to answer the user's question, citing references if needed." |
|
return system_context |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("RAG + Image Captioning: Production Demo") |
|
|
|
st.markdown(""" |
|
This demonstration shows: |
|
1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM. |
|
2. **Image Captioning**: Upload an image for analysis using a known stable model. |
|
""") |
|
|
|
|
|
st.subheader("Image Captioning") |
|
uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"]) |
|
if uploaded_img: |
|
with st.spinner("Analyzing image..."): |
|
caption = analyze_image(uploaded_img, image_model) |
|
st.image(uploaded_img, use_column_width=True) |
|
st.write("**Caption**:", caption) |
|
st.write("---") |
|
|
|
|
|
st.subheader("PubMed RAG Pipeline") |
|
user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?") |
|
|
|
c1, c2, c3 = st.columns([2,1,1]) |
|
with c1: |
|
st.markdown("**Parameters**:") |
|
max_papers = st.slider("Number of Articles", 1, 10, 3) |
|
chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE) |
|
with c2: |
|
llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"]) |
|
with c3: |
|
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1) |
|
|
|
if st.button("Run RAG Pipeline"): |
|
if not user_query.strip(): |
|
st.warning("Please enter a query.") |
|
return |
|
|
|
with st.spinner("Searching PubMed..."): |
|
pmids = search_pubmed(user_query, max_papers) |
|
|
|
if not pmids: |
|
st.error("No PubMed results. Try a different query.") |
|
return |
|
|
|
with st.spinner("Fetching & Summarizing..."): |
|
abstracts_map = fetch_pubmed_abstracts(pmids) |
|
summarized_map = {} |
|
for pmid, text in abstracts_map.items(): |
|
if text.startswith("Error:"): |
|
summarized_map[pmid] = text |
|
else: |
|
summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size) |
|
|
|
st.subheader("Retrieved & Summarized PubMed Articles") |
|
for idx, pmid in enumerate(pmids, start=1): |
|
st.markdown(f"**[Ref{idx}] PMID {pmid}**") |
|
st.write(summarized_map[pmid]) |
|
st.write("---") |
|
|
|
st.subheader("RAG-Enhanced Final Answer") |
|
system_prompt = build_system_prompt_with_refs(pmids, summarized_map) |
|
with st.spinner("Generating LLM response..."): |
|
if llm_choice == "OpenAI: GPT-3.5": |
|
answer = openai_chat(system_prompt, user_query, temperature=temperature) |
|
else: |
|
answer = gemini_chat(system_prompt, user_query, temperature=temperature) |
|
|
|
st.write(answer) |
|
st.success("Pipeline Complete.") |
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
**Production Tips**: |
|
- Vector DB for advanced retrieval |
|
- Precise citation parsing |
|
- Rate limiting on PubMed |
|
- Multi-lingual expansions |
|
- Logging & monitoring |
|
- Security & privacy compliance |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|