Spaces:
Sleeping
Sleeping
# --- Imports --- | |
import streamlit as st | |
import google.generativeai as genai | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from PIL import Image | |
import io | |
import time | |
import logging | |
from typing import Optional, Dict, List, Any, Tuple | |
# --- Set Page Config FIRST --- | |
# This MUST be the first Streamlit command executed in the script. | |
st.set_page_config(layout="wide", page_title="Medical Image Analysis & RAG (HF/BioBERT)") | |
# --- Basic Logging Setup --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Application Configuration --- | |
# Secrets Management (Prioritize Hugging Face Secrets) | |
try: | |
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"] | |
HF_TOKEN = st.secrets.get("HF_TOKEN") # Use .get() for optional token | |
except KeyError as e: | |
err_msg = f"β Missing Secret: {e}. Please add it to your Hugging Face Space secrets." | |
# Now it's safe to call st.error after set_page_config | |
st.error(err_msg) | |
logger.error(err_msg) | |
st.stop() | |
except Exception as e: | |
err_msg = f"β Error loading secrets: {e}" | |
st.error(err_msg) | |
logger.error(err_msg) | |
st.stop() | |
# Gemini Configuration | |
VISION_MODEL_NAME = "gemini-pro-vision" | |
GENERATION_CONFIG = { | |
"temperature": 0.2, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 1024, | |
} | |
SAFETY_SETTINGS = [ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
] | |
GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan). | |
Describe the key visual features relevant to a medical context. | |
Identify potential: | |
- Diseases or conditions indicated | |
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns) | |
- Visible cell types | |
- Relevant biomarkers (if inferable from staining or morphology) | |
- Anatomical context (if discernible) | |
Be concise and focus primarily on visually evident information. Avoid definitive diagnoses. | |
Structure the output clearly, perhaps using bullet points for findings. | |
""" | |
# Chroma DB Configuration | |
CHROMA_PATH = "chroma_data_biobert" # Changed path to reflect model change | |
COLLECTION_NAME = "medical_docs_biobert" # Changed collection name | |
# --- Embedding Model Selection --- | |
# Using BioBERT v1.1 - Good domain knowledge, but potentially suboptimal for *semantic similarity search*. | |
# Default pooling (likely CLS token) will be used by sentence-transformers. | |
# Consider models fine-tuned for sentence similarity if retrieval quality is low: | |
# e.g., 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer' | |
EMBEDDING_MODEL_NAME = "dmis-lab/biobert-v1.1" | |
CHROMA_DISTANCE_METRIC = "cosine" # Cosine is generally good for sentence embeddings | |
# --- Caching Resource Initialization --- | |
def initialize_gemini_model() -> Optional[genai.GenerativeModel]: | |
"""Initializes and returns the Gemini Generative Model.""" | |
try: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
model = genai.GenerativeModel( | |
model_name=VISION_MODEL_NAME, | |
generation_config=GENERATION_CONFIG, | |
safety_settings=SAFETY_SETTINGS | |
) | |
logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}") | |
return model | |
except Exception as e: | |
err_msg = f"β Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}" | |
st.error(err_msg) # Safe to call st.error here now | |
logger.error(err_msg, exc_info=True) | |
return None | |
def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]: | |
"""Initializes and returns the Hugging Face Embedding Function.""" | |
st.info(f"Initializing Embedding Model: {EMBEDDING_MODEL_NAME} (this may take a moment)...") | |
try: | |
# Pass HF_TOKEN if it exists (required for private/gated models) | |
embed_func = embedding_functions.HuggingFaceEmbeddingFunction( | |
api_key=HF_TOKEN, # Pass token here if needed by model | |
model_name=EMBEDDING_MODEL_NAME | |
) | |
logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}") | |
st.success(f"Embedding Model {EMBEDDING_MODEL_NAME} initialized.") | |
return embed_func | |
except Exception as e: | |
err_msg = f"β Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}" | |
st.error(err_msg) # Safe here | |
logger.error(err_msg, exc_info=True) | |
st.info("βΉοΈ Make sure the embedding model name is correct and you have network access. " | |
"If using a private model, ensure HF_TOKEN is set in secrets. Check Space logs for details.") | |
return None | |
def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]: | |
"""Initializes the Chroma DB client and returns the collection.""" | |
if not _embedding_func: | |
st.error("β Cannot initialize Chroma DB without a valid embedding function.") # Safe here | |
return None | |
st.info(f"Initializing Chroma DB collection '{COLLECTION_NAME}'...") | |
try: | |
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) | |
collection = chroma_client.get_or_create_collection( | |
name=COLLECTION_NAME, | |
embedding_function=_embedding_func, # Pass the initialized function | |
metadata={"hnsw:space": CHROMA_DISTANCE_METRIC} | |
) | |
logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.") | |
st.success(f"Chroma DB collection '{COLLECTION_NAME}' ready.") | |
return collection | |
except Exception as e: | |
err_msg = f"β Error initializing Chroma DB at '{CHROMA_PATH}': {e}" | |
st.error(err_msg) # Safe here | |
logger.error(err_msg, exc_info=True) | |
st.info(f"βΉοΈ Ensure the path '{CHROMA_PATH}' is writable. Check Space logs.") | |
return None | |
# --- Core Logic Functions (with Caching for Data Operations) --- | |
# Show spinner manually in UI | |
def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]: | |
""" | |
Analyzes image bytes with Gemini, returns (analysis_text, is_error). | |
Uses Streamlit's caching based on image_bytes. | |
""" | |
if not _gemini_model: | |
return "Error: Gemini model not initialized.", True | |
try: | |
img = Image.open(io.BytesIO(image_bytes)) | |
response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img]) | |
if not response.parts: | |
if response.prompt_feedback and response.prompt_feedback.block_reason: | |
reason = response.prompt_feedback.block_reason | |
msg = f"Analysis blocked by safety settings: {reason}" | |
logger.warning(msg) | |
return msg, True # Indicate block/error state | |
else: | |
msg = "Error: Gemini analysis returned no content (empty or invalid response)." | |
logger.error(msg) | |
return msg, True | |
logger.info("Gemini analysis successful.") | |
return response.text, False # Indicate success | |
except genai.types.BlockedPromptException as e: | |
msg = f"Analysis blocked (prompt issue): {e}" | |
logger.warning(msg) | |
return msg, True | |
except Exception as e: | |
msg = f"Error during Gemini analysis: {e}" | |
logger.error(msg, exc_info=True) | |
return msg, True | |
def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]: | |
"""Queries Chroma DB, returns results dict or None on error.""" | |
if not _collection: | |
logger.error("Query attempt failed: Chroma collection is not available.") | |
return None | |
if not query_text: | |
logger.warning("Attempted to query Chroma with empty text.") | |
return None | |
try: | |
refined_query = query_text # Using direct analysis text for now | |
results = _collection.query( | |
query_texts=[refined_query], | |
n_results=n_results, | |
include=['documents', 'metadatas', 'distances'] | |
) | |
logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'") | |
return results | |
except Exception as e: | |
# Show error in UI as well | |
st.error(f"β Error querying Chroma DB: {e}", icon="π¨") | |
logger.error(f"Error querying Chroma DB: {e}", exc_info=True) | |
return None | |
def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction): | |
"""Adds example medical text snippets to Chroma using the provided embedding function.""" | |
if not collection or not embedding_func: | |
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function not available.") | |
return | |
# Check if dummy data needs adding first to avoid unnecessary processing | |
docs_to_check = [ | |
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive." | |
] # Only check one doc for speed | |
try: | |
existing_check = collection.get(where={"document": docs_to_check[0]}, limit=1, include=[]) | |
if existing_check and existing_check.get('ids'): | |
st.info("Dummy data seems to already exist. Skipping add.") | |
logger.info("Skipping dummy data addition as it likely exists.") | |
return | |
except Exception as e: | |
logger.warning(f"Could not efficiently check for existing dummy data: {e}. Proceeding with add attempt.") | |
status = st.status(f"Adding dummy data (using {EMBEDDING_MODEL_NAME})...", expanded=True) | |
try: | |
# --- Dummy Data Definition --- | |
docs = [ | |
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.", | |
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.", | |
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.", | |
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.", | |
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted." | |
] | |
metadatas = [ | |
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"}, | |
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"}, | |
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"}, | |
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"}, | |
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"} | |
] | |
# Ensure IDs are unique even if run close together | |
base_id = f"doc_biobert_{int(time.time() * 1000)}" | |
ids = [f"{base_id}_{i}" for i in range(len(docs))] | |
status.update(label=f"Generating embeddings & adding {len(docs)} documents (this uses BioBERT and may take time)...") | |
# Embeddings are generated implicitly by ChromaDB during .add() | |
collection.add( | |
documents=docs, | |
metadatas=metadatas, | |
ids=ids | |
) | |
status.update(label=f"β Added {len(docs)} dummy documents.", state="complete", expanded=False) | |
logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.") | |
except Exception as e: | |
err_msg = f"Error adding dummy data to Chroma: {e}" | |
status.update(label=f"β Error: {err_msg}", state="error", expanded=True) | |
logger.error(err_msg, exc_info=True) | |
# --- Initialize Resources --- | |
# These calls use @st.cache_resource, run only once unless cleared/changed. | |
# Order matters if one depends on another (embedding func needed for chroma). | |
gemini_model = initialize_gemini_model() | |
embedding_func = initialize_embedding_function() | |
collection = initialize_chroma_collection(embedding_func) # Pass embedding func | |
# --- Streamlit UI --- | |
# set_page_config() is already called at the top | |
st.title("βοΈ Medical Image Analysis & RAG (BioBERT Embeddings)") | |
# --- DISCLAIMER --- | |
st.warning(""" | |
**β οΈ Disclaimer:** This tool is for demonstration and informational purposes ONLY. | |
It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making. | |
AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns. | |
Do **NOT** upload identifiable patient data (PHI). Analysis quality depends heavily on the chosen embedding model. | |
""", icon="β£οΈ") | |
st.markdown(f""" | |
Upload a medical image. Gemini Vision will analyze it. Related information | |
will be retrieved from a Chroma DB knowledge base using **{EMBEDDING_MODEL_NAME}** embeddings. | |
""") | |
# Sidebar | |
with st.sidebar: | |
st.header("βοΈ Controls") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", | |
type=["jpg", "jpeg", "png", "tiff", "webp"], | |
help="Upload a medical image file (e.g., pathology, diagram)." | |
) | |
st.divider() | |
if st.button("β Add/Verify Dummy KB Data", help=f"Adds example text data to Chroma DB ({COLLECTION_NAME}) if it doesn't exist."): | |
if collection and embedding_func: | |
add_dummy_data_to_chroma(collection, embedding_func) | |
else: | |
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.") | |
st.divider() | |
st.header("βΉοΈ System Info") | |
st.caption(f"**Gemini Model:** `{VISION_MODEL_NAME}`") | |
st.caption(f"**Embedding Model:** `{EMBEDDING_MODEL_NAME}`") | |
st.caption(f"**Chroma Collection:** `{COLLECTION_NAME}`") | |
st.caption(f"**Chroma Path:** `{CHROMA_PATH}`") | |
st.caption(f"**Distance Metric:** `{CHROMA_DISTANCE_METRIC}`") | |
st.caption(f"**Google API Key:** {'Set' if GOOGLE_API_KEY else 'Not Set'}") | |
st.caption(f"**HF Token:** {'Provided' if HF_TOKEN else 'Not Provided'}") | |
# Main Display Area | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("πΌοΈ Uploaded Image") | |
if uploaded_file is not None: | |
image_bytes = uploaded_file.getvalue() | |
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True) | |
else: | |
st.info("Upload an image using the sidebar to begin.") | |
with col2: | |
st.subheader("π¬ Analysis & Retrieval") | |
if uploaded_file is not None and gemini_model and collection: | |
# 1. Analyze Image | |
analysis_text = "" | |
analysis_error = False | |
with st.status("π§ Analyzing image with Gemini Vision...", expanded=True) as status_gemini: | |
analysis_text, analysis_error = analyze_image_with_gemini(gemini_model, image_bytes) | |
if analysis_error: | |
# Shorten the message for status if needed | |
status_label = f"β οΈ Analysis Failed/Blocked: {analysis_text.split(':')[0]}" | |
status_gemini.update(label=status_label , state="error") | |
st.error(f"**Analysis Output:** {analysis_text}", icon="π¨") | |
else: | |
status_gemini.update(label="β Analysis Complete", state="complete", expanded=False) | |
st.markdown("**Gemini Vision Analysis:**") | |
st.markdown(analysis_text) # Display the successful analysis | |
# 2. Query Chroma if Analysis Succeeded | |
if not analysis_error and analysis_text: | |
st.markdown("---") # Separator | |
st.subheader("π Related Information (RAG)") | |
with st.status(f"π Searching knowledge base (Chroma DB w/ BioBERT)...", expanded=True) as status_chroma: | |
chroma_results = query_chroma(collection, analysis_text, n_results=3) # Fetch top 3 | |
if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]: | |
num_results = len(chroma_results['documents'][0]) | |
status_chroma.update(label=f"β Found {num_results} related entries.", state="complete", expanded=False) | |
for i in range(num_results): | |
doc = chroma_results['documents'][0][i] | |
meta = chroma_results['metadatas'][0][i] | |
dist = chroma_results['distances'][0][i] | |
# Ensure distance is float before calculation | |
similarity = 1.0 - float(dist) if dist is not None else 0.0 | |
expander_title = f"Result {i+1} (Similarity: {similarity:.4f}) | Source: {meta.get('source', 'N/A')}" | |
with st.expander(expander_title): | |
st.markdown("**Retrieved Text:**") | |
st.markdown(f"> {doc}") # Use blockquote | |
st.markdown("**Metadata:**") | |
for key, value in meta.items(): | |
st.markdown(f"- **{key.replace('_', ' ').title()}:** `{value}`") | |
if meta.get("IMAGE_ID"): | |
st.info(f"βΉοΈ Associated visual asset ID: `{meta['IMAGE_ID']}`") | |
elif chroma_results is not None: # Query ran, no results | |
status_chroma.update(label="β οΈ No relevant information found.", state="warning", expanded=False) | |
st.warning("No relevant documents found in the knowledge base for this analysis.", icon="β οΈ") | |
# Error case is handled by st.error within query_chroma itself | |
elif chroma_results is None: | |
status_chroma.update(label="β Failed to retrieve results.", state="error", expanded=True) | |
elif not uploaded_file: | |
st.info("Analysis results will appear here once an image is uploaded.") | |
else: | |
# Initialization error occurred earlier, resources might be None | |
st.error("β Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar info & Space logs).") | |
st.markdown("---") | |
st.markdown("<div style='text-align: center; font-size: small;'>Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit</div>", unsafe_allow_html=True) | |