import streamlit as st import os from config import ( OPENAI_API_KEY, OPENAI_DEFAULT_MODEL, MAX_PUBMED_RESULTS ) from pubmed_rag import ( search_pubmed, fetch_pubmed_abstracts, chunk_and_summarize, upsert_documents, semantic_search ) from models import chat_with_openai from image_pipeline import analyze_medical_image ############################################################################### # STREAMLIT SETUP # ############################################################################### st.set_page_config(page_title="Advanced Medical AI", layout="wide") def main(): st.title("Advanced Medical AI: Multi-Modal RAG & Image Diagnostics") st.markdown(""" **Features**: 1. **PubMed RAG**: Retrieve and summarize medical literature, store in a vector DB, and use advanced semantic search for context. 2. **LLM Q&A**: Leverage OpenAI for final question-answering with RAG context. 3. **Medical Image Analysis**: Use `HuggingFaceTB/SmolVLM-500M-Instruct` for diagnostic insights. 4. **Multi-Lingual & Extended Triage**: Placeholder expansions for real-time translation or advanced triage logic. 5. **Production-Ready**: Modular, concurrent, disclaimers, and synergy across tasks. """) menu = ["PubMed RAG Q&A", "Medical Image Analysis", "Semantic Search (Vector DB)"] choice = st.sidebar.selectbox("Select Task", menu) if choice == "PubMed RAG Q&A": pubmed_rag_qna() elif choice == "Medical Image Analysis": medical_image_analysis() else: vector_db_search_ui() st.markdown("---") st.markdown(""" **Disclaimer**: This is an **advanced demonstration** for educational or research purposes only. Always consult a qualified healthcare professional for personal medical decisions. """) def pubmed_rag_qna(): st.subheader("PubMed Retrieval-Augmented Q&A") query = st.text_area( "Ask a medical question (e.g., 'What are the latest treatments for type 2 diabetes?'):", height=100 ) max_art = st.slider("Number of PubMed Articles to Retrieve", 1, 10, 5) if st.button("Search & Summarize"): if not query.strip(): st.warning("Please enter a query.") return with st.spinner("Searching PubMed..."): pmids = search_pubmed(query, max_art) if not pmids: st.error("No articles found. Try another query.") return with st.spinner("Fetching and Summarizing..."): raw_abstracts = fetch_pubmed_abstracts(pmids) # Summarize each summarized = {} for pmid, text in raw_abstracts.items(): if text.startswith("Error"): summarized[pmid] = text else: summary = chunk_and_summarize(text) summarized[pmid] = summary st.subheader("Summaries") for i, (pmid, summary) in enumerate(summarized.items(), start=1): st.markdown(f"**[Ref{i}] PMID {pmid}**") st.write(summary) # Upsert into vector DB upsert_documents(summarized) # store raw or summarized texts # Build system prompt system_prompt = "You are an advanced medical assistant with the following references:\n" for i, (pmid, summary) in enumerate(summarized.items(), start=1): system_prompt += f"[Ref{i}] PMID {pmid}: {summary}\n" system_prompt += "\nUsing these references, provide an evidence-based answer." with st.spinner("Generating final answer..."): final_answer = chat_with_openai(system_prompt, query) st.subheader("Final Answer") st.write(final_answer) def medical_image_analysis(): st.subheader("Medical Image Analysis") uploaded_file = st.file_uploader("Upload a Medical Image (PNG/JPG)", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) if st.button("Analyze Image"): with st.spinner("Analyzing..."): result = analyze_medical_image(uploaded_file) st.subheader("Diagnostic Insight") st.write(result) def vector_db_search_ui(): st.subheader("Semantic Search in Vector DB") user_query = st.text_input("Enter a query to find relevant documents", "") top_k = st.slider("Number of results", 1, 10, 3) if st.button("Search"): if not user_query.strip(): st.warning("Please enter a query.") return with st.spinner("Performing semantic search..."): results = semantic_search(user_query, top_k=top_k) st.subheader("Search Results") for i, doc in enumerate(results, start=1): st.markdown(f"**Result {i}** - PMID {doc['pmid']} (Distance: {doc['score']:.4f})") st.write(doc["text"]) st.write("---") if __name__ == "__main__": main()