File size: 5,305 Bytes
8225d31
113401c
8225d31
113401c
1df70ca
 
 
113401c
 
 
1df70ca
 
113401c
 
 
8225d31
 
1df70ca
8225d31
1df70ca
8225d31
 
113401c
8225d31
1df70ca
113401c
8225d31
1df70ca
8225d31
1df70ca
 
 
718c260
113401c
8225d31
 
1df70ca
8225d31
113401c
8225d31
1df70ca
8225d31
113401c
8225d31
 
113401c
1df70ca
8225d31
 
 
1df70ca
8225d31
 
1df70ca
8225d31
 
1df70ca
 
 
113401c
8225d31
1df70ca
 
 
 
113401c
1df70ca
 
 
113401c
 
 
1df70ca
 
 
 
 
 
 
 
 
 
 
 
8225d31
113401c
8225d31
1df70ca
8225d31
1df70ca
8225d31
1df70ca
 
8225d31
1df70ca
8225d31
1df70ca
 
 
8225d31
1df70ca
113401c
 
8225d31
113401c
8225d31
 
 
113401c
8225d31
 
 
718c260
8225d31
1df70ca
 
113401c
8225d31
113401c
8225d31
 
1df70ca
8225d31
 
 
1df70ca
 
 
 
 
 
 
8225d31
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import streamlit as st
import os

from config import (
    OPENAI_API_KEY,
    GEMINI_API_KEY,
    DEFAULT_CHUNK_SIZE,
)
from models import configure_llms, openai_chat, gemini_chat
from pubmed_utils import (
    search_pubmed,
    fetch_pubmed_abstracts,
    chunk_and_summarize
)
from image_pipeline import load_image_model, analyze_image

###############################################################################
#                             STREAMLIT PAGE CONFIG                           #
###############################################################################
st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide")

###############################################################################
#                           INITIALIZE & LOAD MODELS                          #
###############################################################################
@st.cache_resource
def initialize_app():
    """
    Configure LLMs (OpenAI/Gemini) and load the image captioning model once.
    """
    configure_llms()
    model = load_image_model()
    return model

image_model = initialize_app()

###############################################################################
#                    HELPER: BUILD SYSTEM PROMPT WITH REFS                    #
###############################################################################
def build_system_prompt_with_refs(pmids, summaries):
    """
    Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc.
    """
    system_context = "You have access to the following summarized PubMed articles:\n\n"
    for idx, pmid in enumerate(pmids, start=1):
        ref_label = f"[Ref{idx}]"
        system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
    system_context += "Use this info to answer the user's question, citing references if needed."
    return system_context

###############################################################################
#                                MAIN APP                                     #
###############################################################################
def main():
    st.title("RAG + Image Captioning: Production Demo")

    st.markdown("""
    This demonstration shows:
    1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM.
    2. **Image Captioning**: Upload an image for analysis using a known stable model.
    """)

    # Section A: Image Upload / Caption
    st.subheader("Image Captioning")
    uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
    if uploaded_img:
        with st.spinner("Analyzing image..."):
            caption = analyze_image(uploaded_img, image_model)
        st.image(uploaded_img, use_column_width=True)
        st.write("**Caption**:", caption)
        st.write("---")

    # Section B: PubMed-based RAG
    st.subheader("PubMed RAG Pipeline")
    user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?")
    
    c1, c2, c3 = st.columns([2,1,1])
    with c1:
        st.markdown("**Parameters**:")
        max_papers = st.slider("Number of Articles", 1, 10, 3)
        chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE)
    with c2:
        llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"])
    with c3:
        temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1)

    if st.button("Run RAG Pipeline"):
        if not user_query.strip():
            st.warning("Please enter a query.")
            return
        
        with st.spinner("Searching PubMed..."):
            pmids = search_pubmed(user_query, max_papers)
        
        if not pmids:
            st.error("No PubMed results. Try a different query.")
            return
        
        with st.spinner("Fetching & Summarizing..."):
            abstracts_map = fetch_pubmed_abstracts(pmids)
            summarized_map = {}
            for pmid, text in abstracts_map.items():
                if text.startswith("Error:"):
                    summarized_map[pmid] = text
                else:
                    summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)

        st.subheader("Retrieved & Summarized PubMed Articles")
        for idx, pmid in enumerate(pmids, start=1):
            st.markdown(f"**[Ref{idx}] PMID {pmid}**")
            st.write(summarized_map[pmid])
            st.write("---")

        st.subheader("RAG-Enhanced Final Answer")
        system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
        with st.spinner("Generating LLM response..."):
            if llm_choice == "OpenAI: GPT-3.5":
                answer = openai_chat(system_prompt, user_query, temperature=temperature)
            else:
                answer = gemini_chat(system_prompt, user_query, temperature=temperature)

        st.write(answer)
        st.success("Pipeline Complete.")

    st.markdown("---")
    st.markdown("""
    **Production Tips**:
    - Vector DB for advanced retrieval
    - Precise citation parsing
    - Rate limiting on PubMed
    - Multi-lingual expansions
    - Logging & monitoring
    - Security & privacy compliance
    """)

if __name__ == "__main__":
    main()