|
import streamlit as st |
|
import os |
|
|
|
from config import ( |
|
OPENAI_API_KEY, |
|
GEMINI_API_KEY, |
|
DEFAULT_CHUNK_SIZE |
|
) |
|
from models import configure_llms, openai_chat, gemini_chat |
|
from pubmed_utils import ( |
|
search_pubmed, |
|
fetch_pubmed_abstracts, |
|
chunk_and_summarize |
|
) |
|
from image_pipeline import load_image_model, analyze_image |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="RAG + Image: Production Scenario", layout="wide") |
|
|
|
|
|
|
|
|
|
|
|
def initialize_app(): |
|
""" |
|
Configures LLMs, loads image model, etc. |
|
Cache these calls for performance in HF Spaces. |
|
""" |
|
configure_llms() |
|
image_model = load_image_model() |
|
return image_model |
|
|
|
image_model = initialize_app() |
|
|
|
|
|
|
|
|
|
def build_system_prompt_with_refs(pmids, summaries): |
|
""" |
|
Creates a system prompt that includes references [Ref1], [Ref2], etc. |
|
""" |
|
system_context = "You have access to the following summarized PubMed articles:\n\n" |
|
for idx, pmid in enumerate(pmids, start=1): |
|
ref_label = f"[Ref{idx}]" |
|
system_context += f"{ref_label} (PMID {pmid}): {summaries[pmid]}\n\n" |
|
system_context += ( |
|
"Use this info to answer the user's question. Cite references as needed." |
|
) |
|
return system_context |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("RAG + Image: Production-Ready Medical AI") |
|
|
|
st.markdown(""" |
|
**Features**: |
|
1. *PubMed RAG Pipeline*: Search, fetch, summarize, then generate a final answer with LLM. |
|
2. *Optional Image Analysis*: Upload an image for a simple caption or interpretive text. |
|
3. *Separation of Concerns*: Each major function is in its own module for maintainability. |
|
|
|
**Disclaimer**: Not a substitute for professional medical advice. |
|
""") |
|
|
|
|
|
st.subheader("Image Analysis") |
|
uploaded_image = st.file_uploader("Upload an image (optional)", type=["png", "jpg", "jpeg"]) |
|
if uploaded_image: |
|
with st.spinner("Analyzing image..."): |
|
caption = analyze_image(uploaded_image, image_model) |
|
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) |
|
st.write("**Model Output:**", caption) |
|
st.write("---") |
|
|
|
|
|
st.subheader("PubMed Retrieval & Summarization") |
|
user_query = st.text_input("Enter your medical question:", "What are the latest treatments for type 2 diabetes complications?") |
|
|
|
col1, col2, col3 = st.columns([2, 1, 1]) |
|
with col1: |
|
st.markdown("**Set Pipeline Params**") |
|
max_papers = st.slider("PubMed Articles to Retrieve", 1, 10, 3) |
|
chunk_size = st.slider("Summarization Chunk Size", 256, 1024, DEFAULT_CHUNK_SIZE) |
|
with col2: |
|
selected_llm = st.selectbox("Select LLM", ["OpenAI GPT-3.5", "Gemini PaLM2"]) |
|
with col3: |
|
temperature = st.slider("LLM Temperature", 0.0, 1.0, 0.3, 0.1) |
|
|
|
if st.button("Run RAG Pipeline"): |
|
if not user_query.strip(): |
|
st.warning("Please enter a question.") |
|
return |
|
|
|
|
|
with st.spinner("Searching PubMed..."): |
|
pmids = search_pubmed(user_query, max_results=max_papers) |
|
|
|
if not pmids: |
|
st.error("No relevant results found. Try a different query.") |
|
return |
|
|
|
|
|
with st.spinner("Fetching & Summarizing abstracts..."): |
|
abs_map = fetch_pubmed_abstracts(pmids) |
|
summarized_map = {} |
|
for pmid, text in abs_map.items(): |
|
if text.startswith("Error:"): |
|
summarized_map[pmid] = text |
|
else: |
|
summarized_map[pmid] = chunk_and_summarize(text, chunk_size=chunk_size) |
|
|
|
|
|
st.subheader("Retrieved & Summarized PubMed Articles") |
|
for idx, pmid in enumerate(pmids, start=1): |
|
st.markdown(f"**[Ref{idx}] PMID {pmid}**") |
|
st.write(summarized_map[pmid]) |
|
st.write("---") |
|
|
|
|
|
st.subheader("RAG-Enhanced Final Answer") |
|
system_prompt = build_system_prompt_with_refs(pmids, summarized_map) |
|
with st.spinner("Generating answer..."): |
|
if selected_llm == "OpenAI GPT-3.5": |
|
answer = openai_chat(system_prompt, user_query, temperature=temperature) |
|
else: |
|
answer = gemini_chat(system_prompt, user_query, temperature=temperature) |
|
|
|
st.write(answer) |
|
st.success("RAG Pipeline Complete.") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
### Production Enhancements |
|
- **Vector Database** for advanced retrieval |
|
- **Citation Parsing** for accurate referencing |
|
- **Multi-Lingual** expansions |
|
- **Rate Limiting** for PubMed (max ~3 requests/sec) |
|
- **Robust Logging / Monitoring** |
|
- **Security & Privacy** if patient data is integrated |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|