Spaces:
Running
on
Zero
Running
on
Zero
app.py
CHANGED
@@ -69,17 +69,25 @@ def predict_masks(image, points):
|
|
69 |
return image # Return the original image if no points are selected
|
70 |
PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
|
71 |
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
input_labels = [1] * len(points_list)
|
75 |
|
76 |
with torch.inference_mode():
|
77 |
-
PREDICTOR.set_image(
|
78 |
masks, _, _ = PREDICTOR.predict(
|
79 |
point_coords=points_list, point_labels=input_labels, multimask_output=False
|
80 |
)
|
81 |
|
82 |
# Prepare the overlay image
|
|
|
83 |
red_mask = np.zeros_like(image_np)
|
84 |
if masks and len(masks) > 0:
|
85 |
red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
|
@@ -90,7 +98,6 @@ def predict_masks(image, points):
|
|
90 |
else:
|
91 |
return image_np
|
92 |
|
93 |
-
|
94 |
def update_mask(prompts):
|
95 |
"""Update the mask based on the prompts."""
|
96 |
image = prompts["image"]
|
|
|
69 |
return image # Return the original image if no points are selected
|
70 |
PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
|
71 |
|
72 |
+
# Debugging: Print the structure of points
|
73 |
+
print(f"Points structure: {points}")
|
74 |
+
|
75 |
+
# Ensure points is a list of lists with at least two elements
|
76 |
+
if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
|
77 |
+
points_list = [[point[0], point[1]] for point in points]
|
78 |
+
else:
|
79 |
+
return image # Return the original image if points structure is unexpected
|
80 |
+
|
81 |
input_labels = [1] * len(points_list)
|
82 |
|
83 |
with torch.inference_mode():
|
84 |
+
PREDICTOR.set_image(np.array(image))
|
85 |
masks, _, _ = PREDICTOR.predict(
|
86 |
point_coords=points_list, point_labels=input_labels, multimask_output=False
|
87 |
)
|
88 |
|
89 |
# Prepare the overlay image
|
90 |
+
image_np = np.array(image)
|
91 |
red_mask = np.zeros_like(image_np)
|
92 |
if masks and len(masks) > 0:
|
93 |
red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
|
|
|
98 |
else:
|
99 |
return image_np
|
100 |
|
|
|
101 |
def update_mask(prompts):
|
102 |
"""Update the mask based on the prompts."""
|
103 |
image = prompts["image"]
|