LPX55 commited on
Commit
d8b9a0f
·
1 Parent(s): 0d5a836
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -69,17 +69,25 @@ def predict_masks(image, points):
69
  return image # Return the original image if no points are selected
70
  PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
71
 
72
- image_np = np.array(image)
73
- points_list = [[point["x"], point["y"]] for point in points]
 
 
 
 
 
 
 
74
  input_labels = [1] * len(points_list)
75
 
76
  with torch.inference_mode():
77
- PREDICTOR.set_image(image_np)
78
  masks, _, _ = PREDICTOR.predict(
79
  point_coords=points_list, point_labels=input_labels, multimask_output=False
80
  )
81
 
82
  # Prepare the overlay image
 
83
  red_mask = np.zeros_like(image_np)
84
  if masks and len(masks) > 0:
85
  red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
@@ -90,7 +98,6 @@ def predict_masks(image, points):
90
  else:
91
  return image_np
92
 
93
-
94
  def update_mask(prompts):
95
  """Update the mask based on the prompts."""
96
  image = prompts["image"]
 
69
  return image # Return the original image if no points are selected
70
  PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
71
 
72
+ # Debugging: Print the structure of points
73
+ print(f"Points structure: {points}")
74
+
75
+ # Ensure points is a list of lists with at least two elements
76
+ if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
77
+ points_list = [[point[0], point[1]] for point in points]
78
+ else:
79
+ return image # Return the original image if points structure is unexpected
80
+
81
  input_labels = [1] * len(points_list)
82
 
83
  with torch.inference_mode():
84
+ PREDICTOR.set_image(np.array(image))
85
  masks, _, _ = PREDICTOR.predict(
86
  point_coords=points_list, point_labels=input_labels, multimask_output=False
87
  )
88
 
89
  # Prepare the overlay image
90
+ image_np = np.array(image)
91
  red_mask = np.zeros_like(image_np)
92
  if masks and len(masks) > 0:
93
  red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
 
98
  else:
99
  return image_np
100
 
 
101
  def update_mask(prompts):
102
  """Update the mask based on the prompts."""
103
  image = prompts["image"]