mgbam's picture
Update app.py
228cbf8 verified
raw
history blame
19.4 kB
# -*- coding: utf-8 -*-
"""
Streamlit application for Medical Image Analysis using Google Gemini Vision
and Retrieval-Augmented Generation (RAG) with Chroma DB.
Optimized for deployment on Hugging Face Spaces.
"""
# --- Imports ---
import streamlit as st
import google.generativeai as genai
import chromadb
from chromadb.utils import embedding_functions
from chromadb.api.types import EmbeddingFunction # For type hinting
from PIL import Image
import io
import time
import logging
from typing import Optional, Dict, List, Any, Tuple
# --- Basic Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Configuration Constants ---
# Model and API Configuration
GOOGLE_API_KEY_SECRET = "GOOGLE_API_KEY" # Name of the HF Secret
VISION_MODEL_NAME = "gemini-pro-vision"
GENERATION_CONFIG = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 1024,
}
SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
# Chroma DB Configuration
# Using persistent storage within the HF Space (relative path)
# NOTE: Ensure your HF Space has persistent storage enabled if you need data to survive restarts.
CHROMA_PATH = "chroma_data_hf"
COLLECTION_NAME = "medical_docs_v2"
# Embedding Function - Using Default (all-MiniLM-L6-v2).
# For better medical relevance, consider models fine-tuned on biomedical text.
# Examples (might require installing `sentence-transformers` explicitly):
# - 'sentence-transformers/all-MiniLM-L6-v2' (Default, General Purpose)
# - 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext' (Needs adapter usually)
# - 'emilyalsentzer/Bio_ClinicalBERT' (Needs adapter usually)
# Check Sentence Transformers documentation for loading Hugging Face models directly.
# Make sure the model chosen is consistent between indexing and querying.
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # Or specify a different HF model name
CHROMA_DISTANCE_FUNCTION = "cosine" # Use cosine similarity
# UI Configuration
MAX_RAG_RESULTS = 3 # Number of results to fetch from Chroma
# --- Initialization Functions with Caching ---
@st.cache_resource
def configure_google_ai() -> bool:
"""Configures the Google AI SDK using secrets."""
try:
google_api_key = st.secrets[GOOGLE_API_KEY_SECRET]
genai.configure(api_key=google_api_key)
logger.info("Google AI SDK configured successfully.")
return True
except KeyError:
st.error(f"❌ **Error:** '{GOOGLE_API_KEY_SECRET}' not found in Hugging Face Secrets.")
logger.error(f"Secret '{GOOGLE_API_KEY_SECRET}' not found.")
return False
except Exception as e:
st.error(f"❌ **Error:** Failed to configure Google AI SDK: {e}")
logger.error(f"Error configuring Google AI SDK: {e}", exc_info=True)
return False
@st.cache_resource
def get_gemini_model() -> Optional[genai.GenerativeModel]:
"""Initializes and returns the Gemini Generative Model."""
if not configure_google_ai():
return None
try:
model = genai.GenerativeModel(
model_name=VISION_MODEL_NAME,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS
)
logger.info(f"Gemini Model '{VISION_MODEL_NAME}' initialized.")
return model
except Exception as e:
st.error(f"❌ **Error:** Failed to initialize Gemini Model ({VISION_MODEL_NAME}): {e}")
logger.error(f"Error initializing Gemini Model: {e}", exc_info=True)
return None
@st.cache_resource
def get_embedding_function() -> Optional[EmbeddingFunction]:
"""Initializes and returns the embedding function."""
try:
# Using DefaultEmbeddingFunction which leverages sentence-transformers
# Ensure sentence-transformers library is installed
ef = embedding_functions.DefaultEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
logger.info(f"Initialized embedding function with model: {EMBEDDING_MODEL_NAME}")
return ef
except Exception as e:
st.error(f"❌ **Error:** Failed to initialize embedding function ({EMBEDDING_MODEL_NAME}): {e}")
logger.error(f"Error initializing embedding function: {e}", exc_info=True)
return None
@st.cache_resource
def get_chroma_collection() -> Optional[chromadb.Collection]:
"""Initializes ChromaDB client and returns the specified collection."""
embedding_func = get_embedding_function()
if not embedding_func:
return None
try:
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
logger.info(f"ChromaDB client initialized with path: {CHROMA_PATH}")
collection = chroma_client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_func,
metadata={"hnsw:space": CHROMA_DISTANCE_FUNCTION}
)
logger.info(f"ChromaDB collection '{COLLECTION_NAME}' loaded/created.")
return collection
except Exception as e:
st.error(f"❌ **Error:** Failed to initialize Chroma DB collection '{COLLECTION_NAME}': {e}")
st.info(f"ℹ️ Attempted path: '{CHROMA_PATH}'. Ensure write permissions and space.")
logger.error(f"Error initializing Chroma DB: {e}", exc_info=True)
return None
# --- Helper Functions ---
def analyze_image_with_gemini(gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[Optional[str], bool]:
"""
Analyzes image bytes with Gemini Vision.
Args:
gemini_model: The initialized Gemini model instance.
image_bytes: The image data as bytes.
Returns:
A tuple containing:
- The analysis text (str) or None if error/blocked.
- A boolean indicating success (True) or failure/block (False).
"""
try:
img = Image.open(io.BytesIO(image_bytes))
prompt = """Analyze this medical image (e.g., pathology slide, diagram, scan).
Describe key visual features relevant for medical context (structures, cells, staining, anomalies).
Identify potential findings:
- Possible conditions or disease indicators
- Pathological features (morphology, patterns)
- Visible cell types or tissue structures
- Relevant biomarkers (if suggested by visuals)
- Anatomical context (if clear)
Focus on visual evidence. Be concise. Avoid definitive diagnosis. State uncertainties clearly.
"""
response = gemini_model.generate_content([prompt, img], stream=False) # Use stream=False for simpler handling here
response.resolve() # Ensure response is fully processed if stream=True was used
if not response.parts:
reason = "Unknown reason"
if response.prompt_feedback and response.prompt_feedback.block_reason:
reason = response.prompt_feedback.block_reason.name # Get the reason enum name
logger.warning(f"Gemini analysis blocked or empty. Reason: {reason}")
st.warning(f"⚠️ Analysis blocked by safety filters or returned empty. Reason: {reason}")
return None, False
logger.info("Gemini analysis successful.")
return response.text, True
except genai.types.BlockedPromptException as e:
logger.error(f"Gemini analysis blocked due to prompt: {e}")
st.error(f"❌ **Analysis Blocked:** The prompt content triggered safety filters: {e}")
return None, False
except Exception as e:
logger.error(f"Error during Gemini analysis: {e}", exc_info=True)
st.error(f"❌ **Error:** An unexpected error occurred during Gemini analysis: {e}")
return None, False
def query_chroma(collection: chromadb.Collection, query_text: str, n_results: int = 3) -> Optional[Dict[str, List[Any]]]:
"""Queries the Chroma collection."""
if not query_text:
logger.warning("Chroma query attempted with empty text.")
st.warning("⚠️ Cannot query knowledge base without analysis text.")
return None
try:
results = collection.query(
query_texts=[query_text],
n_results=n_results,
include=['documents', 'metadatas', 'distances']
)
logger.info(f"ChromaDB query executed successfully for text: '{query_text[:50]}...'")
return results
except Exception as e:
logger.error(f"Error querying Chroma DB: {e}", exc_info=True)
st.error(f"❌ **Error:** Failed to query the knowledge base: {e}")
return None
# Function to add dummy data (Consider moving to a separate setup script for cleaner app code)
def add_dummy_data_to_chroma(collection: chromadb.Collection):
"""Adds predefined example medical text snippets to the Chroma collection."""
st.info("Attempting to add dummy data to Chroma DB...")
# --- (Same dummy data as before - Keep for demonstration) ---
docs = [
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
"Diagram: EGFR signaling pathway mutations in NSCLC targeted by TKIs.", # Shorter version
"Micrograph: Chronic gastritis with H. pylori organisms (special stain needed). Mild intestinal metaplasia noted.", # Shorter
"Slide CJD-02: Spongiform changes in cerebral cortex characteristic of prion disease. Gliosis present." # Shorter
]
metadatas = [
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"},
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"},
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology", "entities": "EGFR, TKIs, NSCLC, signaling pathway", "IMAGE_ID": "diagram_egfr_pathway.svg"},
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"}
]
# Generate potentially more stable IDs for demo purposes if needed, but time-based is fine too
# Example: ids = [f"dummy_doc_{i+1}" for i in range(len(docs))]
ids = [f"doc_{int(time.time())}_{i}" for i in range(len(docs))]
try:
# Simple check if *any* of these specific texts exist (for demo)
existing_docs = collection.get(where={"$or": [{"document": doc} for doc in docs]}, include=[])
if not existing_docs or not existing_docs.get('ids'):
collection.add(
documents=docs,
metadatas=metadatas,
ids=ids
)
logger.info(f"Added {len(docs)} dummy documents to Chroma collection '{COLLECTION_NAME}'.")
st.success(f"βœ… Added {len(docs)} dummy documents to Chroma collection '{COLLECTION_NAME}'.")
else:
logger.warning("Dummy data check indicates data might already exist. Skipping addition.")
st.warning("⚠️ Dummy data seems to already exist in the collection. No new data added.")
except Exception as e:
logger.error(f"Error adding dummy data to Chroma: {e}", exc_info=True)
st.error(f"❌ **Error:** Could not add dummy data to Chroma: {e}")
# --- Streamlit UI ---
st.set_page_config(layout="wide", page_title="Medical Image RAG - HF", page_icon="βš•οΈ")
st.title("βš•οΈ Medical Image Analysis & RAG")
st.markdown("""
*Powered by Google Gemini, ChromaDB, and Streamlit on Hugging Face Spaces*
""")
# --- CRITICAL DISCLAIMER ---
st.warning("""
**⚠️ Disclaimer:** This tool is for informational and illustrative purposes ONLY.
It is **NOT** a medical device and **CANNOT** provide a diagnosis. AI analysis may be
imperfect or incomplete. **ALWAYS** consult qualified medical professionals for any
health concerns or decisions. Do **NOT** rely solely on this tool for medical judgment.
""")
# --- Initialize Services ---
gemini_model = get_gemini_model()
chroma_collection = get_chroma_collection()
# Check if critical components failed initialization
if not gemini_model or not chroma_collection:
st.error("❌ Critical components failed to initialize. Cannot proceed. Check logs and secrets.")
st.stop() # Stop execution if core components aren't ready
# --- Sidebar Controls ---
with st.sidebar:
st.header("βš™οΈ Controls")
uploaded_file = st.file_uploader(
"1. Upload Medical Image",
type=["jpg", "jpeg", "png", "tiff", "webp"],
help="Upload formats like pathology slides, diagrams, scans."
)
st.divider()
st.header("πŸ“š Knowledge Base")
if st.button("βž• Add Dummy KB Data", help="Add example text data to the Chroma vector database for demonstration."):
if chroma_collection:
add_dummy_data_to_chroma(chroma_collection)
else:
st.error("❌ Chroma DB not available to add data.")
st.info(f"""
**KB Info:**
- **Collection:** `{COLLECTION_NAME}`
- **Storage:** `{CHROMA_PATH}` (in Space storage)
- **Embeddings:** `{EMBEDDING_MODEL_NAME}`
- **Similarity:** `{CHROMA_DISTANCE_FUNCTION}`
""")
st.caption("Note: Data persists if persistent storage is enabled for this Space, otherwise it's temporary.")
# --- Main Processing Area ---
col1, col2 = st.columns(2)
with col1:
st.subheader("πŸ–ΌοΈ Uploaded Image")
if uploaded_file is not None:
image_bytes = uploaded_file.getvalue()
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
else:
st.info("Upload an image using the sidebar to begin analysis.")
with col2:
st.subheader("πŸ€– AI Analysis & Retrieval")
if uploaded_file is not None and gemini_model and chroma_collection:
analysis_text = None
analysis_successful = False
# Step 1: Analyze Image with Gemini
with st.status("🧠 Analyzing image with Gemini Vision...", expanded=False) as status_analysis:
try:
st.write("Sending image to Gemini...")
analysis_text, analysis_successful = analyze_image_with_gemini(gemini_model, image_bytes)
if analysis_successful:
st.write("Analysis complete.")
status_analysis.update(label="βœ… Analysis Complete", state="complete")
else:
# Error/block message already shown by helper function
status_analysis.update(label="⚠️ Analysis Failed or Blocked", state="error")
except Exception as e: # Catch potential unexpected errors here too
logger.error(f"Unhandled error during analysis status block: {e}", exc_info=True)
st.error(f"❌ An unexpected error occurred during the analysis process: {e}")
status_analysis.update(label="πŸ’₯ Analysis Error", state="error")
analysis_successful = False # Ensure flag is False
# Display Analysis Result if successful
if analysis_successful and analysis_text:
st.markdown("**πŸ”¬ Gemini Vision Analysis:**")
st.markdown(analysis_text)
st.divider() # Separator
# Step 2: Query Chroma DB with Analysis Text
st.markdown("**πŸ“š Related Information (RAG via Chroma DB):**")
with st.status("πŸ” Searching knowledge base...", expanded=True) as status_query:
try:
st.write(f"Querying with analysis summary (top {MAX_RAG_RESULTS} results)...")
chroma_results = query_chroma(chroma_collection, analysis_text, n_results=MAX_RAG_RESULTS)
if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]:
num_results = len(chroma_results['documents'][0])
st.write(f"Found {num_results} related entries.")
status_query.update(label=f"βœ… Found {num_results} results", state="complete")
# Display RAG Results
for i in range(num_results):
doc = chroma_results['documents'][0][i]
meta = chroma_results['metadatas'][0][i]
dist = chroma_results['distances'][0][i]
similarity = 1.0 - dist # For cosine distance
expander_title = f"Result {i+1} (Similarity: {similarity:.3f}) - Source: {meta.get('source', 'N/A')}"
with st.expander(expander_title):
st.markdown("**Retrieved Text:**")
st.markdown(f"> {doc}")
st.markdown("**Metadata:**")
# Nicer metadata display
meta_display = {k: v for k, v in meta.items() if v} # Filter empty values
st.json(meta_display, expanded=False)
# Provide link/info if related image exists
if meta.get("IMAGE_ID"):
st.info(f"ℹ️ Associated Visual: `{meta['IMAGE_ID']}`")
elif chroma_results is not None: # Query ran, no results
st.warning("⚠️ No relevant information found in the knowledge base for this analysis.")
status_query.update(label="⚠️ No results found", state="warning")
else: # Query failed (error handled in query_chroma)
status_query.update(label="πŸ’₯ Query Error", state="error")
except Exception as e:
logger.error(f"Unhandled error during query status block: {e}", exc_info=True)
st.error(f"❌ An unexpected error occurred during the knowledge base search: {e}")
status_query.update(label="πŸ’₯ Query Process Error", state="error")
elif not analysis_successful:
st.info("Cannot proceed to knowledge base search as image analysis failed or was blocked.")
elif not uploaded_file:
st.info("Analysis results and related information will appear here once an image is uploaded and processed.")
else:
# This case means initialization failed earlier, message already shown.
st.info("Waiting for components to initialize...")
# --- Footer ---
st.markdown("---")
st.caption("Ensure responsible use. Verify all findings with qualified professionals.")