|
import streamlit as st |
|
import os |
|
|
|
from config import ( |
|
OPENAI_API_KEY, |
|
OPENAI_DEFAULT_MODEL, |
|
MAX_PUBMED_RESULTS |
|
) |
|
from pubmed_rag import ( |
|
search_pubmed, fetch_pubmed_abstracts, chunk_and_summarize, |
|
upsert_documents, semantic_search |
|
) |
|
from models import chat_with_openai |
|
from image_pipeline import analyze_medical_image |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Advanced Medical AI", layout="wide") |
|
|
|
def main(): |
|
st.title("Advanced Medical AI: Multi-Modal RAG & Image Diagnostics") |
|
|
|
st.markdown(""" |
|
**Features**: |
|
1. **PubMed RAG**: Retrieve and summarize medical literature, store in a vector DB, |
|
and use advanced semantic search for context. |
|
2. **LLM Q&A**: Leverage OpenAI for final question-answering with RAG context. |
|
3. **Medical Image Analysis**: Use `HuggingFaceTB/SmolVLM-500M-Instruct` for diagnostic insights. |
|
4. **Multi-Lingual & Extended Triage**: Placeholder expansions for real-time translation or advanced triage logic. |
|
5. **Production-Ready**: Modular, concurrent, disclaimers, and synergy across tasks. |
|
""") |
|
|
|
menu = ["PubMed RAG Q&A", "Medical Image Analysis", "Semantic Search (Vector DB)"] |
|
choice = st.sidebar.selectbox("Select Task", menu) |
|
|
|
if choice == "PubMed RAG Q&A": |
|
pubmed_rag_qna() |
|
elif choice == "Medical Image Analysis": |
|
medical_image_analysis() |
|
else: |
|
vector_db_search_ui() |
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
**Disclaimer**: This is an **advanced demonstration** for educational or research purposes only. |
|
Always consult a qualified healthcare professional for personal medical decisions. |
|
""") |
|
|
|
def pubmed_rag_qna(): |
|
st.subheader("PubMed Retrieval-Augmented Q&A") |
|
query = st.text_area( |
|
"Ask a medical question (e.g., 'What are the latest treatments for type 2 diabetes?'):", |
|
height=100 |
|
) |
|
max_art = st.slider("Number of PubMed Articles to Retrieve", 1, 10, 5) |
|
|
|
if st.button("Search & Summarize"): |
|
if not query.strip(): |
|
st.warning("Please enter a query.") |
|
return |
|
|
|
with st.spinner("Searching PubMed..."): |
|
pmids = search_pubmed(query, max_art) |
|
|
|
if not pmids: |
|
st.error("No articles found. Try another query.") |
|
return |
|
|
|
with st.spinner("Fetching and Summarizing..."): |
|
raw_abstracts = fetch_pubmed_abstracts(pmids) |
|
|
|
summarized = {} |
|
for pmid, text in raw_abstracts.items(): |
|
if text.startswith("Error"): |
|
summarized[pmid] = text |
|
else: |
|
summary = chunk_and_summarize(text) |
|
summarized[pmid] = summary |
|
|
|
st.subheader("Summaries") |
|
for i, (pmid, summary) in enumerate(summarized.items(), start=1): |
|
st.markdown(f"**[Ref{i}] PMID {pmid}**") |
|
st.write(summary) |
|
|
|
|
|
upsert_documents(summarized) |
|
|
|
|
|
system_prompt = "You are an advanced medical assistant with the following references:\n" |
|
for i, (pmid, summary) in enumerate(summarized.items(), start=1): |
|
system_prompt += f"[Ref{i}] PMID {pmid}: {summary}\n" |
|
system_prompt += "\nUsing these references, provide an evidence-based answer." |
|
|
|
with st.spinner("Generating final answer..."): |
|
final_answer = chat_with_openai(system_prompt, query) |
|
st.subheader("Final Answer") |
|
st.write(final_answer) |
|
|
|
def medical_image_analysis(): |
|
st.subheader("Medical Image Analysis") |
|
uploaded_file = st.file_uploader("Upload a Medical Image (PNG/JPG)", type=["png", "jpg", "jpeg"]) |
|
if uploaded_file is not None: |
|
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) |
|
if st.button("Analyze Image"): |
|
with st.spinner("Analyzing..."): |
|
result = analyze_medical_image(uploaded_file) |
|
st.subheader("Diagnostic Insight") |
|
st.write(result) |
|
|
|
def vector_db_search_ui(): |
|
st.subheader("Semantic Search in Vector DB") |
|
user_query = st.text_input("Enter a query to find relevant documents", "") |
|
top_k = st.slider("Number of results", 1, 10, 3) |
|
if st.button("Search"): |
|
if not user_query.strip(): |
|
st.warning("Please enter a query.") |
|
return |
|
with st.spinner("Performing semantic search..."): |
|
results = semantic_search(user_query, top_k=top_k) |
|
st.subheader("Search Results") |
|
for i, doc in enumerate(results, start=1): |
|
st.markdown(f"**Result {i}** - PMID {doc['pmid']} (Distance: {doc['score']:.4f})") |
|
st.write(doc["text"]) |
|
st.write("---") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|