Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import cv2 # Using OpenCV for image loading/processing | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import gradio as gr | |
import spaces | |
import segmentation_models_pytorch as smp | |
from train_unet import UNetLitModule # Import the Lightning Module definition | |
# --- Configuration --- | |
# Option 1: Load from the Lightning Checkpoint | |
# CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output | |
# Option 2: Load from the saved state_dict | |
MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth" | |
IMG_SIZE = (256, 256) # MUST match training image size | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# --- Load Model --- | |
print(f"Loading model from: {MODEL_STATE_DICT_PATH}") | |
print(f"Using device: {DEVICE}") | |
# Instantiate the base SMP model architecture | |
model = smp.Unet( | |
encoder_name="resnet34", | |
encoder_weights=None, # Don't load pretrained weights, we load our trained ones | |
in_channels=3, | |
classes=1, | |
) | |
# Load the state dict saved at the end of training | |
try: | |
state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE) | |
# If the state_dict was saved directly from the `model.model` attribute in LitModule: | |
model.load_state_dict(state_dict) | |
# If you saved the entire Lightning Module state_dict, you might need to extract the model part: | |
# state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict'] | |
# # Filter keys if they have a prefix like 'model.' | |
# state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} | |
# model.load_state_dict(state_dict) | |
except FileNotFoundError: | |
print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}") | |
print("Please ensure the training script ran successfully and the path is correct.") | |
exit() | |
except Exception as e: | |
print(f"Error loading model state_dict: {e}") | |
print("Ensure the saved state_dict matches the current model architecture.") | |
exit() | |
model.to(DEVICE) | |
model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates) | |
# --- Inference Transforms --- | |
# Should match the validation/test transforms from training (resize, normalize) | |
inference_transform = A.Compose([ | |
A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]), | |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
ToTensorV2(), | |
]) | |
# --- Segmentation Function --- | |
def segment_image(input_image_np): | |
""" | |
Takes a NumPy image, performs segmentation, and returns images for Gradio. | |
""" | |
# 0. Input validation | |
if input_image_np is None: | |
return None, None, None | |
# Ensure image is RGB (Gradio might provide BGR or grayscale) | |
if len(input_image_np.shape) == 2: # Grayscale | |
input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB) | |
elif input_image_np.shape[2] == 4: # RGBA | |
input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB) | |
# Assume BGR if 3 channels, convert to RGB for consistency with training | |
# input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB | |
input_image_rgb = input_image_np.copy() | |
# 1. Preprocess the image | |
transformed = inference_transform(image=input_image_rgb) | |
image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device | |
# 2. Perform inference | |
with torch.no_grad(): | |
logits = model(image_tensor) # Output is [1, 1, H, W] logits | |
# Apply sigmoid to get probabilities [0, 1] | |
probabilities = torch.sigmoid(logits) | |
# 3. Postprocess the output mask | |
# Remove batch dimension, move to CPU, convert to NumPy | |
mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W] | |
# Threshold probabilities to get binary mask (0 or 1) | |
binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8) | |
# Convert binary mask to a displayable format (e.g., 0 or 255) | |
display_mask = (binary_mask_np * 255) # Shape: [H, W] | |
# Resize mask back to original image size for overlay (optional, better overlay quality) | |
orig_h, orig_w = input_image_rgb.shape[:2] | |
display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
# 4. Create Overlay | |
# Convert single-channel mask to 3 channels to overlay on RGB image | |
mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB) | |
# Make the mask red where segmentation is present | |
mask_rgb[:, :, 0] = 0 # Zero out Blue channel | |
mask_rgb[:, :, 1] = 0 # Zero out Green channel | |
# Where mask_rgb is red (255), keep original image pixel, otherwise blend | |
overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0) | |
# Highlight only the segmented area more distinctly | |
highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized) | |
overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted | |
# Return original, mask (resized), overlay | |
# Gradio expects NumPy arrays | |
#return input_image_rgb, display_mask_resized, overlay_image | |
return display_mask_resized, overlay_image | |
# --- Gradio Interface --- | |
print("Launching Gradio Interface...") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)") | |
gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.") | |
with gr.Row(): # Creates a horizontal container | |
inp = gr.Image(type="numpy", label="Input Image") | |
out_mask = gr.Image(type="numpy", label="Segmentation Mask") | |
out_overlay = gr.Image(type="numpy", label="Overlay") | |
# Hook up the function | |
inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay]) | |
# (Optional) add example images | |
# gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]], | |
# inputs=inp, outputs=[out_mask, out_overlay]) | |
# Disable flagging | |
if __name__ == "__main__": | |
demo.launch(share=True) # Share=True to create public link |