Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import numpy as np | |
| import torch | |
| from torchvision.ops.boxes import batched_nms, box_area # type: ignore | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from .modeling import Sam | |
| from .predictor import SamPredictor | |
| from .utils.amg import ( | |
| MaskData, | |
| area_from_rle, | |
| batch_iterator, | |
| batched_mask_to_box, | |
| box_xyxy_to_xywh, | |
| build_all_layer_point_grids, | |
| calculate_stability_score, | |
| coco_encode_rle, | |
| generate_crop_boxes, | |
| is_box_near_crop_edge, | |
| mask_to_rle_pytorch, | |
| remove_small_regions, | |
| rle_to_mask, | |
| uncrop_boxes_xyxy, | |
| uncrop_masks, | |
| uncrop_points, | |
| ) | |
| class SamAutomaticMaskGenerator: | |
| def __init__( | |
| self, | |
| model: Sam, | |
| points_per_side: Optional[int] = 32, | |
| points_per_batch: int = 64, | |
| pred_iou_thresh: float = 0.88, | |
| stability_score_thresh: float = 0.95, | |
| stability_score_offset: float = 1.0, | |
| box_nms_thresh: float = 0.7, | |
| crop_n_layers: int = 0, | |
| crop_nms_thresh: float = 0.7, | |
| crop_overlap_ratio: float = 512 / 1500, | |
| crop_n_points_downscale_factor: int = 1, | |
| point_grids: Optional[List[np.ndarray]] = None, | |
| min_mask_region_area: int = 0, | |
| output_mode: str = "binary_mask", | |
| ) -> None: | |
| """ | |
| Using a SAM model, generates masks for the entire image. | |
| Generates a grid of point prompts over the image, then filters | |
| low quality and duplicate masks. The default settings are chosen | |
| for SAM with a ViT-H backbone. | |
| Arguments: | |
| model (Sam): The SAM model to use for mask prediction. | |
| points_per_side (int or None): The number of points to be sampled | |
| along one side of the image. The total number of points is | |
| points_per_side**2. If None, 'point_grids' must provide explicit | |
| point sampling. | |
| points_per_batch (int): Sets the number of points run simultaneously | |
| by the model. Higher numbers may be faster but use more GPU memory. | |
| pred_iou_thresh (float): A filtering threshold in [0,1], using the | |
| model's predicted mask quality. | |
| stability_score_thresh (float): A filtering threshold in [0,1], using | |
| the stability of the mask under changes to the cutoff used to binarize | |
| the model's mask predictions. | |
| stability_score_offset (float): The amount to shift the cutoff when | |
| calculated the stability score. | |
| box_nms_thresh (float): The box IoU cutoff used by non-maximal | |
| suppression to filter duplicate masks. | |
| crops_n_layers (int): If >0, mask prediction will be run again on | |
| crops of the image. Sets the number of layers to run, where each | |
| layer has 2**i_layer number of image crops. | |
| crops_nms_thresh (float): The box IoU cutoff used by non-maximal | |
| suppression to filter duplicate masks between different crops. | |
| crop_overlap_ratio (float): Sets the degree to which crops overlap. | |
| In the first crop layer, crops will overlap by this fraction of | |
| the image length. Later layers with more crops scale down this overlap. | |
| crop_n_points_downscale_factor (int): The number of points-per-side | |
| sampled in layer n is scaled down by crop_n_points_downscale_factor**n. | |
| point_grids (list(np.ndarray) or None): A list over explicit grids | |
| of points used for sampling, normalized to [0,1]. The nth grid in the | |
| list is used in the nth crop layer. Exclusive with points_per_side. | |
| min_mask_region_area (int): If >0, postprocessing will be applied | |
| to remove disconnected regions and holes in masks with area smaller | |
| than min_mask_region_area. Requires opencv. | |
| output_mode (str): The form masks are returned in. Can be 'binary_mask', | |
| 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. | |
| For large resolutions, 'binary_mask' may consume large amounts of | |
| memory. | |
| """ | |
| assert (points_per_side is None) != ( | |
| point_grids is None | |
| ), "Exactly one of points_per_side or point_grid must be provided." | |
| if points_per_side is not None: | |
| self.point_grids = build_all_layer_point_grids( | |
| points_per_side, | |
| crop_n_layers, | |
| crop_n_points_downscale_factor, | |
| ) | |
| elif point_grids is not None: | |
| self.point_grids = point_grids | |
| else: | |
| raise ValueError("Can't have both points_per_side and point_grid be None.") | |
| assert output_mode in [ | |
| "binary_mask", | |
| "uncompressed_rle", | |
| "coco_rle", | |
| ], f"Unknown output_mode {output_mode}." | |
| if output_mode == "coco_rle": | |
| from pycocotools import mask as mask_utils # type: ignore # noqa: F401 | |
| if min_mask_region_area > 0: | |
| import cv2 # type: ignore # noqa: F401 | |
| self.predictor = SamPredictor(model) | |
| self.points_per_batch = points_per_batch | |
| self.pred_iou_thresh = pred_iou_thresh | |
| self.stability_score_thresh = stability_score_thresh | |
| self.stability_score_offset = stability_score_offset | |
| self.box_nms_thresh = box_nms_thresh | |
| self.crop_n_layers = crop_n_layers | |
| self.crop_nms_thresh = crop_nms_thresh | |
| self.crop_overlap_ratio = crop_overlap_ratio | |
| self.crop_n_points_downscale_factor = crop_n_points_downscale_factor | |
| self.min_mask_region_area = min_mask_region_area | |
| self.output_mode = output_mode | |
| def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: | |
| """ | |
| Generates masks for the given image. | |
| Arguments: | |
| image (np.ndarray): The image to generate masks for, in HWC uint8 format. | |
| Returns: | |
| list(dict(str, any)): A list over records for masks. Each record is | |
| a dict containing the following keys: | |
| segmentation (dict(str, any) or np.ndarray): The mask. If | |
| output_mode='binary_mask', is an array of shape HW. Otherwise, | |
| is a dictionary containing the RLE. | |
| bbox (list(float)): The box around the mask, in XYWH format. | |
| area (int): The area in pixels of the mask. | |
| predicted_iou (float): The model's own prediction of the mask's | |
| quality. This is filtered by the pred_iou_thresh parameter. | |
| point_coords (list(list(float))): The point coordinates input | |
| to the model to generate this mask. | |
| stability_score (float): A measure of the mask's quality. This | |
| is filtered on using the stability_score_thresh parameter. | |
| crop_box (list(float)): The crop of the image used to generate | |
| the mask, given in XYWH format. | |
| """ | |
| # Generate masks | |
| mask_data = self._generate_masks(image) | |
| # Filter small disconnected regions and holes in masks | |
| if self.min_mask_region_area > 0: | |
| mask_data = self.postprocess_small_regions( | |
| mask_data, | |
| self.min_mask_region_area, | |
| max(self.box_nms_thresh, self.crop_nms_thresh), | |
| ) | |
| # Encode masks | |
| if self.output_mode == "coco_rle": | |
| mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] | |
| elif self.output_mode == "binary_mask": | |
| mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] | |
| else: | |
| mask_data["segmentations"] = mask_data["rles"] | |
| # Write mask records | |
| curr_anns = [] | |
| for idx in range(len(mask_data["segmentations"])): | |
| ann = { | |
| "segmentation": mask_data["segmentations"][idx], | |
| "area": area_from_rle(mask_data["rles"][idx]), | |
| "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), | |
| "predicted_iou": mask_data["iou_preds"][idx].item(), | |
| "point_coords": [mask_data["points"][idx].tolist()], | |
| "stability_score": mask_data["stability_score"][idx].item(), | |
| "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), | |
| "feat": mask_data["feats"][idx].tolist(), | |
| } | |
| curr_anns.append(ann) | |
| return curr_anns | |
| def _generate_masks(self, image: np.ndarray) -> MaskData: | |
| orig_size = image.shape[:2] | |
| crop_boxes, layer_idxs = generate_crop_boxes( | |
| orig_size, self.crop_n_layers, self.crop_overlap_ratio | |
| ) | |
| # Iterate over image crops | |
| data = MaskData() | |
| for crop_box, layer_idx in zip(crop_boxes, layer_idxs): | |
| crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) | |
| data.cat(crop_data) | |
| # Remove duplicate masks between crops | |
| if len(crop_boxes) > 1: | |
| # Prefer masks from smaller crops | |
| scores = 1 / box_area(data["crop_boxes"]) | |
| scores = scores.to(data["boxes"].device) | |
| keep_by_nms = batched_nms( | |
| data["boxes"].float(), | |
| scores, | |
| torch.zeros(len(data["boxes"])), # categories | |
| iou_threshold=self.crop_nms_thresh, | |
| ) | |
| data.filter(keep_by_nms) | |
| data.to_numpy() | |
| return data | |
| def _process_crop( | |
| self, | |
| image: np.ndarray, | |
| crop_box: List[int], | |
| crop_layer_idx: int, | |
| orig_size: Tuple[int, ...], | |
| ) -> MaskData: | |
| # Crop the image and calculate embeddings | |
| x0, y0, x1, y1 = crop_box | |
| cropped_im = image[y0:y1, x0:x1, :] | |
| cropped_im_size = cropped_im.shape[:2] | |
| self.predictor.set_image(cropped_im) | |
| # Get points for this crop | |
| points_scale = np.array(cropped_im_size)[None, ::-1] | |
| points_for_image = self.point_grids[crop_layer_idx] * points_scale | |
| # Generate masks for this crop in batches | |
| data = MaskData() | |
| for (points,) in batch_iterator(self.points_per_batch, points_for_image): | |
| batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) | |
| data.cat(batch_data) | |
| del batch_data | |
| self.predictor.reset_image() | |
| # Remove duplicates within this crop. | |
| keep_by_nms = batched_nms( | |
| data["boxes"].float(), | |
| data["iou_preds"], | |
| torch.zeros(len(data["boxes"])), # categories | |
| iou_threshold=self.box_nms_thresh, | |
| ) | |
| data.filter(keep_by_nms) | |
| # Return to the original image frame | |
| data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) | |
| data["points"] = uncrop_points(data["points"], crop_box) | |
| data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) | |
| return data | |
| def _process_batch( | |
| self, | |
| points: np.ndarray, | |
| im_size: Tuple[int, ...], | |
| crop_box: List[int], | |
| orig_size: Tuple[int, ...], | |
| ) -> MaskData: | |
| orig_h, orig_w = orig_size | |
| # Run model on this batch | |
| transformed_points = self.predictor.transform.apply_coords(points, im_size) | |
| in_points = torch.as_tensor(transformed_points, device=self.predictor.device) | |
| in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) | |
| masks, iou_preds, _, feats = self.predictor.predict_torch( | |
| in_points[:, None, :], | |
| in_labels[:, None], | |
| multimask_output=True, | |
| return_logits=True, | |
| ) | |
| # Serialize predictions and store in MaskData | |
| data = MaskData( | |
| feats=feats.flatten(0, 1), | |
| masks=masks.flatten(0, 1), | |
| iou_preds=iou_preds.flatten(0, 1), | |
| points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), | |
| ) | |
| del masks | |
| # Filter by predicted IoU | |
| if self.pred_iou_thresh > 0.0: | |
| keep_mask = data["iou_preds"] > self.pred_iou_thresh | |
| data.filter(keep_mask) | |
| # Calculate stability score | |
| data["stability_score"] = calculate_stability_score( | |
| data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset | |
| ) | |
| if self.stability_score_thresh > 0.0: | |
| keep_mask = data["stability_score"] >= self.stability_score_thresh | |
| data.filter(keep_mask) | |
| # Threshold masks and calculate boxes | |
| data["masks"] = data["masks"] > self.predictor.model.mask_threshold | |
| data["boxes"] = batched_mask_to_box(data["masks"]) | |
| # Filter boxes that touch crop boundaries | |
| keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) | |
| if not torch.all(keep_mask): | |
| data.filter(keep_mask) | |
| # Compress to RLE | |
| data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) | |
| data["rles"] = mask_to_rle_pytorch(data["masks"]) | |
| del data["masks"] | |
| return data | |
| def postprocess_small_regions( | |
| mask_data: MaskData, min_area: int, nms_thresh: float | |
| ) -> MaskData: | |
| """ | |
| Removes small disconnected regions and holes in masks, then reruns | |
| box NMS to remove any new duplicates. | |
| Edits mask_data in place. | |
| Requires open-cv as a dependency. | |
| """ | |
| if len(mask_data["rles"]) == 0: | |
| return mask_data | |
| # Filter small disconnected regions and holes | |
| new_masks = [] | |
| scores = [] | |
| for rle in mask_data["rles"]: | |
| mask = rle_to_mask(rle) | |
| mask, changed = remove_small_regions(mask, min_area, mode="holes") | |
| unchanged = not changed | |
| mask, changed = remove_small_regions(mask, min_area, mode="islands") | |
| unchanged = unchanged and not changed | |
| new_masks.append(torch.as_tensor(mask).unsqueeze(0)) | |
| # Give score=0 to changed masks and score=1 to unchanged masks | |
| # so NMS will prefer ones that didn't need postprocessing | |
| scores.append(float(unchanged)) | |
| # Recalculate boxes and remove any new duplicates | |
| masks = torch.cat(new_masks, dim=0) | |
| boxes = batched_mask_to_box(masks) | |
| keep_by_nms = batched_nms( | |
| boxes.float(), | |
| torch.as_tensor(scores), | |
| torch.zeros(len(boxes)), # categories | |
| iou_threshold=nms_thresh, | |
| ) | |
| # Only recalculate RLEs for masks that have changed | |
| for i_mask in keep_by_nms: | |
| if scores[i_mask] == 0.0: | |
| mask_torch = masks[i_mask].unsqueeze(0) | |
| mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] | |
| mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly | |
| mask_data.filter(keep_by_nms) | |
| return mask_data | |