LPX55 commited on
Commit
7a13abb
·
1 Parent(s): 3a8ed1d

kiss: simplifying everything

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. sam2_mask.py +54 -196
app.py CHANGED
@@ -9,10 +9,11 @@ from controlnet_union import ControlNetModel_Union
9
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
10
  from PIL import Image, ImageDraw
11
  import numpy as np
12
- from sam2_mask import create_sam2_tab, sam_process
13
 
14
  #from sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
 
16
 
17
  MODELS = {
18
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
@@ -486,9 +487,8 @@ with gr.Blocks(css=css, fill_height=True) as demo:
486
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
487
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
488
  preview_image = gr.Image(label="Preview")
489
- with gr.TabItem("SAM2 Masking"):
490
- input_image, points_map, output_result_mask = create_sam2_tab()
491
-
492
  with gr.TabItem("Misc"):
493
  with gr.Column():
494
  clear_cache_button = gr.Button("Clear CUDA Cache")
 
9
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
10
  from PIL import Image, ImageDraw
11
  import numpy as np
12
+ from sam2_mask import create_sam2_mask_interface
13
 
14
  #from sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
16
+ sam2_mask_tab = create_sam2_mask_interface()
17
 
18
  MODELS = {
19
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
 
487
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
488
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
489
  preview_image = gr.Image(label="Preview")
490
+ with gr.TabItem("SAM2 Mask"):
491
+ sam2_mask_tab
 
492
  with gr.TabItem("Misc"):
493
  with gr.Column():
494
  clear_cache_button = gr.Button("Clear CUDA Cache")
sam2_mask.py CHANGED
@@ -1,204 +1,62 @@
1
- import spaces
 
2
  import gradio as gr
3
- import os
4
- os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
5
  import torch
6
  import numpy as np
7
- import cv2
8
- import matplotlib.pyplot as plt
9
- from PIL import Image, ImageFilter
10
- from sam2.build_sam import build_sam2
11
- from sam2.sam2_image_predictor import SAM2ImagePredictor
12
- from gradio_image_prompter import ImagePrompter
13
-
14
- def preprocess_image(image):
15
- return image, gr.State([]), gr.State([]), image
16
-
17
- def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
18
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
19
- tracking_points.append(evt.index)
20
- print(f"TRACKING POINTS: {tracking_points}")
21
- if point_type == "include":
22
- trackings_input_label.append(1)
23
- elif point_type == "exclude":
24
- trackings_input_label.append(0)
25
- print(f"TRACKING INPUT LABELS: {trackings_input_label}")
26
- # Open the image and get its dimensions
27
- transparent_background = Image.open(first_frame_path).convert('RGBA')
28
- w, h = transparent_background.size
29
- # Define the circle radius as a fraction of the smaller dimension
30
- fraction = 0.02 # You can adjust this value as needed
31
- radius = int(fraction * min(w, h))
32
- # Create a transparent layer to draw on
33
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
34
- for index, track in enumerate(tracking_points):
35
- if trackings_input_label[index] == 1:
36
- cv2.circle(transparent_layer, tuple(track), radius, (0, 255, 0, 255), -1)
37
- else:
38
- cv2.circle(transparent_layer, tuple(track), radius, (255, 0, 0, 255), -1)
39
- # Convert the transparent layer back to an image
40
- transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
41
- selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
42
- return tracking_points, trackings_input_label, selected_point_map
43
-
44
- # use bfloat16 for the entire notebook
45
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
46
-
47
- if torch.cuda.get_device_properties(0).major >= 8:
48
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
49
- torch.backends.cuda.matmul.allow_tf32 = True
50
- torch.backends.cudnn.allow_tf32 = True
51
-
52
- def show_mask(mask, ax, random_color=False, borders=True):
53
- if random_color:
54
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
55
  else:
56
- color = np.array([30/255, 144/255, 255/255, 0.6])
57
- h, w = mask.shape[-2:]
58
- mask = mask.astype(np.uint8)
59
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
60
- if borders:
61
- import cv2
62
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
63
- # Try to smooth contours
64
- contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
65
- mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
66
- ax.imshow(mask_image)
67
-
68
- def show_points(coords, labels, ax, marker_size=375):
69
- pos_points = coords[labels == 1]
70
- neg_points = coords[labels == 0]
71
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
72
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
73
-
74
- def show_box(box, ax):
75
- x0, y0 = box[0], box[1]
76
- w, h = box[2] - box[0], box[3] - box[1]
77
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
78
-
79
- def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
80
- combined_images = [] # List to store filenames of images with masks overlaid
81
- mask_images = [] # List to store filenames of separate mask images
82
- for i, (mask, score) in enumerate(zip(masks, scores)):
83
- # ---- Original Image with Mask Overlaid ----
84
- plt.figure(figsize=(10, 10))
85
- plt.imshow(image)
86
- show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
87
- if box_coords is not None:
88
- show_box(box_coords, plt.gca())
89
- if len(scores) > 1:
90
- plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
91
- plt.axis('off')
92
- # Save the figure as a JPG file
93
- combined_filename = f"combined_image_{i+1}.jpg"
94
- plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
95
- combined_images.append(combined_filename)
96
- plt.close() # Close the figure to free up memory
97
- # ---- Separate Mask Image (White Mask on Black Background) ----
98
- # Create a black image
99
- mask_image = np.zeros_like(image, dtype=np.uint8)
100
- # The mask is a binary array where the masked area is 1, else 0.
101
- # Convert the mask to a white color in the mask_image
102
- mask_layer = (mask > 0).astype(np.uint8) * 255
103
- for c in range(3): # Assuming RGB, repeat mask for all channels
104
- mask_image[:, :, c] = mask_layer
105
- # Save the mask image
106
- mask_filename = f"mask_image_{i+1}.png"
107
- Image.fromarray(mask_image).save(mask_filename)
108
- mask_images.append(mask_filename)
109
- plt.close() # Close the figure to free up memory
110
- return combined_images, mask_images
111
-
112
- @spaces.GPU()
113
- def sam_process(original_image, points, labels):
114
-
115
- print(f"Points: {points}")
116
- print(f"Labels: {labels}")
117
- image = Image.open(original_image)
118
- image = np.array(image.convert("RGB"))
119
-
120
- if not points or not labels:
121
- print("No points or labels provided, returning None")
122
- return None
123
- # Convert image to numpy array for SAM2 processing
124
- # image = np.array(original_image)
125
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
126
- predictor.set_image(image)
127
- input_point = np.array(points.value)
128
- input_label = np.array(labels.value)
129
-
130
- print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
131
-
132
- masks, scores, logits = predictor.predict(
133
- point_coords=input_point,
134
- point_labels=input_label,
135
- multimask_output=False,
136
- )
137
- sorted_indices = np.argsort(scores)[::-1]
138
- masks = masks[sorted_indices]
139
- scores = scores[sorted_indices]
140
- logits = logits[sorted_indices]
141
- print(masks.shape)
142
-
143
- results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
144
- print(results)
145
-
146
- return results[0], mask_results[0]
147
-
148
- def create_sam2_tab():
149
- first_frame = gr.State() # Tracks original image
150
- tracking_points = gr.State([])
151
- trackings_input_label = gr.State([])
152
-
153
- with gr.Column():
154
  with gr.Row():
155
  with gr.Column():
156
- sam_input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
157
- img_prompter = ImagePrompter(show_label=False)
158
- points_map = gr.Image(
159
- label="points map",
160
- type="filepath",
161
- interactive=True
162
- )
163
- with gr.Row():
164
- point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
165
- clear_button = gr.Button("Clear Points")
166
  submit_button = gr.Button("Submit")
167
-
168
  with gr.Column():
169
- output_image = gr.Image("Segmented Output")
170
- output_result_mask = gr.Image()
171
- prompted_output = gr.Image(show_label=False)
172
- prompted_data = gr.Dataframe(label="Points")
173
-
174
- # lambda prompts: (prompts["image"], prompts["points"]),
175
- # Event handlers
176
- points_map.upload(
177
- fn = preprocess_image,
178
- inputs = [points_map],
179
- # outputs=[sam_input_image, first_frame, tracking_points, trackings_input_label],
180
- outputs = [first_frame, tracking_points, trackings_input_label, sam_input_image],
181
- queue=False
182
- )
183
-
184
- clear_button.click(
185
- lambda img: ([], [], img),
186
- inputs=first_frame,
187
- outputs=[tracking_points, trackings_input_label, points_map],
188
- queue=False
189
- )
190
-
191
- points_map.select(
192
- get_point,
193
- inputs=[point_type, tracking_points, trackings_input_label, first_frame],
194
- outputs=[tracking_points, trackings_input_label, points_map],
195
- queue = False
196
- )
197
-
198
- submit_button.click(
199
- sam_process,
200
- inputs=[sam_input_image, tracking_points, trackings_input_label],
201
- outputs = [output_image, output_result_mask]
202
- )
203
-
204
- return sam_input_image, points_map, output_image
 
1
+ # K-I-S-S
2
+
3
  import gradio as gr
4
+ from gradio_image_prompter import ImagePrompter
5
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
6
  import torch
7
  import numpy as np
8
+ from PIL import Image as PILImage
9
+
10
+ # Initialize SAM2 predictor
11
+ MODEL = "facebook/sam2-hiera-large"
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
14
+
15
+ def predict_masks(image, points):
16
+ """Predict a single mask from the image based on selected points."""
17
+ image_np = np.array(image)
18
+ points_list = [[point["x"], point["y"]] for point in points]
19
+ input_labels = [1] * len(points_list)
20
+
21
+ with torch.inference_mode():
22
+ PREDICTOR.set_image(image_np)
23
+ masks, _, _ = PREDICTOR.predict(
24
+ point_coords=points_list, point_labels=input_labels, multimask_output=False
25
+ )
26
+
27
+ # Prepare the overlay image
28
+ red_mask = np.zeros_like(image_np)
29
+ if masks and len(masks) > 0:
30
+ red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
31
+ red_mask = PILImage.fromarray(red_mask)
32
+ original_image = PILImage.fromarray(image_np)
33
+ blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
34
+ return np.array(blended_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  else:
36
+ return image_np
37
+
38
+ def create_sam2_mask_interface():
39
+ """Create the Gradio interface for SAM2 mask generation."""
40
+ with gr.Blocks() as sam2_mask_tab:
41
+ gr.Markdown("# Object Segmentation with SAM2")
42
+ gr.Markdown(
43
+ """
44
+ This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image.
45
+ """
46
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with gr.Row():
48
  with gr.Column():
49
+ upload_image_input = ImagePrompter(show_label=False)
 
 
 
 
 
 
 
 
 
50
  submit_button = gr.Button("Submit")
 
51
  with gr.Column():
52
+ image_output = gr.Image(label="Segmented Image", type="pil").style(height=400)
53
+
54
+ # Define the action triggered by the submit button
55
+ submit_button.click(
56
+ fn=predict_masks,
57
+ inputs=[upload_image_input.image, upload_image_input.points],
58
+ outputs=image_output,
59
+ show_progress=True,
60
+ )
61
+
62
+ return sam2_mask_tab