import os import torch import numpy as np import cv2 # Using OpenCV for image loading/processing import albumentations as A from albumentations.pytorch import ToTensorV2 import gradio as gr import spaces import segmentation_models_pytorch as smp from train_unet import UNetLitModule # Import the Lightning Module definition # --- Configuration --- # Option 1: Load from the Lightning Checkpoint # CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output # Option 2: Load from the saved state_dict MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth" IMG_SIZE = (256, 256) # MUST match training image size DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Load Model --- print(f"Loading model from: {MODEL_STATE_DICT_PATH}") print(f"Using device: {DEVICE}") # Instantiate the base SMP model architecture model = smp.Unet( encoder_name="resnet34", encoder_weights=None, # Don't load pretrained weights, we load our trained ones in_channels=3, classes=1, ) # Load the state dict saved at the end of training try: state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE) # If the state_dict was saved directly from the `model.model` attribute in LitModule: model.load_state_dict(state_dict) # If you saved the entire Lightning Module state_dict, you might need to extract the model part: # state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict'] # # Filter keys if they have a prefix like 'model.' # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} # model.load_state_dict(state_dict) except FileNotFoundError: print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}") print("Please ensure the training script ran successfully and the path is correct.") exit() except Exception as e: print(f"Error loading model state_dict: {e}") print("Ensure the saved state_dict matches the current model architecture.") exit() model.to(DEVICE) model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates) # --- Inference Transforms --- # Should match the validation/test transforms from training (resize, normalize) inference_transform = A.Compose([ A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) # --- Segmentation Function --- @spaces.GPU def segment_image(input_image_np): """ Takes a NumPy image, performs segmentation, and returns images for Gradio. """ # 0. Input validation if input_image_np is None: return None, None, None # Ensure image is RGB (Gradio might provide BGR or grayscale) if len(input_image_np.shape) == 2: # Grayscale input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB) elif input_image_np.shape[2] == 4: # RGBA input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB) # Assume BGR if 3 channels, convert to RGB for consistency with training # input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB input_image_rgb = input_image_np.copy() # 1. Preprocess the image transformed = inference_transform(image=input_image_rgb) image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device # 2. Perform inference with torch.no_grad(): logits = model(image_tensor) # Output is [1, 1, H, W] logits # Apply sigmoid to get probabilities [0, 1] probabilities = torch.sigmoid(logits) # 3. Postprocess the output mask # Remove batch dimension, move to CPU, convert to NumPy mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W] # Threshold probabilities to get binary mask (0 or 1) binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8) # Convert binary mask to a displayable format (e.g., 0 or 255) display_mask = (binary_mask_np * 255) # Shape: [H, W] # Resize mask back to original image size for overlay (optional, better overlay quality) orig_h, orig_w = input_image_rgb.shape[:2] display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) # 4. Create Overlay # Convert single-channel mask to 3 channels to overlay on RGB image mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB) # Make the mask red where segmentation is present mask_rgb[:, :, 0] = 0 # Zero out Blue channel mask_rgb[:, :, 1] = 0 # Zero out Green channel # Where mask_rgb is red (255), keep original image pixel, otherwise blend overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0) # Highlight only the segmented area more distinctly highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized) overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted # Return original, mask (resized), overlay # Gradio expects NumPy arrays #return input_image_rgb, display_mask_resized, overlay_image return display_mask_resized, overlay_image # --- Gradio Interface --- print("Launching Gradio Interface...") with gr.Blocks() as demo: gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)") gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.") with gr.Row(): # Creates a horizontal container inp = gr.Image(type="numpy", label="Input Image") out_mask = gr.Image(type="numpy", label="Segmentation Mask") out_overlay = gr.Image(type="numpy", label="Overlay") # Hook up the function inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay]) # (Optional) add example images # gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]], # inputs=inp, outputs=[out_mask, out_overlay]) # Disable flagging if __name__ == "__main__": demo.launch(share=True) # Share=True to create public link