Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
import numpy as np | |
import os | |
from io import StringIO | |
import sys | |
import torch.nn as nn | |
# --- TorchDynamo Fix for Unsloth/MedGemma --- | |
import torch._dynamo | |
torch._dynamo.config.capture_scalar_outputs = True | |
# --- DEFINITIVE FIX FOR JIT COMPILER ERRORS --- | |
torch.compiler.disable() | |
# --- Dependency Handling --- | |
try: | |
from monai.networks.nets import SwinUNETR | |
import torchvision.transforms as T | |
from unsloth import FastVisionModel | |
from transformers import TextStreamer | |
from s2wrapper import forward as multiscale_forward | |
except ImportError as e: | |
st.error(f"A required library is not installed. Please install dependencies. Error: {e}") | |
st.stop() | |
# --- Config and Model Definition --- | |
class Config: | |
ORIGINAL_LABELS = [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60] | |
LABEL_MAP = {val: i for i, val in enumerate(ORIGINAL_LABELS)} | |
NUM_CLASSES = len(ORIGINAL_LABELS) | |
IMG_SIZE = (256, 256) | |
FEATURE_SIZE = 48 | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class multiscaleSwinUNETR(nn.Module): | |
def __init__(self, num_classes, scales=[1]): | |
super().__init__() | |
self.scales = scales | |
self.num_classes = num_classes | |
self.model = SwinUNETR( | |
spatial_dims=2, | |
in_channels=3, | |
out_channels=num_classes, | |
feature_size=Config.FEATURE_SIZE, | |
drop_rate=0.0, | |
attn_drop_rate=0.0, | |
dropout_path_rate=0.0, | |
use_checkpoint=True, | |
use_v2=True | |
) | |
self.segmentation_head = nn.Sequential( | |
nn.Conv2d(len(scales)*num_classes, num_classes, 3, padding=1), | |
nn.BatchNorm2d(num_classes), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(num_classes, num_classes, 1) | |
) | |
def forward(self, x): | |
outs = multiscale_forward(self.model, x, scales=self.scales, output_shape="bchw") | |
if isinstance(outs, (list, tuple)): | |
normed = [] | |
for f in outs: | |
f = f / (f.std(dim=(2, 3), keepdim=True) + 1e-6) | |
normed.append(f) | |
feats = torch.cat(normed, dim=1) | |
elif isinstance(outs, torch.Tensor) and outs.dim() == 4: | |
if len(self.scales) == 1: | |
return outs | |
feats = outs / (outs.std(dim=(2, 3), keepdim=True) + 1e-6) | |
else: | |
raise ValueError(f"Unexpected output shape/type from multiscale_forward: {type(outs)}, {getattr(outs,'shape',None)}") | |
logits = self.segmentation_head(feats) | |
return logits | |
# --- Model Loading --- | |
def load_swinunetr_model(): | |
"""Loads the multiscale SwinUNETR segmentation model.""" | |
model_path = 's2-swinunetr-weights.pth' | |
if not os.path.exists(model_path): | |
st.error(f"Segmentation model file not found at {model_path}") | |
return None, None | |
try: | |
model = multiscaleSwinUNETR(num_classes=Config.NUM_CLASSES, scales=[1]) | |
model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE)) | |
model.eval() | |
return model, Config | |
except Exception as e: | |
st.error(f"Error loading segmentation model: {e}") | |
return None, None | |
def load_medgemma_model(): | |
"""Loads the MedGemma vision-language model in eager mode.""" | |
try: | |
model, processor = FastVisionModel.from_pretrained( | |
"fiqqy/MedGemma-MM-OR-FT10", | |
load_in_4bit=False, | |
use_gradient_checkpointing="unsloth", | |
) | |
return model, processor | |
except Exception as e: | |
st.error(f"Error loading MedGemma model: {e}") | |
return None, None | |
# --- Preprocessing --- | |
def preprocess_frames(frames, config): | |
"""Prepares image frames for the segmentation model.""" | |
transform = T.Compose([ | |
T.Resize(config.IMG_SIZE, antialias=True), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
tensors = [transform(frame.convert("RGB")) for frame in frames] | |
batch = torch.stack(tensors) | |
return batch | |
# --- Color Palette for Mask Visualization --- | |
def make_palette(num_classes): | |
rng = np.random.default_rng(0) | |
colors = rng.integers(0, 255, size=(num_classes, 3), dtype=np.uint8) | |
colors[0] = np.array([0, 0, 0]) | |
return colors | |
# --- Inference --- | |
def run_segmentation(model, config, frames): | |
"""Runs segmentation on the uploaded frames and visualizes with a color palette.""" | |
st.write("Running segmentation...") | |
batch = preprocess_frames(frames, config) | |
device = config.DEVICE | |
batch = batch.to(device) | |
model = model.to(device) | |
with torch.no_grad(): | |
logits = model(batch) | |
preds = torch.argmax(logits, 1).cpu().numpy() | |
mask = preds[0] | |
st.write(f"Mask unique values: {np.unique(mask)}") | |
palette = make_palette(config.NUM_CLASSES) | |
color_mask = palette[mask] | |
mask_img = Image.fromarray(color_mask.astype(np.uint8)) | |
return mask_img | |
# --- MedGemma Captioning --- | |
def run_captioning(medgemma_model, processor, frames, mask_img, instruction): | |
"""Runs MedGemma inference using 3 frames, 1 mask, and an instruction.""" | |
st.write("Preparing inputs for MedGemma...") | |
images = [f.convert("RGB") for f in frames] | |
mask_img = mask_img.convert("RGB") | |
messages = [ | |
{"role": "user", "content": [ | |
{"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"}, | |
{"type": "text", "text": instruction}, | |
]}, | |
] | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
all_images = images + [mask_img] | |
inputs = processor( | |
all_images, input_text, add_special_tokens=False, return_tensors="pt", | |
).to(device) | |
text_streamer = TextStreamer(processor, skip_prompt=True) | |
old_stdout = sys.stdout | |
sys.stdout = captured_output = StringIO() | |
st.write("Running MedGemma Analysis...") | |
torch._dynamo.disable() | |
medgemma_model.generate( | |
**inputs, streamer=text_streamer, max_new_tokens=768, | |
use_cache=True, temperature=1.0, top_p=0.95, top_k=64 | |
) | |
sys.stdout = old_stdout | |
result = captured_output.getvalue() | |
return result | |
# --- Streamlit UI --- | |
def show(): | |
"""Main function to render the Streamlit UI.""" | |
st.title("Surgical Scene Analysis System") | |
st.write("A system to test surgical scene segmentation and captioning models.") | |
st.header("1. Load Models") | |
if "seg_model" not in st.session_state or "seg_config" not in st.session_state: | |
st.session_state.seg_model, st.session_state.seg_config = None, None | |
if st.button("Load Segmentation Model"): | |
with st.spinner("Loading SwinUNETR..."): | |
st.session_state.seg_model, st.session_state.seg_config = load_swinunetr_model() | |
if st.session_state.seg_model is not None: | |
st.success("Segmentation model is loaded.") | |
else: | |
st.warning("Segmentation model is not loaded.") | |
if "medgemma_model" not in st.session_state: | |
st.session_state.medgemma_model, st.session_state.processor = None, None | |
if st.button("Load MedGemma Model"): | |
with st.spinner("Loading MedGemma... This can take several minutes."): | |
st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model() | |
if st.session_state.get("medgemma_model") and st.session_state.get("processor"): | |
st.success("MedGemma model is loaded.") | |
else: | |
st.warning("MedGemma model is not loaded.") | |
st.header("2. Upload Data & Generate Mask") | |
st.subheader("Upload Three Sequential Surgical Video Frames") | |
col1, col2, col3 = st.columns(3) | |
uploaded_files = [ | |
col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"), | |
col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"), | |
col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3") | |
] | |
frames = [Image.open(f) for f in uploaded_files if f is not None] | |
display_size = (256, 256) | |
if "mask_img" not in st.session_state: | |
st.session_state.mask_img = None | |
if len(frames) == 3: | |
st.success("All three frames have been uploaded successfully.") | |
img_cols = st.columns(4) | |
for i, frame in enumerate(frames): | |
img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True) | |
if st.session_state.seg_model and st.session_state.seg_config and st.button("Run Segmentation"): | |
with st.spinner("Generating segmentation mask..."): | |
st.session_state.mask_img = run_segmentation(st.session_state.seg_model, st.session_state.seg_config, frames) | |
if st.session_state.mask_img is not None: | |
img_cols[3].image(st.session_state.mask_img.resize(display_size), caption="Segmentation Mask", use_container_width=True) | |
else: | |
st.info("Please upload all three frames to proceed.") | |
st.header("3. Generate Scene Analysis") | |
instruction_prompt = st.text_area( | |
"Enter your custom instruction prompt:", | |
"Provide a detailed summary of the surgical action, noting the instruments used and their interactions." | |
) | |
can_run_analysis = ( | |
st.session_state.get("medgemma_model") is not None and | |
len(frames) == 3 and | |
st.session_state.get("mask_img") is not None and | |
bool(instruction_prompt) | |
) | |
if st.button("Run Analysis", disabled=not can_run_analysis): | |
with st.spinner("Running MedGemma analysis... This may take a moment."): | |
result = run_captioning( | |
st.session_state.medgemma_model, st.session_state.processor, | |
frames, st.session_state.mask_img, instruction_prompt | |
) | |
st.subheader("Analysis Result") | |
st.write(result) | |
if not can_run_analysis: | |
st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, segmentation is complete, and a prompt is provided.") | |
if __name__ == "__main__": | |
show() | |