Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import numpy as np | |
| from typing import List, Optional | |
| from segment_anything import SamAutomaticMaskGenerator | |
| from segment_anything.utils.amg import build_all_layer_point_grids | |
| from .predictor import SamPredictorHQ | |
| class SamAutomaticMaskGeneratorHQ(SamAutomaticMaskGenerator): | |
| def __init__( | |
| self, | |
| model: SamPredictorHQ, | |
| points_per_side: Optional[int] = 32, | |
| points_per_batch: int = 64, | |
| pred_iou_thresh: float = 0.88, | |
| stability_score_thresh: float = 0.95, | |
| stability_score_offset: float = 1.0, | |
| box_nms_thresh: float = 0.7, | |
| crop_n_layers: int = 0, | |
| crop_nms_thresh: float = 0.7, | |
| crop_overlap_ratio: float = 512 / 1500, | |
| crop_n_points_downscale_factor: int = 1, | |
| point_grids: Optional[List[np.ndarray]] = None, | |
| min_mask_region_area: int = 0, | |
| output_mode: str = "binary_mask", | |
| ) -> None: | |
| """ | |
| Using a SAM model, generates masks for the entire image. | |
| Generates a grid of point prompts over the image, then filters | |
| low quality and duplicate masks. The default settings are chosen | |
| for SAM with a ViT-H backbone. | |
| Arguments: | |
| model (Sam): The SAM model to use for mask prediction. | |
| points_per_side (int or None): The number of points to be sampled | |
| along one side of the image. The total number of points is | |
| points_per_side**2. If None, 'point_grids' must provide explicit | |
| point sampling. | |
| points_per_batch (int): Sets the number of points run simultaneously | |
| by the model. Higher numbers may be faster but use more GPU memory. | |
| pred_iou_thresh (float): A filtering threshold in [0,1], using the | |
| model's predicted mask quality. | |
| stability_score_thresh (float): A filtering threshold in [0,1], using | |
| the stability of the mask under changes to the cutoff used to binarize | |
| the model's mask predictions. | |
| stability_score_offset (float): The amount to shift the cutoff when | |
| calculated the stability score. | |
| box_nms_thresh (float): The box IoU cutoff used by non-maximal | |
| suppression to filter duplicate masks. | |
| crop_n_layers (int): If >0, mask prediction will be run again on | |
| crops of the image. Sets the number of layers to run, where each | |
| layer has 2**i_layer number of image crops. | |
| crop_nms_thresh (float): The box IoU cutoff used by non-maximal | |
| suppression to filter duplicate masks between different crops. | |
| crop_overlap_ratio (float): Sets the degree to which crops overlap. | |
| In the first crop layer, crops will overlap by this fraction of | |
| the image length. Later layers with more crops scale down this overlap. | |
| crop_n_points_downscale_factor (int): The number of points-per-side | |
| sampled in layer n is scaled down by crop_n_points_downscale_factor**n. | |
| point_grids (list(np.ndarray) or None): A list over explicit grids | |
| of points used for sampling, normalized to [0,1]. The nth grid in the | |
| list is used in the nth crop layer. Exclusive with points_per_side. | |
| min_mask_region_area (int): If >0, postprocessing will be applied | |
| to remove disconnected regions and holes in masks with area smaller | |
| than min_mask_region_area. Requires opencv. | |
| output_mode (str): The form masks are returned in. Can be 'binary_mask', | |
| 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. | |
| For large resolutions, 'binary_mask' may consume large amounts of | |
| memory. | |
| """ | |
| assert (points_per_side is None) != ( | |
| point_grids is None | |
| ), "Exactly one of points_per_side or point_grid must be provided." | |
| if points_per_side is not None: | |
| self.point_grids = build_all_layer_point_grids( | |
| points_per_side, | |
| crop_n_layers, | |
| crop_n_points_downscale_factor, | |
| ) | |
| elif point_grids is not None: | |
| self.point_grids = point_grids | |
| else: | |
| raise ValueError("Can't have both points_per_side and point_grid be None.") | |
| assert output_mode in [ | |
| "binary_mask", | |
| "uncompressed_rle", | |
| "coco_rle", | |
| ], f"Unknown output_mode {output_mode}." | |
| if output_mode == "coco_rle": | |
| from pycocotools import mask as mask_utils # type: ignore # noqa: F401 | |
| if min_mask_region_area > 0: | |
| import cv2 # type: ignore # noqa: F401 | |
| self.predictor = model | |
| self.points_per_batch = points_per_batch | |
| self.pred_iou_thresh = pred_iou_thresh | |
| self.stability_score_thresh = stability_score_thresh | |
| self.stability_score_offset = stability_score_offset | |
| self.box_nms_thresh = box_nms_thresh | |
| self.crop_n_layers = crop_n_layers | |
| self.crop_nms_thresh = crop_nms_thresh | |
| self.crop_overlap_ratio = crop_overlap_ratio | |
| self.crop_n_points_downscale_factor = crop_n_points_downscale_factor | |
| self.min_mask_region_area = min_mask_region_area | |
| self.output_mode = output_mode | |