SST / sam_utils.py
Daniel-F's picture
demo v1
ccb7bbe
import cv2
import torch
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
import sam2
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import argparse
def area(mask):
if mask.size == 0: return 0
return np.count_nonzero(mask) / mask.size
def show_mask(mask, ax, obj_id=None, random_color=False, borders = True, alpha=0.5):
if random_color:
color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, alpha])
if not random_color and obj_id is not None:
color = np.array([*plt.get_cmap("tab10")(obj_id)[:3], alpha])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def area(mask):
if mask.size == 0: return 0
return np.count_nonzero(mask) / mask.size
def nms_bbox_removal(boxes_xyxy, iou_thresh=0.25 ):
remove_indices = []
for i, box in enumerate(boxes_xyxy):
for j in range(i+1, len(boxes_xyxy)):
box2 = boxes_xyxy[j]
iou1 = compute_iou(box, box2)
iou2 = compute_iou(box2, box)
if iou1 > iou_thresh or iou2 > iou_thresh:
if iou1 > iou2:
remove_indices.append(j)
else:
remove_indices.append(i)
return [box for i, box in enumerate(boxes_xyxy) if i not in remove_indices]
def load_SAM2(ckpt_path, model_cfg_path):
if torch.cuda.is_available():
print("Using CUDA")
device = "cuda"
else:
print("CUDA device not found, using CPU instead")
device = "cpu"
sam2 = build_sam2(model_cfg_path, ckpt_path, device=device, apply_postprocessing=False)
return sam2
def compute_iou(box1, box2):
# intersection / area of box1
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
x5, y5 = max(x1, x3), max(y1, y3)
x6, y6 = min(x2, x4), min(y2, y4)
if x5 >= x6 or y5 >= y6:
return 0
intersection = (x6 - x5) * (y6 - y5)
union = (x2 - x1) * (y2 - y1)
return intersection / union
def show_anns(anns, color=None, borders=True):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].squeeze().shape[0], sorted_anns[0]['segmentation'].squeeze().shape[1], 4))
img[:, :, 3] = 0
for ann in sorted_anns:
m = ann['segmentation'].squeeze()
if color is None:
color_mask = np.concatenate([np.random.random(3), [0.75]])
else:
color_mask = color
img[m] = color_mask
if borders:
import cv2
contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=2)
ax.imshow(img)
def build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_large.pt", model_cfg="sam2_hiera_l"):
device = "cuda" if torch.cuda.is_available() else "cpu"
video_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device, apply_postprocessing=False)
return video_predictor
def load_masks(video_predictor, query_images, support_image, support_masks, offload_video_to_cpu=True, offload_state_to_cpu=True, verbose=False):
'''
video_predictor: sam2 predictor
query_images: list of np.array of shape (H, W, 3)
support_image: np.array of shape (H, W, 3)
support_masks: list of np.array of shape (H, W)
offload_video_to_cpu: for long video sequences, offload the video to the CPU to save GPU memory
offload_state_to_cpu: save GPU memory by offloading the state to the CPU
'''
query_images.insert(0, support_image)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = video_predictor.init_state(None, image_inputs=query_images, async_loading_frames=False, offload_video_to_cpu=offload_video_to_cpu, offload_state_to_cpu=offload_state_to_cpu, verbose=verbose)
video_predictor.reset_state(state)
for i, patch_mask in enumerate(support_masks):
ann_frame_idx = 0
ann_obj_id = i # give a unique id to each object we interact with
patch_mask = np.array(patch_mask, dtype=np.uint8)
patch_mask = cv2.resize(patch_mask, (1024, 1024))
_, _, _ = video_predictor.add_new_mask(
inference_state=state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
mask=patch_mask,
)
return state
def propagate_masks(video_predictor, state, verbose=False):
"""
returns: list[dict] with keys 'obj_ids', 'segmentation', 'area'
list['segmentation']: np.array of shape (H, W) with dtype bool
"""
frame_info = []
# run propagation throughout the video and collect the results in a dict
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
for _, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(state, verbose=verbose):
out_mask_logits = (out_mask_logits>0).cpu().numpy().squeeze()
if out_mask_logits.ndim == 2:
out_mask_logits = np.expand_dims(out_mask_logits, axis=0)
frame_info.append({'obj_ids': out_obj_ids, 'segmentation': out_mask_logits, 'area': area(out_mask_logits)})
return frame_info
def show_video_masks(image, frame_info):
img_resized = cv2.resize(image, (1024, 1024))
plt.imshow(img_resized)
for obj_ids, mask in zip(frame_info['obj_ids'], frame_info['masks']):
mask = cv2.resize(mask.astype(np.uint8), (1024, 1024))
show_mask(mask, plt.gca(), obj_id=obj_ids, borders=True, alpha=0.75)
plt.axis('off')
plt.show()
def get_parser(inputs):
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
args = parser.parse_args(inputs)
return args
def auto_segment_SAM(boxes_xyxy, img, iou_thresh=0.9, stability_score_thresh=0.95, min_mask_region_area=10000, verbose=False):
checkpoint = "../../checkpoints/sam2_hiera_large.pt"
model_cfg = "../../sam2_configs/sam2_hiera_l.yaml"
sam2 = load_SAM2(checkpoint, model_cfg)
auto_mask_predictor = SAM2AutomaticMaskGenerator(sam2,
points_per_batch=128,
pred_iou_thresh=iou_thresh,
stability_score_thresh=stability_score_thresh,
min_mask_region_area=min_mask_region_area,
multimask_output=True)
masks_list = []
for box_xyxy in boxes_xyxy:
wing = img[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])]
mask = auto_mask_predictor.generate(wing)
# for mask_
# dict in mask:
# mask_dict['segmentation'] = np.bitwise_not(mask_dict['segmentation'])
if verbose:
plt.imshow(wing)
show_anns(mask)
# remove axis
plt.axis('off')
plt.show()
# translate the mask to the original image
binary_masks = [e['segmentation'] for e in mask]
for e in binary_masks:
new_mask = np.zeros((img.shape[0], img.shape[1]), dtype=bool)
new_mask[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])] = e
new_mask_dict = {
'segmentation': new_mask,
'area': area(new_mask)
}
masks_list.append(new_mask_dict)
return masks_list
def show_masks(masks_list, img, verbose=True, imshow=True, grey=False):
if imshow:
if grey:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
plt.imshow(img, cmap='gray')
else:
plt.imshow(img)
plt.axis('off')
show_anns(masks_list)
if verbose:
plt.show()
def show_individual_masks(masks_list, img):
for mask in masks_list:
plt.imshow(img)
plt.axis('off')
show_anns([mask])
plt.show()