mgbam's picture
Create app.py
21689c4 verified
raw
history blame
11.2 kB
import streamlit as st
import google.generativeai as genai
import chromadb
from chromadb.utils import embedding_functions
from PIL import Image
import os
import io
import time # To create unique IDs for Chroma
# --- Configuration ---
try:
# Try loading secrets from Hugging Face secrets first
GOOGLE_API_KEY = st.secrets["GOOGLE_API_KEY"]
genai.configure(api_key=GOOGLE_API_KEY)
except KeyError:
st.error("GOOGLE_API_KEY not found in Hugging Face secrets!")
st.stop()
except Exception as e:
st.error(f"Error configuring Google AI: {e}")
st.stop()
# --- Gemini Model Setup ---
# Check available models if needed, select the vision model
# for m in genai.list_models():
# if 'generateContent' in m.supported_generation_methods:
# print(m.name) # Find the vision model name (e.g., 'gemini-pro-vision')
VISION_MODEL_NAME = "gemini-pro-vision"
GENERATION_CONFIG = {
"temperature": 0.2, # Lower temp for more factual descriptions
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 1024,
}
SAFETY_SETTINGS = [ # Adjust safety settings as needed for medical content
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
try:
gemini_model = genai.GenerativeModel(
model_name=VISION_MODEL_NAME,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS
)
except Exception as e:
st.error(f"Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}")
st.stop()
# --- Chroma DB Setup ---
# Using persistent storage within the HF Space (data lost if space is wiped)
# For production, consider a hosted Chroma or other DB solution.
CHROMA_PATH = "chroma_data"
COLLECTION_NAME = "medical_docs"
# Use a default sentence transformer embedding function (runs locally on HF space CPU)
# For better domain adaptation, consider finetuned medical embeddings if possible/available.
# Make sure the model used here matches the one used when INGESTING data.
embedding_func = embedding_functions.DefaultEmbeddingFunction()
try:
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
# Get or create the collection with the specified embedding function
collection = chroma_client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"} # Use cosine distance
)
except Exception as e:
st.error(f"Error initializing Chroma DB at '{CHROMA_PATH}': {e}")
st.info("If this is the first run, the directory will be created.")
# Attempt creation again more robustly if needed, or guide user.
st.stop()
# --- Helper Functions ---
def analyze_image_with_gemini(image_bytes):
"""Sends image bytes to Gemini Vision and returns the text description."""
try:
img = Image.open(io.BytesIO(image_bytes))
prompt = """Analyze this medical image (could be a pathology slide, diagram, or other medical visual).
Describe the key visual features relevant to a medical professional.
Identify potential:
- Diseases or conditions suggested
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns)
- Cell types visible
- Relevant biomarkers (if inferrable from staining or morphology)
- Anatomical context (if clear)
Be concise and focus on visually evident information.
"""
response = gemini_model.generate_content([prompt, img])
# Handle potential blocked responses or errors
if not response.parts:
# Check if it was blocked
if response.prompt_feedback and response.prompt_feedback.block_reason:
return f"Analysis blocked: {response.prompt_feedback.block_reason}"
else:
# Some other issue, maybe no response text?
return "Error: Gemini analysis failed or returned no content."
return response.text
except genai.types.BlockedPromptException as e:
st.error(f"Gemini request blocked: {e}")
return f"Analysis blocked due to safety settings: {e}"
except Exception as e:
st.error(f"Error during Gemini analysis: {e}")
return f"Error analyzing image: {e}"
def query_chroma(query_text, n_results=5):
"""Queries the Chroma collection with the given text."""
try:
results = collection.query(
query_texts=[query_text],
n_results=n_results,
include=['documents', 'metadatas', 'distances'] # Include distances for relevance sorting
)
return results
except Exception as e:
st.error(f"Error querying Chroma DB: {e}")
return None
def add_dummy_data_to_chroma():
"""Adds some example medical text snippets to Chroma."""
# --- IMPORTANT ---
# In a real scenario, this data would come from processing actual medical documents
# (papers, reports) using a tool like Unstructured (as in the original article)
# or manual curation to extract text and METADATA, including IMAGE_IDs.
# The embeddings generated here MUST match the query embedding function.
st.info("Adding dummy data to Chroma DB...")
docs = [
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.",
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.",
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted."
]
metadatas = [
{"source": "Example Paper 1", "entities": {"DISEASES": ["adenocarcinoma", "lung cancer"], "PATHOLOGY_FINDINGS": ["glandular structures", "nuclear atypia", "papillary subtype"], "BIOMARKERS": ["TTF-1"]}, "IMAGE_ID": "fig_1a_adeno_lung.png"},
{"source": "Path Report 789", "entities": {"DISEASES": ["high-grade glioma", "glioblastoma"], "PATHOLOGY_FINDINGS": ["necrosis", "microvascular proliferation"], "BIOMARKERS": ["Ki-67"]}, "IMAGE_ID": "slide_34b_gbm.tiff"},
{"source": "Textbook Chapter 5", "entities": {"GENES": ["EGFR"], "DRUGS": ["tyrosine kinase inhibitors"], "DISEASES": ["non-small cell lung cancer"]}, "IMAGE_ID": "diagram_egfr_pathway.svg"},
{"source": "Path Report 101", "entities": {"DISEASES": ["chronic gastritis", "Helicobacter pylori infection"], "PATHOLOGY_FINDINGS": ["intestinal metaplasia"]}, "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
{"source": "Case Study CJD", "entities": {"DISEASES": ["prion disease"], "PATHOLOGY_FINDINGS": ["Spongiform changes", "Gliosis"], "ANATOMICAL_LOCATIONS": ["cerebral cortex"]}, "IMAGE_ID": "slide_cjd_sample_02.jpg"}
]
ids = [f"doc_{int(time.time())}_{i}" for i in range(len(docs))] # Unique IDs
try:
# Check if docs with these exact texts already exist to avoid duplicates on rerun
existing = collection.get(where={"$or": [{"document": doc} for doc in docs]})
if not existing or not existing['ids']: # Only add if none exist
collection.add(
documents=docs,
metadatas=metadatas,
ids=ids
)
st.success(f"Added {len(docs)} dummy documents to Chroma collection '{COLLECTION_NAME}'.")
else:
st.warning("Dummy data seems to already exist in the collection.")
except Exception as e:
st.error(f"Error adding dummy data to Chroma: {e}")
# --- Streamlit UI ---
st.set_page_config(layout="wide")
st.title("โš•๏ธ Medical Image Analysis & RAG")
st.markdown("Upload a medical image (pathology slide, diagram, etc.). Gemini Vision will analyze it, and Chroma DB will retrieve related information from a knowledge base.")
# Sidebar for controls
with st.sidebar:
st.header("Controls")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "tiff", "webp"])
if st.button("Load Dummy KB Data"):
add_dummy_data_to_chroma()
st.info("Note: Chroma data persists in the Space's storage but is lost if the Space is reset/deleted.")
# Main area for display
if uploaded_file is not None:
# Read image bytes
image_bytes = uploaded_file.getvalue()
# Display the uploaded image
st.image(image_bytes, caption=f"Uploaded Image: {uploaded_file.name}", use_column_width=False, width=400)
st.markdown("---")
st.subheader("๐Ÿ”ฌ Gemini Vision Analysis")
# Analyze image with Gemini
with st.spinner("Analyzing image with Gemini Vision..."):
analysis_text = analyze_image_with_gemini(image_bytes)
if analysis_text.startswith("Error:") or analysis_text.startswith("Analysis blocked:"):
st.error(analysis_text)
else:
st.markdown(analysis_text)
st.markdown("---")
st.subheader("๐Ÿ“š Related Information from Knowledge Base (Chroma DB)")
# Query Chroma DB using the Gemini analysis text
with st.spinner("Querying Chroma DB..."):
chroma_results = query_chroma(analysis_text)
if chroma_results and chroma_results.get('documents') and chroma_results['documents'][0]:
st.success(f"Found {len(chroma_results['documents'][0])} related entries:")
for i in range(len(chroma_results['documents'][0])):
doc = chroma_results['documents'][0][i]
meta = chroma_results['metadatas'][0][i]
dist = chroma_results['distances'][0][i]
with st.expander(f"Result {i+1} (Distance: {dist:.4f}) - Source: {meta.get('source', 'N/A')}"):
st.markdown("**Text:**")
st.markdown(doc)
st.markdown("**Metadata:**")
st.json(meta) # Display all metadata nicely
# Highlight if it references another image
if meta.get("IMAGE_ID"):
st.info(f"โ„น๏ธ This text describes another visual asset: `{meta['IMAGE_ID']}`")
# In a real app, you might fetch/display this image if available
elif chroma_results is not None: # Query ran but found nothing
st.warning("No relevant information found in the knowledge base for this analysis.")
else: # Error occurred during query
st.error("Failed to retrieve results from Chroma DB.")
else:
st.info("Upload an image using the sidebar to start the analysis.")
st.markdown("---")
st.markdown("Powered by Google Gemini, Chroma DB, and Streamlit.")