Spaces:
Running
on
Zero
Running
on
Zero
Update sam2_mask.py
Browse files- sam2_mask.py +32 -19
sam2_mask.py
CHANGED
@@ -12,28 +12,41 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
12 |
|
13 |
def preprocess_image(image):
|
14 |
return image, gr.State([]), gr.State([]), image
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
#
|
22 |
-
|
23 |
-
radius = int(min(w, h)
|
24 |
-
|
25 |
-
draw
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
30 |
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
torch.backends.cudnn.allow_tf32 = True
|
37 |
|
38 |
def show_mask(mask, ax, random_color=False, borders=True):
|
39 |
if random_color:
|
|
|
12 |
|
13 |
def preprocess_image(image):
|
14 |
return image, gr.State([]), gr.State([]), image
|
15 |
+
|
16 |
+
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
17 |
+
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
18 |
+
|
19 |
+
tracking_points.value.append(evt.index)
|
20 |
+
print(f"TRACKING POINT: {tracking_points.value}")
|
21 |
|
22 |
+
if point_type == "include":
|
23 |
+
trackings_input_label.value.append(1)
|
24 |
+
elif point_type == "exclude":
|
25 |
+
trackings_input_label.value.append(0)
|
26 |
+
print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
|
27 |
+
|
28 |
+
# Open the image and get its dimensions
|
29 |
+
transparent_background = Image.open(first_frame_path).convert('RGBA')
|
30 |
+
w, h = transparent_background.size
|
31 |
|
32 |
+
# Define the circle radius as a fraction of the smaller dimension
|
33 |
+
fraction = 0.02 # You can adjust this value as needed
|
34 |
+
radius = int(fraction * min(w, h))
|
35 |
+
|
36 |
+
# Create a transparent layer to draw on
|
37 |
+
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
|
38 |
+
|
39 |
+
for index, track in enumerate(tracking_points.value):
|
40 |
+
if trackings_input_label.value[index] == 1:
|
41 |
+
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
|
42 |
+
else:
|
43 |
+
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
|
44 |
|
45 |
+
# Convert the transparent layer back to an image
|
46 |
+
transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
|
47 |
+
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
|
48 |
+
|
49 |
+
return tracking_points, trackings_input_label, selected_point_map
|
|
|
50 |
|
51 |
def show_mask(mask, ax, random_color=False, borders=True):
|
52 |
if random_color:
|