File size: 6,228 Bytes
821b618
 
 
 
 
 
 
deef786
821b618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ddf239
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import torch
import numpy as np
import cv2 # Using OpenCV for image loading/processing
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gradio as gr
import spaces

import segmentation_models_pytorch as smp
from train_unet import UNetLitModule # Import the Lightning Module definition

# --- Configuration ---
# Option 1: Load from the Lightning Checkpoint
# CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output
# Option 2: Load from the saved state_dict
MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth"
IMG_SIZE = (256, 256) # MUST match training image size
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load Model ---
print(f"Loading model from: {MODEL_STATE_DICT_PATH}")
print(f"Using device: {DEVICE}")

# Instantiate the base SMP model architecture
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None, # Don't load pretrained weights, we load our trained ones
    in_channels=3,
    classes=1,
)

# Load the state dict saved at the end of training
try:
    state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)
    # If the state_dict was saved directly from the `model.model` attribute in LitModule:
    model.load_state_dict(state_dict)
    # If you saved the entire Lightning Module state_dict, you might need to extract the model part:
    # state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict']
    # # Filter keys if they have a prefix like 'model.'
    # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
    # model.load_state_dict(state_dict)

except FileNotFoundError:
    print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}")
    print("Please ensure the training script ran successfully and the path is correct.")
    exit()
except Exception as e:
    print(f"Error loading model state_dict: {e}")
    print("Ensure the saved state_dict matches the current model architecture.")
    exit()


model.to(DEVICE)
model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates)

# --- Inference Transforms ---
# Should match the validation/test transforms from training (resize, normalize)
inference_transform = A.Compose([
    A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# --- Segmentation Function ---
@spaces.GPU
def segment_image(input_image_np):
    """
    Takes a NumPy image, performs segmentation, and returns images for Gradio.
    """
    # 0. Input validation
    if input_image_np is None:
        return None, None, None

    # Ensure image is RGB (Gradio might provide BGR or grayscale)
    if len(input_image_np.shape) == 2: # Grayscale
        input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
    elif input_image_np.shape[2] == 4: # RGBA
        input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB)
    # Assume BGR if 3 channels, convert to RGB for consistency with training
    # input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB
    input_image_rgb = input_image_np.copy()


    # 1. Preprocess the image
    transformed = inference_transform(image=input_image_rgb)
    image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device

    # 2. Perform inference
    with torch.no_grad():
        logits = model(image_tensor) # Output is [1, 1, H, W] logits
        # Apply sigmoid to get probabilities [0, 1]
        probabilities = torch.sigmoid(logits)

    # 3. Postprocess the output mask
    # Remove batch dimension, move to CPU, convert to NumPy
    mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W]

    # Threshold probabilities to get binary mask (0 or 1)
    binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8)

    # Convert binary mask to a displayable format (e.g., 0 or 255)
    display_mask = (binary_mask_np * 255) # Shape: [H, W]

    # Resize mask back to original image size for overlay (optional, better overlay quality)
    orig_h, orig_w = input_image_rgb.shape[:2]
    display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)

    # 4. Create Overlay
    # Convert single-channel mask to 3 channels to overlay on RGB image
    mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB)
    # Make the mask red where segmentation is present
    mask_rgb[:, :, 0] = 0 # Zero out Blue channel
    mask_rgb[:, :, 1] = 0 # Zero out Green channel
    # Where mask_rgb is red (255), keep original image pixel, otherwise blend
    overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0)
    # Highlight only the segmented area more distinctly
    highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized)
    overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted

    # Return original, mask (resized), overlay
    # Gradio expects NumPy arrays
    #return input_image_rgb, display_mask_resized, overlay_image
    return display_mask_resized, overlay_image


# --- Gradio Interface ---
print("Launching Gradio Interface...")

with gr.Blocks() as demo:
    gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)")
    gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.")
    
    with gr.Row():            # Creates a horizontal container
        inp = gr.Image(type="numpy", label="Input Image")
        out_mask = gr.Image(type="numpy", label="Segmentation Mask")
        out_overlay = gr.Image(type="numpy", label="Overlay")
    
    # Hook up the function
    inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay])
    
    # (Optional) add example images
    # gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]], 
    #             inputs=inp, outputs=[out_mask, out_overlay])
    
    # Disable flagging
    
if __name__ == "__main__":
    demo.launch(share=True) # Share=True to create public link