import cv2 import torch from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2_video_predictor import sam2 from PIL import Image import os import numpy as np import matplotlib.pyplot as plt import argparse def area(mask): if mask.size == 0: return 0 return np.count_nonzero(mask) / mask.size def show_mask(mask, ax, obj_id=None, random_color=False, borders = True, alpha=0.5): if random_color: color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0) else: color = np.array([30/255, 144/255, 255/255, alpha]) if not random_color and obj_id is not None: color = np.array([*plt.get_cmap("tab10")(obj_id)[:3], alpha]) h, w = mask.shape[-2:] mask = mask.astype(np.uint8) mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) if borders: import cv2 contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # Try to smooth contours contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) ax.imshow(mask_image) def area(mask): if mask.size == 0: return 0 return np.count_nonzero(mask) / mask.size def nms_bbox_removal(boxes_xyxy, iou_thresh=0.25 ): remove_indices = [] for i, box in enumerate(boxes_xyxy): for j in range(i+1, len(boxes_xyxy)): box2 = boxes_xyxy[j] iou1 = compute_iou(box, box2) iou2 = compute_iou(box2, box) if iou1 > iou_thresh or iou2 > iou_thresh: if iou1 > iou2: remove_indices.append(j) else: remove_indices.append(i) return [box for i, box in enumerate(boxes_xyxy) if i not in remove_indices] def load_SAM2(ckpt_path, model_cfg_path): if torch.cuda.is_available(): print("Using CUDA") device = "cuda" else: print("CUDA device not found, using CPU instead") device = "cpu" sam2 = build_sam2(model_cfg_path, ckpt_path, device=device, apply_postprocessing=False) return sam2 def compute_iou(box1, box2): # intersection / area of box1 x1, y1, x2, y2 = box1 x3, y3, x4, y4 = box2 x5, y5 = max(x1, x3), max(y1, y3) x6, y6 = min(x2, x4), min(y2, y4) if x5 >= x6 or y5 >= y6: return 0 intersection = (x6 - x5) * (y6 - y5) union = (x2 - x1) * (y2 - y1) return intersection / union def show_anns(anns, color=None, borders=True): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img = np.ones((sorted_anns[0]['segmentation'].squeeze().shape[0], sorted_anns[0]['segmentation'].squeeze().shape[1], 4)) img[:, :, 3] = 0 for ann in sorted_anns: m = ann['segmentation'].squeeze() if color is None: color_mask = np.concatenate([np.random.random(3), [0.75]]) else: color_mask = color img[m] = color_mask if borders: import cv2 contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # Try to smooth contours contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=2) ax.imshow(img) def build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_large.pt", model_cfg="sam2_hiera_l"): device = "cuda" if torch.cuda.is_available() else "cpu" video_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device, apply_postprocessing=False) return video_predictor def load_masks(video_predictor, query_images, support_image, support_masks, offload_video_to_cpu=True, offload_state_to_cpu=True, verbose=False): ''' video_predictor: sam2 predictor query_images: list of np.array of shape (H, W, 3) support_image: np.array of shape (H, W, 3) support_masks: list of np.array of shape (H, W) offload_video_to_cpu: for long video sequences, offload the video to the CPU to save GPU memory offload_state_to_cpu: save GPU memory by offloading the state to the CPU ''' query_images.insert(0, support_image) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = video_predictor.init_state(None, image_inputs=query_images, async_loading_frames=False, offload_video_to_cpu=offload_video_to_cpu, offload_state_to_cpu=offload_state_to_cpu, verbose=verbose) video_predictor.reset_state(state) for i, patch_mask in enumerate(support_masks): ann_frame_idx = 0 ann_obj_id = i # give a unique id to each object we interact with patch_mask = np.array(patch_mask, dtype=np.uint8) patch_mask = cv2.resize(patch_mask, (1024, 1024)) _, _, _ = video_predictor.add_new_mask( inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, mask=patch_mask, ) return state def propagate_masks(video_predictor, state, verbose=False): """ returns: list[dict] with keys 'obj_ids', 'segmentation', 'area' list['segmentation']: np.array of shape (H, W) with dtype bool """ frame_info = [] # run propagation throughout the video and collect the results in a dict with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): for _, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(state, verbose=verbose): out_mask_logits = (out_mask_logits>0).cpu().numpy().squeeze() if out_mask_logits.ndim == 2: out_mask_logits = np.expand_dims(out_mask_logits, axis=0) frame_info.append({'obj_ids': out_obj_ids, 'segmentation': out_mask_logits, 'area': area(out_mask_logits)}) return frame_info def show_video_masks(image, frame_info): img_resized = cv2.resize(image, (1024, 1024)) plt.imshow(img_resized) for obj_ids, mask in zip(frame_info['obj_ids'], frame_info['masks']): mask = cv2.resize(mask.astype(np.uint8), (1024, 1024)) show_mask(mask, plt.gca(), obj_id=obj_ids, borders=True, alpha=0.75) plt.axis('off') plt.show() def get_parser(inputs): parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") parser.add_argument( "--config-file", default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--opts", help="Modify config options using the command-line 'KEY VALUE' pairs", default=[], nargs=argparse.REMAINDER, ) args = parser.parse_args(inputs) return args def auto_segment_SAM(boxes_xyxy, img, iou_thresh=0.9, stability_score_thresh=0.95, min_mask_region_area=10000, verbose=False): checkpoint = "../../checkpoints/sam2_hiera_large.pt" model_cfg = "../../sam2_configs/sam2_hiera_l.yaml" sam2 = load_SAM2(checkpoint, model_cfg) auto_mask_predictor = SAM2AutomaticMaskGenerator(sam2, points_per_batch=128, pred_iou_thresh=iou_thresh, stability_score_thresh=stability_score_thresh, min_mask_region_area=min_mask_region_area, multimask_output=True) masks_list = [] for box_xyxy in boxes_xyxy: wing = img[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])] mask = auto_mask_predictor.generate(wing) # for mask_ # dict in mask: # mask_dict['segmentation'] = np.bitwise_not(mask_dict['segmentation']) if verbose: plt.imshow(wing) show_anns(mask) # remove axis plt.axis('off') plt.show() # translate the mask to the original image binary_masks = [e['segmentation'] for e in mask] for e in binary_masks: new_mask = np.zeros((img.shape[0], img.shape[1]), dtype=bool) new_mask[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])] = e new_mask_dict = { 'segmentation': new_mask, 'area': area(new_mask) } masks_list.append(new_mask_dict) return masks_list def show_masks(masks_list, img, verbose=True, imshow=True, grey=False): if imshow: if grey: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) plt.imshow(img, cmap='gray') else: plt.imshow(img) plt.axis('off') show_anns(masks_list) if verbose: plt.show() def show_individual_masks(masks_list, img): for mask in masks_list: plt.imshow(img) plt.axis('off') show_anns([mask]) plt.show()