LPX55 commited on
Commit
0149a5e
·
verified ·
1 Parent(s): 93bfbab

Update sam2_mask.py

Browse files
Files changed (1) hide show
  1. sam2_mask.py +32 -19
sam2_mask.py CHANGED
@@ -12,28 +12,41 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
 
 
 
 
 
 
15
 
16
- def get_point(point_type, tracking_points, trackings_input_label, original_image, evt: gr.SelectData):
17
- x, y = evt.index
18
- tracking_points.append((x, y))
19
- trackings_input_label.append(1 if point_type == "include" else 0)
 
 
 
 
 
20
 
21
- # Redraw all points on original image
22
- w, h = original_image.size
23
- radius = int(min(w, h) * 0.02)
24
- img = original_image.convert("RGBA")
25
- draw = ImageDraw.Draw(img)
26
- for i, (cx, cy) in enumerate(tracking_points):
27
- color = (0, 255, 0, 255) if trackings_input_label[i] == 1 else (255, 0, 0, 255)
28
- draw.ellipse([cx-radius, cy-radius, cx+radius, cy+radius], fill=color)
29
- return tracking_points, trackings_input_label, img
 
 
 
30
 
31
- # use bfloat16 for the entire notebook
32
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
33
- if torch.cuda.get_device_properties(0).major >= 8:
34
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
35
- torch.backends.cuda.matmul.allow_tf32 = True
36
- torch.backends.cudnn.allow_tf32 = True
37
 
38
  def show_mask(mask, ax, random_color=False, borders=True):
39
  if random_color:
 
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
+
16
+ def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
17
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
18
+
19
+ tracking_points.value.append(evt.index)
20
+ print(f"TRACKING POINT: {tracking_points.value}")
21
 
22
+ if point_type == "include":
23
+ trackings_input_label.value.append(1)
24
+ elif point_type == "exclude":
25
+ trackings_input_label.value.append(0)
26
+ print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
27
+
28
+ # Open the image and get its dimensions
29
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
30
+ w, h = transparent_background.size
31
 
32
+ # Define the circle radius as a fraction of the smaller dimension
33
+ fraction = 0.02 # You can adjust this value as needed
34
+ radius = int(fraction * min(w, h))
35
+
36
+ # Create a transparent layer to draw on
37
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
38
+
39
+ for index, track in enumerate(tracking_points.value):
40
+ if trackings_input_label.value[index] == 1:
41
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
42
+ else:
43
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
44
 
45
+ # Convert the transparent layer back to an image
46
+ transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
47
+ selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
48
+
49
+ return tracking_points, trackings_input_label, selected_point_map
 
50
 
51
  def show_mask(mask, ax, random_color=False, borders=True):
52
  if random_color: