mgbam commited on
Commit
1df70ca
·
verified ·
1 Parent(s): ca7c5dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -69
app.py CHANGED
@@ -2,145 +2,141 @@ import streamlit as st
2
  import os
3
 
4
  from config import (
5
- OPENAI_API_KEY,
6
- GEMINI_API_KEY,
7
- DEFAULT_CHUNK_SIZE
8
  )
9
  from models import configure_llms, openai_chat, gemini_chat
10
  from pubmed_utils import (
11
- search_pubmed,
12
- fetch_pubmed_abstracts,
13
  chunk_and_summarize
14
  )
15
  from image_pipeline import load_image_model, analyze_image
16
 
17
  ###############################################################################
18
- # PAGE CONFIG FIRST #
19
  ###############################################################################
20
- st.set_page_config(page_title="RAG + Image: Production Scenario", layout="wide")
21
 
22
  ###############################################################################
23
  # INITIALIZE & LOAD MODELS #
24
  ###############################################################################
25
-
26
  def initialize_app():
27
  """
28
- Configures LLMs, loads image model, etc.
29
- Cache these calls for performance in HF Spaces.
30
  """
31
- configure_llms() # sets openai.api_key and genai.configure if keys are present
32
- image_model = load_image_model()
33
- return image_model
34
 
35
  image_model = initialize_app()
36
 
37
  ###############################################################################
38
- # HELPER: BUILD SYSTEM PROMPT WITH REFERENCES #
39
  ###############################################################################
40
  def build_system_prompt_with_refs(pmids, summaries):
41
  """
42
- Creates a system prompt that includes references [Ref1], [Ref2], etc.
43
  """
44
  system_context = "You have access to the following summarized PubMed articles:\n\n"
45
  for idx, pmid in enumerate(pmids, start=1):
46
  ref_label = f"[Ref{idx}]"
47
  system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
48
- system_context += (
49
- "Use this info to answer the user's question. Cite references as needed."
50
- )
51
  return system_context
52
 
53
  ###############################################################################
54
- # MAIN APP #
55
  ###############################################################################
56
  def main():
57
- st.title("RAG + Image: Production-Ready Medical AI")
58
 
59
  st.markdown("""
60
- **Features**:
61
- 1. *PubMed RAG Pipeline*: Search, fetch, summarize, then generate a final answer with LLM.
62
- 2. *Optional Image Analysis*: Upload an image for a simple caption or interpretive text.
63
- 3. *Separation of Concerns*: Each major function is in its own module for maintainability.
64
-
65
- **Disclaimer**: Not a substitute for professional medical advice.
 
 
 
 
66
  """)
67
 
68
- # Section A: Image pipeline
69
- st.subheader("Image Analysis")
70
- uploaded_image = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
71
- if uploaded_image:
72
  with st.spinner("Analyzing image..."):
73
- caption = analyze_image(uploaded_image, image_model)
74
- st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
75
- st.write("**Model Output:**", caption)
76
  st.write("---")
77
 
78
  # Section B: PubMed-based RAG
79
- st.subheader("PubMed Retrieval & Summarization")
80
- user_query = st.text_input("Enter your medical question:", "What are the latest treatments for type 2 diabetes complications?")
81
-
82
- col1, col2, col3 = st.columns([2, 1, 1])
83
- with col1:
84
- st.markdown("**Set Pipeline Params**")
85
- max_papers = st.slider("PubMed Articles to Retrieve", 1, 10, 3)
86
- chunk_size = st.slider("Summarization Chunk Size", 256, 1024, DEFAULT_CHUNK_SIZE)
87
- with col2:
88
- selected_llm = st.selectbox("Select LLM", ["OpenAI GPT-3.5", "Gemini PaLM2"])
89
- with col3:
90
- temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.1)
91
 
92
  if st.button("Run RAG Pipeline"):
93
  if not user_query.strip():
94
- st.warning("Please enter a question.")
95
  return
96
-
97
- # 1) PubMed retrieval
98
  with st.spinner("Searching PubMed..."):
99
- pmids = search_pubmed(user_query, max_results=max_papers)
100
-
101
  if not pmids:
102
- st.error("No relevant results found. Try a different query.")
103
  return
104
-
105
- # 2) Fetch & Summarize
106
- with st.spinner("Fetching & Summarizing abstracts..."):
107
- abs_map = fetch_pubmed_abstracts(pmids)
108
  summarized_map = {}
109
- for pmid, text in abs_map.items():
110
  if text.startswith("Error:"):
111
  summarized_map[pmid] = text
112
  else:
113
  summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
114
 
115
- # 3) Display Summaries
116
  st.subheader("Retrieved & Summarized PubMed Articles")
117
  for idx, pmid in enumerate(pmids, start=1):
118
  st.markdown(f"**[Ref{idx}] PMID {pmid}**")
119
  st.write(summarized_map[pmid])
120
  st.write("---")
121
 
122
- # 4) Final LLM Answer
123
  st.subheader("RAG-Enhanced Final Answer")
124
  system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
125
- with st.spinner("Generating answer..."):
126
- if selected_llm == "OpenAI GPT-3.5":
127
  answer = openai_chat(system_prompt, user_query, temperature=temperature)
128
  else:
129
  answer = gemini_chat(system_prompt, user_query, temperature=temperature)
130
 
131
  st.write(answer)
132
- st.success("RAG Pipeline Complete.")
133
 
134
- # Production tips
135
  st.markdown("---")
136
  st.markdown("""
137
- ### Production Enhancements
138
- - **Vector Database** for advanced retrieval
139
- - **Citation Parsing** for accurate referencing
140
- - **Multi-Lingual** expansions
141
- - **Rate Limiting** for PubMed (max ~3 requests/sec)
142
- - **Robust Logging / Monitoring**
143
- - **Security & Privacy** if patient data is integrated
144
  """)
145
 
146
  if __name__ == "__main__":
 
2
  import os
3
 
4
  from config import (
5
+ OPENAI_API_KEY,
6
+ GEMINI_API_KEY,
7
+ DEFAULT_CHUNK_SIZE,
8
  )
9
  from models import configure_llms, openai_chat, gemini_chat
10
  from pubmed_utils import (
11
+ search_pubmed,
12
+ fetch_pubmed_abstracts,
13
  chunk_and_summarize
14
  )
15
  from image_pipeline import load_image_model, analyze_image
16
 
17
  ###############################################################################
18
+ # STREAMLIT PAGE CONFIG #
19
  ###############################################################################
20
+ st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide")
21
 
22
  ###############################################################################
23
  # INITIALIZE & LOAD MODELS #
24
  ###############################################################################
25
+ @st.cache_resource
26
  def initialize_app():
27
  """
28
+ Configure LLMs (OpenAI/Gemini) and load the image captioning model once.
 
29
  """
30
+ configure_llms()
31
+ model = load_image_model()
32
+ return model
33
 
34
  image_model = initialize_app()
35
 
36
  ###############################################################################
37
+ # HELPER: BUILD SYSTEM PROMPT WITH REFS #
38
  ###############################################################################
39
  def build_system_prompt_with_refs(pmids, summaries):
40
  """
41
+ Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc.
42
  """
43
  system_context = "You have access to the following summarized PubMed articles:\n\n"
44
  for idx, pmid in enumerate(pmids, start=1):
45
  ref_label = f"[Ref{idx}]"
46
  system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
47
+ system_context += "Use this info to answer the user's question, citing references if needed."
 
 
48
  return system_context
49
 
50
  ###############################################################################
51
+ # MAIN APP #
52
  ###############################################################################
53
  def main():
54
+ st.title("RAG + Image Captioning: Production Demo")
55
 
56
  st.markdown("""
57
+ This demonstration shows:
58
+ 1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM.
59
+ 2. **Image Captioning**: Upload an image for analysis using a known stable model.
60
+ 3. **Separation of Concerns**: Each pipeline is in its own module.
61
+
62
+ **Note**: If you previously encountered KeyError: 'idefics3', it's because the
63
+ `SmolVLM-500M-Instruct` model was incompatible with your Transformers version.
64
+ Here, we use a supported model such as `nlpconnect/vit-gpt2-image-captioning`.
65
+
66
+ **Disclaimer**: This is a prototype, not medical advice.
67
  """)
68
 
69
+ # Section A: Image Upload / Caption
70
+ st.subheader("Image Captioning")
71
+ uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
72
+ if uploaded_img:
73
  with st.spinner("Analyzing image..."):
74
+ caption = analyze_image(uploaded_img, image_model)
75
+ st.image(uploaded_img, use_column_width=True)
76
+ st.write("**Caption**:", caption)
77
  st.write("---")
78
 
79
  # Section B: PubMed-based RAG
80
+ st.subheader("PubMed RAG Pipeline")
81
+ user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?")
82
+
83
+ c1, c2, c3 = st.columns([2,1,1])
84
+ with c1:
85
+ st.markdown("**Parameters**:")
86
+ max_papers = st.slider("Number of Articles", 1, 10, 3)
87
+ chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE)
88
+ with c2:
89
+ llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"])
90
+ with c3:
91
+ temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1)
92
 
93
  if st.button("Run RAG Pipeline"):
94
  if not user_query.strip():
95
+ st.warning("Please enter a query.")
96
  return
97
+
 
98
  with st.spinner("Searching PubMed..."):
99
+ pmids = search_pubmed(user_query, max_papers)
100
+
101
  if not pmids:
102
+ st.error("No PubMed results. Try a different query.")
103
  return
104
+
105
+ with st.spinner("Fetching & Summarizing..."):
106
+ abstracts_map = fetch_pubmed_abstracts(pmids)
 
107
  summarized_map = {}
108
+ for pmid, text in abstracts_map.items():
109
  if text.startswith("Error:"):
110
  summarized_map[pmid] = text
111
  else:
112
  summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
113
 
 
114
  st.subheader("Retrieved & Summarized PubMed Articles")
115
  for idx, pmid in enumerate(pmids, start=1):
116
  st.markdown(f"**[Ref{idx}] PMID {pmid}**")
117
  st.write(summarized_map[pmid])
118
  st.write("---")
119
 
 
120
  st.subheader("RAG-Enhanced Final Answer")
121
  system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
122
+ with st.spinner("Generating LLM response..."):
123
+ if llm_choice == "OpenAI: GPT-3.5":
124
  answer = openai_chat(system_prompt, user_query, temperature=temperature)
125
  else:
126
  answer = gemini_chat(system_prompt, user_query, temperature=temperature)
127
 
128
  st.write(answer)
129
+ st.success("Pipeline Complete.")
130
 
 
131
  st.markdown("---")
132
  st.markdown("""
133
+ **Production Tips**:
134
+ - Vector DB for advanced retrieval
135
+ - Precise citation parsing
136
+ - Rate limiting on PubMed
137
+ - Multi-lingual expansions
138
+ - Logging & monitoring
139
+ - Security & privacy compliance
140
  """)
141
 
142
  if __name__ == "__main__":