derm_maskHG / app.py
Johnyquest7's picture
Update app.py
deef786 verified
import os
import torch
import numpy as np
import cv2 # Using OpenCV for image loading/processing
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gradio as gr
import spaces
import segmentation_models_pytorch as smp
from train_unet import UNetLitModule # Import the Lightning Module definition
# --- Configuration ---
# Option 1: Load from the Lightning Checkpoint
# CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output
# Option 2: Load from the saved state_dict
MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth"
IMG_SIZE = (256, 256) # MUST match training image size
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load Model ---
print(f"Loading model from: {MODEL_STATE_DICT_PATH}")
print(f"Using device: {DEVICE}")
# Instantiate the base SMP model architecture
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None, # Don't load pretrained weights, we load our trained ones
in_channels=3,
classes=1,
)
# Load the state dict saved at the end of training
try:
state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)
# If the state_dict was saved directly from the `model.model` attribute in LitModule:
model.load_state_dict(state_dict)
# If you saved the entire Lightning Module state_dict, you might need to extract the model part:
# state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict']
# # Filter keys if they have a prefix like 'model.'
# state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
# model.load_state_dict(state_dict)
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}")
print("Please ensure the training script ran successfully and the path is correct.")
exit()
except Exception as e:
print(f"Error loading model state_dict: {e}")
print("Ensure the saved state_dict matches the current model architecture.")
exit()
model.to(DEVICE)
model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates)
# --- Inference Transforms ---
# Should match the validation/test transforms from training (resize, normalize)
inference_transform = A.Compose([
A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
# --- Segmentation Function ---
@spaces.GPU
def segment_image(input_image_np):
"""
Takes a NumPy image, performs segmentation, and returns images for Gradio.
"""
# 0. Input validation
if input_image_np is None:
return None, None, None
# Ensure image is RGB (Gradio might provide BGR or grayscale)
if len(input_image_np.shape) == 2: # Grayscale
input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
elif input_image_np.shape[2] == 4: # RGBA
input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB)
# Assume BGR if 3 channels, convert to RGB for consistency with training
# input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB
input_image_rgb = input_image_np.copy()
# 1. Preprocess the image
transformed = inference_transform(image=input_image_rgb)
image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device
# 2. Perform inference
with torch.no_grad():
logits = model(image_tensor) # Output is [1, 1, H, W] logits
# Apply sigmoid to get probabilities [0, 1]
probabilities = torch.sigmoid(logits)
# 3. Postprocess the output mask
# Remove batch dimension, move to CPU, convert to NumPy
mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W]
# Threshold probabilities to get binary mask (0 or 1)
binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8)
# Convert binary mask to a displayable format (e.g., 0 or 255)
display_mask = (binary_mask_np * 255) # Shape: [H, W]
# Resize mask back to original image size for overlay (optional, better overlay quality)
orig_h, orig_w = input_image_rgb.shape[:2]
display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
# 4. Create Overlay
# Convert single-channel mask to 3 channels to overlay on RGB image
mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB)
# Make the mask red where segmentation is present
mask_rgb[:, :, 0] = 0 # Zero out Blue channel
mask_rgb[:, :, 1] = 0 # Zero out Green channel
# Where mask_rgb is red (255), keep original image pixel, otherwise blend
overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0)
# Highlight only the segmented area more distinctly
highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized)
overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted
# Return original, mask (resized), overlay
# Gradio expects NumPy arrays
#return input_image_rgb, display_mask_resized, overlay_image
return display_mask_resized, overlay_image
# --- Gradio Interface ---
print("Launching Gradio Interface...")
with gr.Blocks() as demo:
gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)")
gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.")
with gr.Row(): # Creates a horizontal container
inp = gr.Image(type="numpy", label="Input Image")
out_mask = gr.Image(type="numpy", label="Segmentation Mask")
out_overlay = gr.Image(type="numpy", label="Overlay")
# Hook up the function
inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay])
# (Optional) add example images
# gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]],
# inputs=inp, outputs=[out_mask, out_overlay])
# Disable flagging
if __name__ == "__main__":
demo.launch(share=True) # Share=True to create public link