hackathon_acvss / system_test_page.py
ImedHa's picture
Upload 7 files
ee412eb verified
import streamlit as st
from PIL import Image
import torch
import numpy as np
import os
from io import StringIO
import sys
import torch.nn as nn
# --- TorchDynamo Fix for Unsloth/MedGemma ---
import torch._dynamo
torch._dynamo.config.capture_scalar_outputs = True
# --- DEFINITIVE FIX FOR JIT COMPILER ERRORS ---
torch.compiler.disable()
# --- Dependency Handling ---
try:
from monai.networks.nets import SwinUNETR
import torchvision.transforms as T
from unsloth import FastVisionModel
from transformers import TextStreamer
from s2wrapper import forward as multiscale_forward
except ImportError as e:
st.error(f"A required library is not installed. Please install dependencies. Error: {e}")
st.stop()
# --- Config and Model Definition ---
class Config:
ORIGINAL_LABELS = [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60]
LABEL_MAP = {val: i for i, val in enumerate(ORIGINAL_LABELS)}
NUM_CLASSES = len(ORIGINAL_LABELS)
IMG_SIZE = (256, 256)
FEATURE_SIZE = 48
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class multiscaleSwinUNETR(nn.Module):
def __init__(self, num_classes, scales=[1]):
super().__init__()
self.scales = scales
self.num_classes = num_classes
self.model = SwinUNETR(
spatial_dims=2,
in_channels=3,
out_channels=num_classes,
feature_size=Config.FEATURE_SIZE,
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
use_checkpoint=True,
use_v2=True
)
self.segmentation_head = nn.Sequential(
nn.Conv2d(len(scales)*num_classes, num_classes, 3, padding=1),
nn.BatchNorm2d(num_classes),
nn.ReLU(inplace=True),
nn.Conv2d(num_classes, num_classes, 1)
)
def forward(self, x):
outs = multiscale_forward(self.model, x, scales=self.scales, output_shape="bchw")
if isinstance(outs, (list, tuple)):
normed = []
for f in outs:
f = f / (f.std(dim=(2, 3), keepdim=True) + 1e-6)
normed.append(f)
feats = torch.cat(normed, dim=1)
elif isinstance(outs, torch.Tensor) and outs.dim() == 4:
if len(self.scales) == 1:
return outs
feats = outs / (outs.std(dim=(2, 3), keepdim=True) + 1e-6)
else:
raise ValueError(f"Unexpected output shape/type from multiscale_forward: {type(outs)}, {getattr(outs,'shape',None)}")
logits = self.segmentation_head(feats)
return logits
# --- Model Loading ---
@st.cache_resource
def load_swinunetr_model():
"""Loads the multiscale SwinUNETR segmentation model."""
model_path = 's2-swinunetr-weights.pth'
if not os.path.exists(model_path):
st.error(f"Segmentation model file not found at {model_path}")
return None, None
try:
model = multiscaleSwinUNETR(num_classes=Config.NUM_CLASSES, scales=[1])
model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE))
model.eval()
return model, Config
except Exception as e:
st.error(f"Error loading segmentation model: {e}")
return None, None
@st.cache_resource
def load_medgemma_model():
"""Loads the MedGemma vision-language model in eager mode."""
try:
model, processor = FastVisionModel.from_pretrained(
"fiqqy/MedGemma-MM-OR-FT10",
load_in_4bit=False,
use_gradient_checkpointing="unsloth",
)
return model, processor
except Exception as e:
st.error(f"Error loading MedGemma model: {e}")
return None, None
# --- Preprocessing ---
def preprocess_frames(frames, config):
"""Prepares image frames for the segmentation model."""
transform = T.Compose([
T.Resize(config.IMG_SIZE, antialias=True),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
tensors = [transform(frame.convert("RGB")) for frame in frames]
batch = torch.stack(tensors)
return batch
# --- Color Palette for Mask Visualization ---
def make_palette(num_classes):
rng = np.random.default_rng(0)
colors = rng.integers(0, 255, size=(num_classes, 3), dtype=np.uint8)
colors[0] = np.array([0, 0, 0])
return colors
# --- Inference ---
def run_segmentation(model, config, frames):
"""Runs segmentation on the uploaded frames and visualizes with a color palette."""
st.write("Running segmentation...")
batch = preprocess_frames(frames, config)
device = config.DEVICE
batch = batch.to(device)
model = model.to(device)
with torch.no_grad():
logits = model(batch)
preds = torch.argmax(logits, 1).cpu().numpy()
mask = preds[0]
st.write(f"Mask unique values: {np.unique(mask)}")
palette = make_palette(config.NUM_CLASSES)
color_mask = palette[mask]
mask_img = Image.fromarray(color_mask.astype(np.uint8))
return mask_img
# --- MedGemma Captioning ---
def run_captioning(medgemma_model, processor, frames, mask_img, instruction):
"""Runs MedGemma inference using 3 frames, 1 mask, and an instruction."""
st.write("Preparing inputs for MedGemma...")
images = [f.convert("RGB") for f in frames]
mask_img = mask_img.convert("RGB")
messages = [
{"role": "user", "content": [
{"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"},
{"type": "text", "text": instruction},
]},
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
all_images = images + [mask_img]
inputs = processor(
all_images, input_text, add_special_tokens=False, return_tensors="pt",
).to(device)
text_streamer = TextStreamer(processor, skip_prompt=True)
old_stdout = sys.stdout
sys.stdout = captured_output = StringIO()
st.write("Running MedGemma Analysis...")
torch._dynamo.disable()
medgemma_model.generate(
**inputs, streamer=text_streamer, max_new_tokens=768,
use_cache=True, temperature=1.0, top_p=0.95, top_k=64
)
sys.stdout = old_stdout
result = captured_output.getvalue()
return result
# --- Streamlit UI ---
def show():
"""Main function to render the Streamlit UI."""
st.title("Surgical Scene Analysis System")
st.write("A system to test surgical scene segmentation and captioning models.")
st.header("1. Load Models")
if "seg_model" not in st.session_state or "seg_config" not in st.session_state:
st.session_state.seg_model, st.session_state.seg_config = None, None
if st.button("Load Segmentation Model"):
with st.spinner("Loading SwinUNETR..."):
st.session_state.seg_model, st.session_state.seg_config = load_swinunetr_model()
if st.session_state.seg_model is not None:
st.success("Segmentation model is loaded.")
else:
st.warning("Segmentation model is not loaded.")
if "medgemma_model" not in st.session_state:
st.session_state.medgemma_model, st.session_state.processor = None, None
if st.button("Load MedGemma Model"):
with st.spinner("Loading MedGemma... This can take several minutes."):
st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model()
if st.session_state.get("medgemma_model") and st.session_state.get("processor"):
st.success("MedGemma model is loaded.")
else:
st.warning("MedGemma model is not loaded.")
st.header("2. Upload Data & Generate Mask")
st.subheader("Upload Three Sequential Surgical Video Frames")
col1, col2, col3 = st.columns(3)
uploaded_files = [
col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"),
col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"),
col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3")
]
frames = [Image.open(f) for f in uploaded_files if f is not None]
display_size = (256, 256)
if "mask_img" not in st.session_state:
st.session_state.mask_img = None
if len(frames) == 3:
st.success("All three frames have been uploaded successfully.")
img_cols = st.columns(4)
for i, frame in enumerate(frames):
img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True)
if st.session_state.seg_model and st.session_state.seg_config and st.button("Run Segmentation"):
with st.spinner("Generating segmentation mask..."):
st.session_state.mask_img = run_segmentation(st.session_state.seg_model, st.session_state.seg_config, frames)
if st.session_state.mask_img is not None:
img_cols[3].image(st.session_state.mask_img.resize(display_size), caption="Segmentation Mask", use_container_width=True)
else:
st.info("Please upload all three frames to proceed.")
st.header("3. Generate Scene Analysis")
instruction_prompt = st.text_area(
"Enter your custom instruction prompt:",
"Provide a detailed summary of the surgical action, noting the instruments used and their interactions."
)
can_run_analysis = (
st.session_state.get("medgemma_model") is not None and
len(frames) == 3 and
st.session_state.get("mask_img") is not None and
bool(instruction_prompt)
)
if st.button("Run Analysis", disabled=not can_run_analysis):
with st.spinner("Running MedGemma analysis... This may take a moment."):
result = run_captioning(
st.session_state.medgemma_model, st.session_state.processor,
frames, st.session_state.mask_img, instruction_prompt
)
st.subheader("Analysis Result")
st.write(result)
if not can_run_analysis:
st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, segmentation is complete, and a prompt is provided.")
if __name__ == "__main__":
show()