mgbam commited on
Commit
fd2b30f
·
verified ·
1 Parent(s): 5090b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -104
app.py CHANGED
@@ -3,134 +3,126 @@ import os
3
 
4
  from config import (
5
  OPENAI_API_KEY,
6
- GEMINI_API_KEY,
7
- DEFAULT_CHUNK_SIZE,
8
  )
9
- from models import configure_llms, openai_chat, gemini_chat
10
- from pubmed_utils import (
11
- search_pubmed,
12
- fetch_pubmed_abstracts,
13
- chunk_and_summarize
14
  )
15
- from image_pipeline import load_image_model, analyze_image
 
16
 
17
  ###############################################################################
18
- # STREAMLIT PAGE CONFIG #
19
  ###############################################################################
20
- st.set_page_config(page_title="RAG + Image Captioning Demo", layout="wide")
21
 
22
- ###############################################################################
23
- # INITIALIZE & LOAD MODELS #
24
- ###############################################################################
25
- @st.cache_resource
26
- def initialize_app():
27
- """
28
- Configure LLMs (OpenAI/Gemini) and load the image captioning model once.
29
- """
30
- configure_llms()
31
- model = load_image_model()
32
- return model
33
 
34
- image_model = initialize_app()
 
 
 
 
 
 
 
 
35
 
36
- ###############################################################################
37
- # HELPER: BUILD SYSTEM PROMPT WITH REFS #
38
- ###############################################################################
39
- def build_system_prompt_with_refs(pmids, summaries):
40
- """
41
- Creates a system prompt for the LLM that includes references [Ref1], [Ref2], etc.
42
- """
43
- system_context = "You have access to the following summarized PubMed articles:\n\n"
44
- for idx, pmid in enumerate(pmids, start=1):
45
- ref_label = f"[Ref{idx}]"
46
- system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n"
47
- system_context += "Use this info to answer the user's question, citing references if needed."
48
- return system_context
49
 
50
- ###############################################################################
51
- # MAIN APP #
52
- ###############################################################################
53
- def main():
54
- st.title("RAG + Image Captioning: Production Demo")
 
55
 
 
56
  st.markdown("""
57
- This demonstration shows:
58
- 1. **PubMed RAG**: Retrieve abstracts, summarize, and feed them into an LLM.
59
- 2. **Image Captioning**: Upload an image for analysis using a known stable model.
60
  """)
61
 
62
- # Section A: Image Upload / Caption
63
- st.subheader("Image Captioning")
64
- uploaded_img = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"])
65
- if uploaded_img:
66
- with st.spinner("Analyzing image..."):
67
- caption = analyze_image(uploaded_img, image_model)
68
- st.image(uploaded_img, use_column_width=True)
69
- st.write("**Caption**:", caption)
70
- st.write("---")
71
-
72
- # Section B: PubMed-based RAG
73
- st.subheader("PubMed RAG Pipeline")
74
- user_query = st.text_input("Enter a medical question:", "What are the latest treatments for type 2 diabetes?")
75
-
76
- c1, c2, c3 = st.columns([2,1,1])
77
- with c1:
78
- st.markdown("**Parameters**:")
79
- max_papers = st.slider("Number of Articles", 1, 10, 3)
80
- chunk_size = st.slider("Chunk Size", 128, 1024, DEFAULT_CHUNK_SIZE)
81
- with c2:
82
- llm_choice = st.selectbox("Choose LLM", ["OpenAI: GPT-3.5", "Gemini: PaLM2"])
83
- with c3:
84
- temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, step=0.1)
85
-
86
- if st.button("Run RAG Pipeline"):
87
- if not user_query.strip():
88
  st.warning("Please enter a query.")
89
  return
90
-
91
  with st.spinner("Searching PubMed..."):
92
- pmids = search_pubmed(user_query, max_papers)
93
 
94
  if not pmids:
95
- st.error("No PubMed results. Try a different query.")
96
  return
97
-
98
- with st.spinner("Fetching & Summarizing..."):
99
- abstracts_map = fetch_pubmed_abstracts(pmids)
100
- summarized_map = {}
101
- for pmid, text in abstracts_map.items():
102
- if text.startswith("Error:"):
103
- summarized_map[pmid] = text
 
104
  else:
105
- summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size)
 
 
 
 
 
 
106
 
107
- st.subheader("Retrieved & Summarized PubMed Articles")
108
- for idx, pmid in enumerate(pmids, start=1):
109
- st.markdown(f"**[Ref{idx}] PMID {pmid}**")
110
- st.write(summarized_map[pmid])
111
- st.write("---")
112
 
113
- st.subheader("RAG-Enhanced Final Answer")
114
- system_prompt = build_system_prompt_with_refs(pmids, summarized_map)
115
- with st.spinner("Generating LLM response..."):
116
- if llm_choice == "OpenAI: GPT-3.5":
117
- answer = openai_chat(system_prompt, user_query, temperature=temperature)
118
- else:
119
- answer = gemini_chat(system_prompt, user_query, temperature=temperature)
120
 
121
- st.write(answer)
122
- st.success("Pipeline Complete.")
 
 
123
 
124
- st.markdown("---")
125
- st.markdown("""
126
- **Production Tips**:
127
- - Vector DB for advanced retrieval
128
- - Precise citation parsing
129
- - Rate limiting on PubMed
130
- - Multi-lingual expansions
131
- - Logging & monitoring
132
- - Security & privacy compliance
133
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
  main()
 
3
 
4
  from config import (
5
  OPENAI_API_KEY,
6
+ OPENAI_DEFAULT_MODEL,
7
+ MAX_PUBMED_RESULTS
8
  )
9
+ from pubmed_rag import (
10
+ search_pubmed, fetch_pubmed_abstracts, chunk_and_summarize,
11
+ upsert_documents, semantic_search
 
 
12
  )
13
+ from models import chat_with_openai
14
+ from image_pipeline import analyze_medical_image
15
 
16
  ###############################################################################
17
+ # STREAMLIT SETUP #
18
  ###############################################################################
19
+ st.set_page_config(page_title="Advanced Medical AI", layout="wide")
20
 
21
+ def main():
22
+ st.title("Advanced Medical AI: Multi-Modal RAG & Image Diagnostics")
 
 
 
 
 
 
 
 
 
23
 
24
+ st.markdown("""
25
+ **Features**:
26
+ 1. **PubMed RAG**: Retrieve and summarize medical literature, store in a vector DB,
27
+ and use advanced semantic search for context.
28
+ 2. **LLM Q&A**: Leverage OpenAI for final question-answering with RAG context.
29
+ 3. **Medical Image Analysis**: Use `HuggingFaceTB/SmolVLM-500M-Instruct` for diagnostic insights.
30
+ 4. **Multi-Lingual & Extended Triage**: Placeholder expansions for real-time translation or advanced triage logic.
31
+ 5. **Production-Ready**: Modular, concurrent, disclaimers, and synergy across tasks.
32
+ """)
33
 
34
+ menu = ["PubMed RAG Q&A", "Medical Image Analysis", "Semantic Search (Vector DB)"]
35
+ choice = st.sidebar.selectbox("Select Task", menu)
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ if choice == "PubMed RAG Q&A":
38
+ pubmed_rag_qna()
39
+ elif choice == "Medical Image Analysis":
40
+ medical_image_analysis()
41
+ else:
42
+ vector_db_search_ui()
43
 
44
+ st.markdown("---")
45
  st.markdown("""
46
+ **Disclaimer**: This is an **advanced demonstration** for educational or research purposes only.
47
+ Always consult a qualified healthcare professional for personal medical decisions.
 
48
  """)
49
 
50
+ def pubmed_rag_qna():
51
+ st.subheader("PubMed Retrieval-Augmented Q&A")
52
+ query = st.text_area(
53
+ "Ask a medical question (e.g., 'What are the latest treatments for type 2 diabetes?'):",
54
+ height=100
55
+ )
56
+ max_art = st.slider("Number of PubMed Articles to Retrieve", 1, 10, 5)
57
+
58
+ if st.button("Search & Summarize"):
59
+ if not query.strip():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  st.warning("Please enter a query.")
61
  return
62
+
63
  with st.spinner("Searching PubMed..."):
64
+ pmids = search_pubmed(query, max_art)
65
 
66
  if not pmids:
67
+ st.error("No articles found. Try another query.")
68
  return
69
+
70
+ with st.spinner("Fetching and Summarizing..."):
71
+ raw_abstracts = fetch_pubmed_abstracts(pmids)
72
+ # Summarize each
73
+ summarized = {}
74
+ for pmid, text in raw_abstracts.items():
75
+ if text.startswith("Error"):
76
+ summarized[pmid] = text
77
  else:
78
+ summary = chunk_and_summarize(text)
79
+ summarized[pmid] = summary
80
+
81
+ st.subheader("Summaries")
82
+ for i, (pmid, summary) in enumerate(summarized.items(), start=1):
83
+ st.markdown(f"**[Ref{i}] PMID {pmid}**")
84
+ st.write(summary)
85
 
86
+ # Upsert into vector DB
87
+ upsert_documents(summarized) # store raw or summarized texts
 
 
 
88
 
89
+ # Build system prompt
90
+ system_prompt = "You are an advanced medical assistant with the following references:\n"
91
+ for i, (pmid, summary) in enumerate(summarized.items(), start=1):
92
+ system_prompt += f"[Ref{i}] PMID {pmid}: {summary}\n"
93
+ system_prompt += "\nUsing these references, provide an evidence-based answer."
 
 
94
 
95
+ with st.spinner("Generating final answer..."):
96
+ final_answer = chat_with_openai(system_prompt, query)
97
+ st.subheader("Final Answer")
98
+ st.write(final_answer)
99
 
100
+ def medical_image_analysis():
101
+ st.subheader("Medical Image Analysis")
102
+ uploaded_file = st.file_uploader("Upload a Medical Image (PNG/JPG)", type=["png", "jpg", "jpeg"])
103
+ if uploaded_file is not None:
104
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
105
+ if st.button("Analyze Image"):
106
+ with st.spinner("Analyzing..."):
107
+ result = analyze_medical_image(uploaded_file)
108
+ st.subheader("Diagnostic Insight")
109
+ st.write(result)
110
+
111
+ def vector_db_search_ui():
112
+ st.subheader("Semantic Search in Vector DB")
113
+ user_query = st.text_input("Enter a query to find relevant documents", "")
114
+ top_k = st.slider("Number of results", 1, 10, 3)
115
+ if st.button("Search"):
116
+ if not user_query.strip():
117
+ st.warning("Please enter a query.")
118
+ return
119
+ with st.spinner("Performing semantic search..."):
120
+ results = semantic_search(user_query, top_k=top_k)
121
+ st.subheader("Search Results")
122
+ for i, doc in enumerate(results, start=1):
123
+ st.markdown(f"**Result {i}** - PMID {doc['pmid']} (Distance: {doc['score']:.4f})")
124
+ st.write(doc["text"])
125
+ st.write("---")
126
 
127
  if __name__ == "__main__":
128
  main()