Johnyquest7 commited on
Commit
6ddf239
·
verified ·
1 Parent(s): efc5be1

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +3 -9
  2. app.py +147 -0
  3. requirements.txt +104 -0
  4. unet_derm_final_model.pth +3 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Derm MaskHG
3
- emoji: 🔥
4
- colorFrom: red
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.32.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: derm_maskHG
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.27.0
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2 # Using OpenCV for image loading/processing
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ import gradio as gr
8
+
9
+ import segmentation_models_pytorch as smp
10
+ from train_unet import UNetLitModule # Import the Lightning Module definition
11
+
12
+ # --- Configuration ---
13
+ # Option 1: Load from the Lightning Checkpoint
14
+ # CHECKPOINT_PATH = "checkpoints/unet-derm-epoch=XX-val_iou=Y.YYYY.ckpt" # Find the best checkpoint path from training output
15
+ # Option 2: Load from the saved state_dict
16
+ MODEL_STATE_DICT_PATH = "unet_derm_final_model.pth"
17
+ IMG_SIZE = (256, 256) # MUST match training image size
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # --- Load Model ---
21
+ print(f"Loading model from: {MODEL_STATE_DICT_PATH}")
22
+ print(f"Using device: {DEVICE}")
23
+
24
+ # Instantiate the base SMP model architecture
25
+ model = smp.Unet(
26
+ encoder_name="resnet34",
27
+ encoder_weights=None, # Don't load pretrained weights, we load our trained ones
28
+ in_channels=3,
29
+ classes=1,
30
+ )
31
+
32
+ # Load the state dict saved at the end of training
33
+ try:
34
+ state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)
35
+ # If the state_dict was saved directly from the `model.model` attribute in LitModule:
36
+ model.load_state_dict(state_dict)
37
+ # If you saved the entire Lightning Module state_dict, you might need to extract the model part:
38
+ # state_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=DEVICE)['state_dict']
39
+ # # Filter keys if they have a prefix like 'model.'
40
+ # state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')}
41
+ # model.load_state_dict(state_dict)
42
+
43
+ except FileNotFoundError:
44
+ print(f"Error: Model file not found at {MODEL_STATE_DICT_PATH}")
45
+ print("Please ensure the training script ran successfully and the path is correct.")
46
+ exit()
47
+ except Exception as e:
48
+ print(f"Error loading model state_dict: {e}")
49
+ print("Ensure the saved state_dict matches the current model architecture.")
50
+ exit()
51
+
52
+
53
+ model.to(DEVICE)
54
+ model.eval() # Set model to evaluation mode (disables dropout, batchnorm updates)
55
+
56
+ # --- Inference Transforms ---
57
+ # Should match the validation/test transforms from training (resize, normalize)
58
+ inference_transform = A.Compose([
59
+ A.Resize(height=IMG_SIZE[0], width=IMG_SIZE[1]),
60
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
61
+ ToTensorV2(),
62
+ ])
63
+
64
+ # --- Segmentation Function ---
65
+ def segment_image(input_image_np):
66
+ """
67
+ Takes a NumPy image, performs segmentation, and returns images for Gradio.
68
+ """
69
+ # 0. Input validation
70
+ if input_image_np is None:
71
+ return None, None, None
72
+
73
+ # Ensure image is RGB (Gradio might provide BGR or grayscale)
74
+ if len(input_image_np.shape) == 2: # Grayscale
75
+ input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_GRAY2RGB)
76
+ elif input_image_np.shape[2] == 4: # RGBA
77
+ input_image_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGBA2RGB)
78
+ # Assume BGR if 3 channels, convert to RGB for consistency with training
79
+ # input_image_rgb = cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB) # PIL/Gradio usually loads RGB
80
+ input_image_rgb = input_image_np.copy()
81
+
82
+
83
+ # 1. Preprocess the image
84
+ transformed = inference_transform(image=input_image_rgb)
85
+ image_tensor = transformed['image'].unsqueeze(0).to(DEVICE) # Add batch dim and send to device
86
+
87
+ # 2. Perform inference
88
+ with torch.no_grad():
89
+ logits = model(image_tensor) # Output is [1, 1, H, W] logits
90
+ # Apply sigmoid to get probabilities [0, 1]
91
+ probabilities = torch.sigmoid(logits)
92
+
93
+ # 3. Postprocess the output mask
94
+ # Remove batch dimension, move to CPU, convert to NumPy
95
+ mask_pred_np = probabilities.squeeze().cpu().numpy() # Shape: [H, W]
96
+
97
+ # Threshold probabilities to get binary mask (0 or 1)
98
+ binary_mask_np = (mask_pred_np > 0.5).astype(np.uint8)
99
+
100
+ # Convert binary mask to a displayable format (e.g., 0 or 255)
101
+ display_mask = (binary_mask_np * 255) # Shape: [H, W]
102
+
103
+ # Resize mask back to original image size for overlay (optional, better overlay quality)
104
+ orig_h, orig_w = input_image_rgb.shape[:2]
105
+ display_mask_resized = cv2.resize(display_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
106
+
107
+ # 4. Create Overlay
108
+ # Convert single-channel mask to 3 channels to overlay on RGB image
109
+ mask_rgb = cv2.cvtColor(display_mask_resized, cv2.COLOR_GRAY2RGB)
110
+ # Make the mask red where segmentation is present
111
+ mask_rgb[:, :, 0] = 0 # Zero out Blue channel
112
+ mask_rgb[:, :, 1] = 0 # Zero out Green channel
113
+ # Where mask_rgb is red (255), keep original image pixel, otherwise blend
114
+ overlay_image = cv2.addWeighted(input_image_rgb, 0.7, mask_rgb, 0.3, 0)
115
+ # Highlight only the segmented area more distinctly
116
+ highlighted_area = cv2.bitwise_and(input_image_rgb, input_image_rgb, mask=display_mask_resized)
117
+ overlay_image = cv2.addWeighted(input_image_rgb, 0.7, highlighted_area, 0.9, 0) # Blend original with highlighted
118
+
119
+ # Return original, mask (resized), overlay
120
+ # Gradio expects NumPy arrays
121
+ #return input_image_rgb, display_mask_resized, overlay_image
122
+ return display_mask_resized, overlay_image
123
+
124
+
125
+ # --- Gradio Interface ---
126
+ print("Launching Gradio Interface...")
127
+
128
+ with gr.Blocks() as demo:
129
+ gr.Markdown("# Dermatology Image Segmentation (UNet ResNet34)")
130
+ gr.Markdown("Upload a dermatology image to see the predicted segmentation mask using a trained UNet model.")
131
+
132
+ with gr.Row(): # Creates a horizontal container
133
+ inp = gr.Image(type="numpy", label="Input Image")
134
+ out_mask = gr.Image(type="numpy", label="Segmentation Mask")
135
+ out_overlay = gr.Image(type="numpy", label="Overlay")
136
+
137
+ # Hook up the function
138
+ inp.change(fn=segment_image, inputs=inp, outputs=[out_mask, out_overlay])
139
+
140
+ # (Optional) add example images
141
+ # gr.Examples(examples=[["examples/img1.jpg"], ["examples/img2.jpg"]],
142
+ # inputs=inp, outputs=[out_mask, out_overlay])
143
+
144
+ # Disable flagging
145
+
146
+ if __name__ == "__main__":
147
+ demo.launch(share=True) # Share=True to create public link
requirements.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.2.2
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.18
5
+ aiosignal==1.3.2
6
+ albucore==0.0.23
7
+ albumentations==2.0.5
8
+ annotated-types==0.7.0
9
+ anyio==4.9.0
10
+ attrs==25.3.0
11
+ certifi==2025.4.26
12
+ charset-normalizer==3.4.1
13
+ click==8.1.8
14
+ fastapi==0.115.12
15
+ ffmpy==0.5.0
16
+ filelock==3.18.0
17
+ frozenlist==1.6.0
18
+ fsspec==2025.3.2
19
+ gradio==5.27.0
20
+ gradio_client==1.9.0
21
+ groovy==0.1.2
22
+ grpcio==1.71.0
23
+ h11==0.16.0
24
+ httpcore==1.0.9
25
+ httpx==0.28.1
26
+ huggingface-hub==0.30.2
27
+ idna==3.10
28
+ Jinja2==3.1.6
29
+ lightning-utilities==0.14.3
30
+ Markdown==3.8
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==3.0.2
33
+ mdurl==0.1.2
34
+ mpmath==1.3.0
35
+ multidict==6.4.3
36
+ networkx==3.4.2
37
+ numpy==2.2.5
38
+ nvidia-cublas-cu12==12.6.4.1
39
+ nvidia-cuda-cupti-cu12==12.6.80
40
+ nvidia-cuda-nvrtc-cu12==12.6.77
41
+ nvidia-cuda-runtime-cu12==12.6.77
42
+ nvidia-cudnn-cu12==9.5.1.17
43
+ nvidia-cufft-cu12==11.3.0.4
44
+ nvidia-cufile-cu12==1.11.1.6
45
+ nvidia-curand-cu12==10.3.7.77
46
+ nvidia-cusolver-cu12==11.7.1.2
47
+ nvidia-cusparse-cu12==12.5.4.2
48
+ nvidia-cusparselt-cu12==0.6.3
49
+ nvidia-nccl-cu12==2.26.2
50
+ nvidia-nvjitlink-cu12==12.6.85
51
+ nvidia-nvtx-cu12==12.6.77
52
+ opencv-python==4.11.0.86
53
+ opencv-python-headless==4.11.0.86
54
+ orjson==3.10.16
55
+ packaging==25.0
56
+ pandas==2.2.3
57
+ pillow==11.2.1
58
+ propcache==0.3.1
59
+ protobuf==6.30.2
60
+ pydantic==2.11.3
61
+ pydantic_core==2.33.1
62
+ pydub==0.25.1
63
+ Pygments==2.19.1
64
+ python-dateutil==2.9.0.post0
65
+ python-multipart==0.0.20
66
+ pytorch-lightning==2.5.1.post0
67
+ pytz==2025.2
68
+ PyYAML==6.0.2
69
+ requests==2.32.3
70
+ rich==14.0.0
71
+ ruff==0.11.7
72
+ safehttpx==0.1.6
73
+ safetensors==0.5.3
74
+ scipy==1.15.2
75
+ segmentation_models_pytorch==0.5.0
76
+ semantic-version==2.10.0
77
+ setuptools==75.8.0
78
+ shellingham==1.5.4
79
+ simsimd==6.2.1
80
+ six==1.17.0
81
+ sniffio==1.3.1
82
+ starlette==0.46.2
83
+ stringzilla==3.12.5
84
+ sympy==1.14.0
85
+ tensorboard==2.19.0
86
+ tensorboard-data-server==0.7.2
87
+ timm==1.0.15
88
+ tomlkit==0.13.2
89
+ torch==2.7.0
90
+ torchaudio==2.7.0
91
+ torchmetrics==1.7.1
92
+ torchvision==0.22.0
93
+ tqdm==4.67.1
94
+ triton==3.3.0
95
+ typer==0.15.2
96
+ typing-inspection==0.4.0
97
+ typing_extensions==4.13.2
98
+ tzdata==2025.2
99
+ urllib3==2.4.0
100
+ uvicorn==0.34.2
101
+ websockets==15.0.1
102
+ Werkzeug==3.1.3
103
+ wheel==0.45.1
104
+ yarl==1.20.0
unet_derm_final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc76b7c54afe131b0de98db1daee33bc5c5e573729e51c38d7f8adfe1d3d0ce0
3
+ size 97923355