LPX55 commited on
Commit
6c176fb
·
verified ·
1 Parent(s): 0149a5e

Update sam2_mask.py

Browse files
Files changed (1) hide show
  1. sam2_mask.py +11 -19
sam2_mask.py CHANGED
@@ -12,42 +12,33 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
-
16
- def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
17
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
18
-
19
  tracking_points.value.append(evt.index)
20
- print(f"TRACKING POINT: {tracking_points.value}")
21
-
22
  if point_type == "include":
23
  trackings_input_label.value.append(1)
24
  elif point_type == "exclude":
25
  trackings_input_label.value.append(0)
26
- print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
27
-
28
  # Open the image and get its dimensions
29
  transparent_background = Image.open(first_frame_path).convert('RGBA')
30
  w, h = transparent_background.size
31
-
32
  # Define the circle radius as a fraction of the smaller dimension
33
  fraction = 0.02 # You can adjust this value as needed
34
  radius = int(fraction * min(w, h))
35
-
36
  # Create a transparent layer to draw on
37
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
38
-
39
  for index, track in enumerate(tracking_points.value):
40
  if trackings_input_label.value[index] == 1:
41
  cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
42
  else:
43
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
44
-
45
  # Convert the transparent layer back to an image
46
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
47
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
48
-
49
  return tracking_points, trackings_input_label, selected_point_map
50
-
51
  def show_mask(mask, ax, random_color=False, borders=True):
52
  if random_color:
53
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
@@ -107,26 +98,27 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
107
  mask_images.append(mask_filename)
108
  plt.close() # Close the figure to free up memory
109
  return combined_images, mask_images
110
-
111
  @spaces.GPU()
112
  def sam_process(original_image, points, labels):
 
 
113
  # Convert image to numpy array for SAM2 processing
114
  image = np.array(original_image)
115
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
116
  predictor.set_image(image)
117
  input_point = np.array(points)
118
  input_label = np.array(labels)
119
- masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
 
 
 
120
  sorted_indices = np.argsort(scores)[::-1]
121
  masks = masks[sorted_indices]
122
-
123
  # Generate mask image
124
  mask = masks[0] * 255
125
  mask_image = Image.fromarray(mask.astype(np.uint8))
126
  return mask_image
127
- # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
128
- # predictor = SAM2ImagePredictor(sam2_model)
129
-
130
 
131
  def create_sam2_tab():
132
  first_frame = gr.State() # Tracks original image
 
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
+ def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
 
16
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
 
17
  tracking_points.value.append(evt.index)
18
+ print(f"TRACKING POINTS: {tracking_points.value}")
 
19
  if point_type == "include":
20
  trackings_input_label.value.append(1)
21
  elif point_type == "exclude":
22
  trackings_input_label.value.append(0)
23
+ print(f"TRACKING INPUT LABELS: {trackings_input_label.value}")
 
24
  # Open the image and get its dimensions
25
  transparent_background = Image.open(first_frame_path).convert('RGBA')
26
  w, h = transparent_background.size
 
27
  # Define the circle radius as a fraction of the smaller dimension
28
  fraction = 0.02 # You can adjust this value as needed
29
  radius = int(fraction * min(w, h))
 
30
  # Create a transparent layer to draw on
31
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
 
32
  for index, track in enumerate(tracking_points.value):
33
  if trackings_input_label.value[index] == 1:
34
  cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
35
  else:
36
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
 
37
  # Convert the transparent layer back to an image
38
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
39
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
 
40
  return tracking_points, trackings_input_label, selected_point_map
41
+
42
  def show_mask(mask, ax, random_color=False, borders=True):
43
  if random_color:
44
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
 
98
  mask_images.append(mask_filename)
99
  plt.close() # Close the figure to free up memory
100
  return combined_images, mask_images
101
+
102
  @spaces.GPU()
103
  def sam_process(original_image, points, labels):
104
+ print(f"Points: {points}")
105
+ print(f"Labels: {labels}")
106
  # Convert image to numpy array for SAM2 processing
107
  image = np.array(original_image)
108
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
109
  predictor.set_image(image)
110
  input_point = np.array(points)
111
  input_label = np.array(labels)
112
+ if not input_point.size or not input_label.size:
113
+ print("No points or labels provided, returning None")
114
+ return None
115
+ masks, scores, _= predictor.predict(input_point, input_label, multimask_output=False)
116
  sorted_indices = np.argsort(scores)[::-1]
117
  masks = masks[sorted_indices]
 
118
  # Generate mask image
119
  mask = masks[0] * 255
120
  mask_image = Image.fromarray(mask.astype(np.uint8))
121
  return mask_image
 
 
 
122
 
123
  def create_sam2_tab():
124
  first_frame = gr.State() # Tracks original image