Spaces:
Sleeping
Sleeping
File size: 6,228 Bytes
821b618 deef786 821b618 6ddf239 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import torch
import numpy as np
import cv2 # Using OpenCV for image loading/processing
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gradio as gr
import spaces
import segmentation_models_pytorch as smp
from train_unet import UNetLitModule # Import the Lightning Module definition
# --- Configuration ---
# Option 1: Load from the Lightning Checkpoint
# CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output
# Option 2: Load from the saved state_dict
MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth"
IMG_SIZE = (256, 256) # MUST match training image size
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load Model ---
print(f"Loading model from: {MODEL_STATE_DICT_PATH}")
print(f"Using device: {DEVICE}")
# Instantiate the base SMP model architecture
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None, # Don't load pretrained weights, we load our trained ones
in_channels=3,
classes=1,
)
# Load the state dict saved at the end of training
try:
state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)
# If the state_dict was saved directly from the `model.model` attribute in LitModule:
model.load_state_dict(state_dict)
# If you saved the entire Lightning Module state_dict, you might need to extract the model part:
# state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict']
# # Filter keys if they have a prefix like 'model.'
# state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
# model.load_state_dict(state_dict)
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}")
print("Please ensure the training script ran successfully and the path is correct.")
exit()
except Exception as e:
print(f"Error loading model state_dict: {e}")
print("Ensure the saved state_dict matches the current model architecture.")
exit()
model.to(DEVICE)
model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates)
# --- Inference Transforms ---
# Should match the validation/test transforms from training (resize, normalize)
inference_transform = A.Compose([
A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
# --- Segmentation Function ---
@spaces.GPU
def segment_image(input_image_np):
"""
Takes a NumPy image, performs segmentation, and returns images for Gradio.
"""
# 0. Input validation
if input_image_np is None:
return None, None, None
# Ensure image is RGB (Gradio might provide BGR or grayscale)
if len(input_image_np.shape) == 2: # Grayscale
input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
elif input_image_np.shape[2] == 4: # RGBA
input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB)
# Assume BGR if 3 channels, convert to RGB for consistency with training
# input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB
input_image_rgb = input_image_np.copy()
# 1. Preprocess the image
transformed = inference_transform(image=input_image_rgb)
image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device
# 2. Perform inference
with torch.no_grad():
logits = model(image_tensor) # Output is [1, 1, H, W] logits
# Apply sigmoid to get probabilities [0, 1]
probabilities = torch.sigmoid(logits)
# 3. Postprocess the output mask
# Remove batch dimension, move to CPU, convert to NumPy
mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W]
# Threshold probabilities to get binary mask (0 or 1)
binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8)
# Convert binary mask to a displayable format (e.g., 0 or 255)
display_mask = (binary_mask_np * 255) # Shape: [H, W]
# Resize mask back to original image size for overlay (optional, better overlay quality)
orig_h, orig_w = input_image_rgb.shape[:2]
display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
# 4. Create Overlay
# Convert single-channel mask to 3 channels to overlay on RGB image
mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB)
# Make the mask red where segmentation is present
mask_rgb[:, :, 0] = 0 # Zero out Blue channel
mask_rgb[:, :, 1] = 0 # Zero out Green channel
# Where mask_rgb is red (255), keep original image pixel, otherwise blend
overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0)
# Highlight only the segmented area more distinctly
highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized)
overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted
# Return original, mask (resized), overlay
# Gradio expects NumPy arrays
#return input_image_rgb, display_mask_resized, overlay_image
return display_mask_resized, overlay_image
# --- Gradio Interface ---
print("Launching Gradio Interface...")
with gr.Blocks() as demo:
gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)")
gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.")
with gr.Row(): # Creates a horizontal container
inp = gr.Image(type="numpy", label="Input Image")
out_mask = gr.Image(type="numpy", label="Segmentation Mask")
out_overlay = gr.Image(type="numpy", label="Overlay")
# Hook up the function
inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay])
# (Optional) add example images
# gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]],
# inputs=inp, outputs=[out_mask, out_overlay])
# Disable flagging
if __name__ == "__main__":
demo.launch(share=True) # Share=True to create public link |