Johnyquest7 commited on
Commit
821b618
·
verified ·
1 Parent(s): 8e3a6e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -146
app.py CHANGED
@@ -1,147 +1,148 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import cv2 # Using OpenCV for image loading/processing
5
- import albumentations as A
6
- from albumentations.pytorch import ToTensorV2
7
- import gradio as gr
8
-
9
- import segmentation_models_pytorch as smp
10
- from train_unet import UNetLitModule # Import the Lightning Module definition
11
-
12
- # --- Configuration ---
13
- # Option 1: Load from the Lightning Checkpoint
14
- # CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output
15
- # Option 2: Load from the saved state_dict
16
- MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth"
17
- IMG_SIZE = (256, 256) # MUST match training image size
18
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
-
20
- # --- Load Model ---
21
- print(f"Loading model from: {MODEL_STATE_DICT_PATH}")
22
- print(f"Using device: {DEVICE}")
23
-
24
- # Instantiate the base SMP model architecture
25
- model = smp.Unet(
26
- encoder_name="resnet34",
27
- encoder_weights=None, # Don't load pretrained weights, we load our trained ones
28
- in_channels=3,
29
- classes=1,
30
- )
31
-
32
- # Load the state dict saved at the end of training
33
- try:
34
- state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)
35
- # If the state_dict was saved directly from the `model.model` attribute in LitModule:
36
- model.load_state_dict(state_dict)
37
- # If you saved the entire Lightning Module state_dict, you might need to extract the model part:
38
- # state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict']
39
- # # Filter keys if they have a prefix like 'model.'
40
- # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
41
- # model.load_state_dict(state_dict)
42
-
43
- except FileNotFoundError:
44
- print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}")
45
- print("Please ensure the training script ran successfully and the path is correct.")
46
- exit()
47
- except Exception as e:
48
- print(f"Error loading model state_dict: {e}")
49
- print("Ensure the saved state_dict matches the current model architecture.")
50
- exit()
51
-
52
-
53
- model.to(DEVICE)
54
- model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates)
55
-
56
- # --- Inference Transforms ---
57
- # Should match the validation/test transforms from training (resize, normalize)
58
- inference_transform = A.Compose([
59
- A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
60
- A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
61
- ToTensorV2(),
62
- ])
63
-
64
- # --- Segmentation Function ---
65
- def segment_image(input_image_np):
66
- """
67
- Takes a NumPy image, performs segmentation, and returns images for Gradio.
68
- """
69
- # 0. Input validation
70
- if input_image_np is None:
71
- return None, None, None
72
-
73
- # Ensure image is RGB (Gradio might provide BGR or grayscale)
74
- if len(input_image_np.shape) == 2: # Grayscale
75
- input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
76
- elif input_image_np.shape[2] == 4: # RGBA
77
- input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB)
78
- # Assume BGR if 3 channels, convert to RGB for consistency with training
79
- # input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB
80
- input_image_rgb = input_image_np.copy()
81
-
82
-
83
- # 1. Preprocess the image
84
- transformed = inference_transform(image=input_image_rgb)
85
- image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device
86
-
87
- # 2. Perform inference
88
- with torch.no_grad():
89
- logits = model(image_tensor) # Output is [1, 1, H, W] logits
90
- # Apply sigmoid to get probabilities [0, 1]
91
- probabilities = torch.sigmoid(logits)
92
-
93
- # 3. Postprocess the output mask
94
- # Remove batch dimension, move to CPU, convert to NumPy
95
- mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W]
96
-
97
- # Threshold probabilities to get binary mask (0 or 1)
98
- binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8)
99
-
100
- # Convert binary mask to a displayable format (e.g., 0 or 255)
101
- display_mask = (binary_mask_np * 255) # Shape: [H, W]
102
-
103
- # Resize mask back to original image size for overlay (optional, better overlay quality)
104
- orig_h, orig_w = input_image_rgb.shape[:2]
105
- display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
106
-
107
- # 4. Create Overlay
108
- # Convert single-channel mask to 3 channels to overlay on RGB image
109
- mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB)
110
- # Make the mask red where segmentation is present
111
- mask_rgb[:, :, 0] = 0 # Zero out Blue channel
112
- mask_rgb[:, :, 1] = 0 # Zero out Green channel
113
- # Where mask_rgb is red (255), keep original image pixel, otherwise blend
114
- overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0)
115
- # Highlight only the segmented area more distinctly
116
- highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized)
117
- overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted
118
-
119
- # Return original, mask (resized), overlay
120
- # Gradio expects NumPy arrays
121
- #return input_image_rgb, display_mask_resized, overlay_image
122
- return display_mask_resized, overlay_image
123
-
124
-
125
- # --- Gradio Interface ---
126
- print("Launching Gradio Interface...")
127
-
128
- with gr.Blocks() as demo:
129
- gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)")
130
- gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.")
131
-
132
- with gr.Row(): # Creates a horizontal container
133
- inp = gr.Image(type="numpy", label="Input Image")
134
- out_mask = gr.Image(type="numpy", label="Segmentation Mask")
135
- out_overlay = gr.Image(type="numpy", label="Overlay")
136
-
137
- # Hook up the function
138
- inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay])
139
-
140
- # (Optional) add example images
141
- # gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]],
142
- # inputs=inp, outputs=[out_mask, out_overlay])
143
-
144
- # Disable flagging
145
-
146
- if __name__ == "__main__":
 
147
  demo.launch(share=True) # Share=True to create public link
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2 # Using OpenCV for image loading/processing
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ import gradio as gr
8
+
9
+ import segmentation_models_pytorch as smp
10
+ from train_unet import UNetLitModule # Import the Lightning Module definition
11
+
12
+ # --- Configuration ---
13
+ # Option 1: Load from the Lightning Checkpoint
14
+ # CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output
15
+ # Option 2: Load from the saved state_dict
16
+ MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth"
17
+ IMG_SIZE = (256, 256) # MUST match training image size
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # --- Load Model ---
21
+ print(f"Loading model from: {MODEL_STATE_DICT_PATH}")
22
+ print(f"Using device: {DEVICE}")
23
+
24
+ # Instantiate the base SMP model architecture
25
+ model = smp.Unet(
26
+ encoder_name="resnet34",
27
+ encoder_weights=None, # Don't load pretrained weights, we load our trained ones
28
+ in_channels=3,
29
+ classes=1,
30
+ )
31
+
32
+ # Load the state dict saved at the end of training
33
+ try:
34
+ state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)
35
+ # If the state_dict was saved directly from the `model.model` attribute in LitModule:
36
+ model.load_state_dict(state_dict)
37
+ # If you saved the entire Lightning Module state_dict, you might need to extract the model part:
38
+ # state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict']
39
+ # # Filter keys if they have a prefix like 'model.'
40
+ # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
41
+ # model.load_state_dict(state_dict)
42
+
43
+ except FileNotFoundError:
44
+ print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}")
45
+ print("Please ensure the training script ran successfully and the path is correct.")
46
+ exit()
47
+ except Exception as e:
48
+ print(f"Error loading model state_dict: {e}")
49
+ print("Ensure the saved state_dict matches the current model architecture.")
50
+ exit()
51
+
52
+
53
+ model.to(DEVICE)
54
+ model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates)
55
+
56
+ # --- Inference Transforms ---
57
+ # Should match the validation/test transforms from training (resize, normalize)
58
+ inference_transform = A.Compose([
59
+ A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
60
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
61
+ ToTensorV2(),
62
+ ])
63
+
64
+ # --- Segmentation Function ---
65
+ @spaces.GPU
66
+ def segment_image(input_image_np):
67
+ """
68
+ Takes a NumPy image, performs segmentation, and returns images for Gradio.
69
+ """
70
+ # 0. Input validation
71
+ if input_image_np is None:
72
+ return None, None, None
73
+
74
+ # Ensure image is RGB (Gradio might provide BGR or grayscale)
75
+ if len(input_image_np.shape) == 2: # Grayscale
76
+ input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
77
+ elif input_image_np.shape[2] == 4: # RGBA
78
+ input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB)
79
+ # Assume BGR if 3 channels, convert to RGB for consistency with training
80
+ # input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB
81
+ input_image_rgb = input_image_np.copy()
82
+
83
+
84
+ # 1. Preprocess the image
85
+ transformed = inference_transform(image=input_image_rgb)
86
+ image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device
87
+
88
+ # 2. Perform inference
89
+ with torch.no_grad():
90
+ logits = model(image_tensor) # Output is [1, 1, H, W] logits
91
+ # Apply sigmoid to get probabilities [0, 1]
92
+ probabilities = torch.sigmoid(logits)
93
+
94
+ # 3. Postprocess the output mask
95
+ # Remove batch dimension, move to CPU, convert to NumPy
96
+ mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W]
97
+
98
+ # Threshold probabilities to get binary mask (0 or 1)
99
+ binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8)
100
+
101
+ # Convert binary mask to a displayable format (e.g., 0 or 255)
102
+ display_mask = (binary_mask_np * 255) # Shape: [H, W]
103
+
104
+ # Resize mask back to original image size for overlay (optional, better overlay quality)
105
+ orig_h, orig_w = input_image_rgb.shape[:2]
106
+ display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
107
+
108
+ # 4. Create Overlay
109
+ # Convert single-channel mask to 3 channels to overlay on RGB image
110
+ mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB)
111
+ # Make the mask red where segmentation is present
112
+ mask_rgb[:, :, 0] = 0 # Zero out Blue channel
113
+ mask_rgb[:, :, 1] = 0 # Zero out Green channel
114
+ # Where mask_rgb is red (255), keep original image pixel, otherwise blend
115
+ overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0)
116
+ # Highlight only the segmented area more distinctly
117
+ highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized)
118
+ overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted
119
+
120
+ # Return original, mask (resized), overlay
121
+ # Gradio expects NumPy arrays
122
+ #return input_image_rgb, display_mask_resized, overlay_image
123
+ return display_mask_resized, overlay_image
124
+
125
+
126
+ # --- Gradio Interface ---
127
+ print("Launching Gradio Interface...")
128
+
129
+ with gr.Blocks() as demo:
130
+ gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)")
131
+ gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.")
132
+
133
+ with gr.Row(): # Creates a horizontal container
134
+ inp = gr.Image(type="numpy", label="Input Image")
135
+ out_mask = gr.Image(type="numpy", label="Segmentation Mask")
136
+ out_overlay = gr.Image(type="numpy", label="Overlay")
137
+
138
+ # Hook up the function
139
+ inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay])
140
+
141
+ # (Optional) add example images
142
+ # gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]],
143
+ # inputs=inp, outputs=[out_mask, out_overlay])
144
+
145
+ # Disable flagging
146
+
147
+ if __name__ == "__main__":
148
  demo.launch(share=True) # Share=True to create public link