import streamlit as st from PIL import Image import torch import numpy as np import os from io import StringIO import sys import torch.nn as nn # --- TorchDynamo Fix for Unsloth/MedGemma --- import torch._dynamo torch._dynamo.config.capture_scalar_outputs = True # --- DEFINITIVE FIX FOR JIT COMPILER ERRORS --- torch.compiler.disable() # --- Dependency Handling --- try: from monai.networks.nets import SwinUNETR import torchvision.transforms as T from unsloth import FastVisionModel from transformers import TextStreamer from s2wrapper import forward as multiscale_forward except ImportError as e: st.error(f"A required library is not installed. Please install dependencies. Error: {e}") st.stop() # --- Config and Model Definition --- class Config: ORIGINAL_LABELS = [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60] LABEL_MAP = {val: i for i, val in enumerate(ORIGINAL_LABELS)} NUM_CLASSES = len(ORIGINAL_LABELS) IMG_SIZE = (256, 256) FEATURE_SIZE = 48 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class multiscaleSwinUNETR(nn.Module): def __init__(self, num_classes, scales=[1]): super().__init__() self.scales = scales self.num_classes = num_classes self.model = SwinUNETR( spatial_dims=2, in_channels=3, out_channels=num_classes, feature_size=Config.FEATURE_SIZE, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_checkpoint=True, use_v2=True ) self.segmentation_head = nn.Sequential( nn.Conv2d(len(scales)*num_classes, num_classes, 3, padding=1), nn.BatchNorm2d(num_classes), nn.ReLU(inplace=True), nn.Conv2d(num_classes, num_classes, 1) ) def forward(self, x): outs = multiscale_forward(self.model, x, scales=self.scales, output_shape="bchw") if isinstance(outs, (list, tuple)): normed = [] for f in outs: f = f / (f.std(dim=(2, 3), keepdim=True) + 1e-6) normed.append(f) feats = torch.cat(normed, dim=1) elif isinstance(outs, torch.Tensor) and outs.dim() == 4: if len(self.scales) == 1: return outs feats = outs / (outs.std(dim=(2, 3), keepdim=True) + 1e-6) else: raise ValueError(f"Unexpected output shape/type from multiscale_forward: {type(outs)}, {getattr(outs,'shape',None)}") logits = self.segmentation_head(feats) return logits # --- Model Loading --- @st.cache_resource def load_swinunetr_model(): """Loads the multiscale SwinUNETR segmentation model.""" model_path = 's2-swinunetr-weights.pth' if not os.path.exists(model_path): st.error(f"Segmentation model file not found at {model_path}") return None, None try: model = multiscaleSwinUNETR(num_classes=Config.NUM_CLASSES, scales=[1]) model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE)) model.eval() return model, Config except Exception as e: st.error(f"Error loading segmentation model: {e}") return None, None @st.cache_resource def load_medgemma_model(): """Loads the MedGemma vision-language model in eager mode.""" try: model, processor = FastVisionModel.from_pretrained( "fiqqy/MedGemma-MM-OR-FT10", load_in_4bit=False, use_gradient_checkpointing="unsloth", ) return model, processor except Exception as e: st.error(f"Error loading MedGemma model: {e}") return None, None # --- Preprocessing --- def preprocess_frames(frames, config): """Prepares image frames for the segmentation model.""" transform = T.Compose([ T.Resize(config.IMG_SIZE, antialias=True), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) tensors = [transform(frame.convert("RGB")) for frame in frames] batch = torch.stack(tensors) return batch # --- Color Palette for Mask Visualization --- def make_palette(num_classes): rng = np.random.default_rng(0) colors = rng.integers(0, 255, size=(num_classes, 3), dtype=np.uint8) colors[0] = np.array([0, 0, 0]) return colors # --- Inference --- def run_segmentation(model, config, frames): """Runs segmentation on the uploaded frames and visualizes with a color palette.""" st.write("Running segmentation...") batch = preprocess_frames(frames, config) device = config.DEVICE batch = batch.to(device) model = model.to(device) with torch.no_grad(): logits = model(batch) preds = torch.argmax(logits, 1).cpu().numpy() mask = preds[0] st.write(f"Mask unique values: {np.unique(mask)}") palette = make_palette(config.NUM_CLASSES) color_mask = palette[mask] mask_img = Image.fromarray(color_mask.astype(np.uint8)) return mask_img # --- MedGemma Captioning --- def run_captioning(medgemma_model, processor, frames, mask_img, instruction): """Runs MedGemma inference using 3 frames, 1 mask, and an instruction.""" st.write("Preparing inputs for MedGemma...") images = [f.convert("RGB") for f in frames] mask_img = mask_img.convert("RGB") messages = [ {"role": "user", "content": [ {"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "text", "text": instruction}, ]}, ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) device = "cuda" if torch.cuda.is_available() else "cpu" all_images = images + [mask_img] inputs = processor( all_images, input_text, add_special_tokens=False, return_tensors="pt", ).to(device) text_streamer = TextStreamer(processor, skip_prompt=True) old_stdout = sys.stdout sys.stdout = captured_output = StringIO() st.write("Running MedGemma Analysis...") torch._dynamo.disable() medgemma_model.generate( **inputs, streamer=text_streamer, max_new_tokens=768, use_cache=True, temperature=1.0, top_p=0.95, top_k=64 ) sys.stdout = old_stdout result = captured_output.getvalue() return result # --- Streamlit UI --- def show(): """Main function to render the Streamlit UI.""" st.title("Surgical Scene Analysis System") st.write("A system to test surgical scene segmentation and captioning models.") st.header("1. Load Models") if "seg_model" not in st.session_state or "seg_config" not in st.session_state: st.session_state.seg_model, st.session_state.seg_config = None, None if st.button("Load Segmentation Model"): with st.spinner("Loading SwinUNETR..."): st.session_state.seg_model, st.session_state.seg_config = load_swinunetr_model() if st.session_state.seg_model is not None: st.success("Segmentation model is loaded.") else: st.warning("Segmentation model is not loaded.") if "medgemma_model" not in st.session_state: st.session_state.medgemma_model, st.session_state.processor = None, None if st.button("Load MedGemma Model"): with st.spinner("Loading MedGemma... This can take several minutes."): st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model() if st.session_state.get("medgemma_model") and st.session_state.get("processor"): st.success("MedGemma model is loaded.") else: st.warning("MedGemma model is not loaded.") st.header("2. Upload Data & Generate Mask") st.subheader("Upload Three Sequential Surgical Video Frames") col1, col2, col3 = st.columns(3) uploaded_files = [ col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"), col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"), col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3") ] frames = [Image.open(f) for f in uploaded_files if f is not None] display_size = (256, 256) if "mask_img" not in st.session_state: st.session_state.mask_img = None if len(frames) == 3: st.success("All three frames have been uploaded successfully.") img_cols = st.columns(4) for i, frame in enumerate(frames): img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True) if st.session_state.seg_model and st.session_state.seg_config and st.button("Run Segmentation"): with st.spinner("Generating segmentation mask..."): st.session_state.mask_img = run_segmentation(st.session_state.seg_model, st.session_state.seg_config, frames) if st.session_state.mask_img is not None: img_cols[3].image(st.session_state.mask_img.resize(display_size), caption="Segmentation Mask", use_container_width=True) else: st.info("Please upload all three frames to proceed.") st.header("3. Generate Scene Analysis") instruction_prompt = st.text_area( "Enter your custom instruction prompt:", "Provide a detailed summary of the surgical action, noting the instruments used and their interactions." ) can_run_analysis = ( st.session_state.get("medgemma_model") is not None and len(frames) == 3 and st.session_state.get("mask_img") is not None and bool(instruction_prompt) ) if st.button("Run Analysis", disabled=not can_run_analysis): with st.spinner("Running MedGemma analysis... This may take a moment."): result = run_captioning( st.session_state.medgemma_model, st.session_state.processor, frames, st.session_state.mask_img, instruction_prompt ) st.subheader("Analysis Result") st.write(result) if not can_run_analysis: st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, segmentation is complete, and a prompt is provided.") if __name__ == "__main__": show()