Spaces:
Running
on
Zero
Running
on
Zero
Update sam2_mask.py
Browse files- sam2_mask.py +11 -19
sam2_mask.py
CHANGED
@@ -12,42 +12,33 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
12 |
|
13 |
def preprocess_image(image):
|
14 |
return image, gr.State([]), gr.State([]), image
|
15 |
-
|
16 |
-
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
17 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
18 |
-
|
19 |
tracking_points.value.append(evt.index)
|
20 |
-
print(f"TRACKING
|
21 |
-
|
22 |
if point_type == "include":
|
23 |
trackings_input_label.value.append(1)
|
24 |
elif point_type == "exclude":
|
25 |
trackings_input_label.value.append(0)
|
26 |
-
print(f"TRACKING INPUT
|
27 |
-
|
28 |
# Open the image and get its dimensions
|
29 |
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
30 |
w, h = transparent_background.size
|
31 |
-
|
32 |
# Define the circle radius as a fraction of the smaller dimension
|
33 |
fraction = 0.02 # You can adjust this value as needed
|
34 |
radius = int(fraction * min(w, h))
|
35 |
-
|
36 |
# Create a transparent layer to draw on
|
37 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
38 |
-
|
39 |
for index, track in enumerate(tracking_points.value):
|
40 |
if trackings_input_label.value[index] == 1:
|
41 |
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
42 |
else:
|
43 |
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
44 |
-
|
45 |
# Convert the transparent layer back to an image
|
46 |
transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
|
47 |
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
|
48 |
-
|
49 |
return tracking_points, trackings_input_label, selected_point_map
|
50 |
-
|
51 |
def show_mask(mask, ax, random_color=False, borders=True):
|
52 |
if random_color:
|
53 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
@@ -107,26 +98,27 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
|
|
107 |
mask_images.append(mask_filename)
|
108 |
plt.close() # Close the figure to free up memory
|
109 |
return combined_images, mask_images
|
110 |
-
|
111 |
@spaces.GPU()
|
112 |
def sam_process(original_image, points, labels):
|
|
|
|
|
113 |
# Convert image to numpy array for SAM2 processing
|
114 |
image = np.array(original_image)
|
115 |
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
116 |
predictor.set_image(image)
|
117 |
input_point = np.array(points)
|
118 |
input_label = np.array(labels)
|
119 |
-
|
|
|
|
|
|
|
120 |
sorted_indices = np.argsort(scores)[::-1]
|
121 |
masks = masks[sorted_indices]
|
122 |
-
|
123 |
# Generate mask image
|
124 |
mask = masks[0] * 255
|
125 |
mask_image = Image.fromarray(mask.astype(np.uint8))
|
126 |
return mask_image
|
127 |
-
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
128 |
-
# predictor = SAM2ImagePredictor(sam2_model)
|
129 |
-
|
130 |
|
131 |
def create_sam2_tab():
|
132 |
first_frame = gr.State() # Tracks original image
|
|
|
12 |
|
13 |
def preprocess_image(image):
|
14 |
return image, gr.State([]), gr.State([]), image
|
15 |
+
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
|
|
16 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
|
17 |
tracking_points.value.append(evt.index)
|
18 |
+
print(f"TRACKING POINTS: {tracking_points.value}")
|
|
|
19 |
if point_type == "include":
|
20 |
trackings_input_label.value.append(1)
|
21 |
elif point_type == "exclude":
|
22 |
trackings_input_label.value.append(0)
|
23 |
+
print(f"TRACKING INPUT LABELS: {trackings_input_label.value}")
|
|
|
24 |
# Open the image and get its dimensions
|
25 |
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
26 |
w, h = transparent_background.size
|
|
|
27 |
# Define the circle radius as a fraction of the smaller dimension
|
28 |
fraction = 0.02 # You can adjust this value as needed
|
29 |
radius = int(fraction * min(w, h))
|
|
|
30 |
# Create a transparent layer to draw on
|
31 |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
|
|
32 |
for index, track in enumerate(tracking_points.value):
|
33 |
if trackings_input_label.value[index] == 1:
|
34 |
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
35 |
else:
|
36 |
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
|
|
37 |
# Convert the transparent layer back to an image
|
38 |
transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
|
39 |
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
|
|
|
40 |
return tracking_points, trackings_input_label, selected_point_map
|
41 |
+
|
42 |
def show_mask(mask, ax, random_color=False, borders=True):
|
43 |
if random_color:
|
44 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
98 |
mask_images.append(mask_filename)
|
99 |
plt.close() # Close the figure to free up memory
|
100 |
return combined_images, mask_images
|
101 |
+
|
102 |
@spaces.GPU()
|
103 |
def sam_process(original_image, points, labels):
|
104 |
+
print(f"Points: {points}")
|
105 |
+
print(f"Labels: {labels}")
|
106 |
# Convert image to numpy array for SAM2 processing
|
107 |
image = np.array(original_image)
|
108 |
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
109 |
predictor.set_image(image)
|
110 |
input_point = np.array(points)
|
111 |
input_label = np.array(labels)
|
112 |
+
if not input_point.size or not input_label.size:
|
113 |
+
print("No points or labels provided, returning None")
|
114 |
+
return None
|
115 |
+
masks, scores, _= predictor.predict(input_point, input_label, multimask_output=False)
|
116 |
sorted_indices = np.argsort(scores)[::-1]
|
117 |
masks = masks[sorted_indices]
|
|
|
118 |
# Generate mask image
|
119 |
mask = masks[0] * 255
|
120 |
mask_image = Image.fromarray(mask.astype(np.uint8))
|
121 |
return mask_image
|
|
|
|
|
|
|
122 |
|
123 |
def create_sam2_tab():
|
124 |
first_frame = gr.State() # Tracks original image
|