Spaces:
Build error
Build error
| import numpy.typing as npt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import cv2 | |
| from torchvision.ops.boxes import batched_nms | |
| from app.mobile_sam import SamPredictor | |
| from app.mobile_sam.utils import batched_mask_to_box | |
| from app.sam.postprocess import clean_mask_torch | |
| def point_selection(mask_sim, topk: int = 1): | |
| # Top-1 point selection | |
| _, h = mask_sim.shape | |
| topk_xy = mask_sim.flatten(0).topk(topk)[1] | |
| topk_x = (topk_xy // h).unsqueeze(0) | |
| topk_y = topk_xy - topk_x * h | |
| topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) | |
| topk_label = np.array([1] * topk) | |
| topk_xy = topk_xy.cpu().numpy() | |
| return topk_xy, topk_label | |
| def mask_nms( | |
| masks: list[npt.NDArray], scores: list[float], iou_thresh: float = 0.2 | |
| ) -> tuple[list[npt.NDArray], list[float]]: | |
| ious = np.zeros((len(masks), len(masks))) | |
| np_masks = np.array(masks).astype(bool) | |
| np_scores = np.array(scores) | |
| remove_indices = set() | |
| for i in range(len(masks)): | |
| mask_i = np_masks[i, :, :] | |
| intersection_sum = np.logical_and(mask_i, np_masks).sum(axis=(1, 2)) | |
| union = np.logical_or(mask_i, np_masks) | |
| ious_i = intersection_sum / union.sum(axis=(1, 2)) | |
| ious[i, :] = ious_i | |
| # if the mask completely overlaps another mask, take the highest | |
| # scoring mask and remove the lower (current) one | |
| overlap = intersection_sum >= np_masks.sum(axis=(1, 2)) * 0.90 | |
| argmax_idx = np_scores[overlap].argmax() | |
| max_idx = np.where(overlap == True)[0][argmax_idx] | |
| if max_idx != i: | |
| remove_indices.add(i) | |
| for i in range(ious.shape[0]): | |
| ious_i = ious[i, :] | |
| idxs = np.where(ious_i > iou_thresh)[0] | |
| keep = idxs[np.argmax(np_scores[idxs])] | |
| if keep != i: | |
| remove_indices.add(i) | |
| return [masks[i] for i in range(len(masks)) if i not in remove_indices], [ | |
| scores[i] for i in range(len(masks)) if i not in remove_indices | |
| ] | |
| class MaskWeights(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3) | |
| class PerSAM: | |
| def __init__( | |
| self, | |
| sam: SamPredictor, | |
| target_feat: torch.Tensor, | |
| max_objects: int, | |
| score_thresh: float, | |
| nms_iou_thresh: float, | |
| mask_weights: torch.Tensor, | |
| ) -> None: | |
| super().__init__() | |
| self.sam = sam | |
| self.weights = mask_weights | |
| self.target_feat = target_feat | |
| self.max_objects = max_objects | |
| self.score_thresh = score_thresh | |
| self.nms_iou_thresh = nms_iou_thresh | |
| def __call__(self, x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: | |
| return fast_inference( | |
| self.sam, | |
| x, | |
| self.target_feat, | |
| self.weights, | |
| self.max_objects, | |
| self.score_thresh, | |
| self.nms_iou_thresh, | |
| ) | |
| def fast_inference( | |
| predictor: SamPredictor, | |
| image: npt.NDArray, | |
| target_feat: torch.Tensor, | |
| weights: torch.Tensor, | |
| max_objects: int, | |
| score_thresh: float, | |
| nms_iou_thresh: float = 0.2, | |
| ) -> tuple[npt.NDArray | None, npt.NDArray | None, npt.NDArray | None]: | |
| weights_np = weights.detach().cpu().numpy() | |
| pred_masks = [] | |
| pred_scores = [] | |
| # Image feature encoding | |
| predictor.set_image(image) | |
| test_feat = predictor.features.squeeze() | |
| # Cosine similarity | |
| C, h, w = test_feat.shape | |
| test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) | |
| test_feat = test_feat.reshape(C, h * w) | |
| sim = target_feat @ test_feat | |
| sim = sim.reshape(1, 1, h, w) | |
| sim = F.interpolate(sim, scale_factor=4, mode="bilinear") | |
| sim = predictor.model.postprocess_masks( | |
| sim, input_size=predictor.input_size, original_size=predictor.original_size | |
| ).squeeze() | |
| for _ in range(max_objects): | |
| # Positive location prior | |
| topk_xy, topk_label = point_selection(sim, topk=1) | |
| # First-step prediction | |
| logits_high, scores, logits = predictor.predict( | |
| point_coords=topk_xy, | |
| point_labels=topk_label, | |
| multimask_output=True, | |
| return_logits=True, | |
| return_numpy=False, | |
| ) | |
| logits = logits.detach().cpu().numpy() | |
| # Weighted sum three-scale masks | |
| logits_high = logits_high * weights.unsqueeze(-1) | |
| logit_high = logits_high.sum(0) | |
| # mask = (logit_high > 0).detach().cpu().numpy() | |
| mask = (logit_high > 0) | |
| mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy() | |
| logits = logits * weights_np[..., None] | |
| logit = logits.sum(0) | |
| # Cascaded Post-refinement-1 | |
| y, x = np.nonzero(mask) | |
| x_min = x.min() | |
| x_max = x.max() | |
| y_min = y.min() | |
| y_max = y.max() | |
| input_box = np.array([x_min, y_min, x_max, y_max]) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=topk_xy, | |
| point_labels=topk_label, | |
| box=input_box[None, :], | |
| mask_input=logit[None, :, :], | |
| multimask_output=True, | |
| ) | |
| best_idx = np.argmax(scores) | |
| # Cascaded Post-refinement-2 | |
| y, x = np.nonzero(masks[best_idx]) | |
| x_min = x.min() | |
| x_max = x.max() | |
| y_min = y.min() | |
| y_max = y.max() | |
| input_box = np.array([x_min, y_min, x_max, y_max]) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=topk_xy, | |
| point_labels=topk_label, | |
| box=input_box[None, :], | |
| mask_input=logits[best_idx : best_idx + 1, :, :], | |
| multimask_output=True, | |
| return_numpy=False, | |
| ) | |
| best_idx = np.argmax(scores.detach().cpu().numpy()) | |
| final_mask = masks[best_idx] | |
| score = sim[topk_xy[0][1], topk_xy[0][0]].item() | |
| final_mask_dilate = cv2.dilate( | |
| final_mask.detach().cpu().numpy().astype(np.uint8), np.ones((5, 5), np.uint8), iterations=1 | |
| ) | |
| if score < score_thresh: | |
| break | |
| sim[final_mask_dilate] = 0 | |
| pred_masks.append(final_mask) | |
| pred_scores.append(score) | |
| if len(pred_masks) == 0: | |
| return None, None, None | |
| pred_masks = torch.stack(pred_masks) | |
| bboxes = batched_mask_to_box(pred_masks) | |
| keep_by_nms = batched_nms( | |
| bboxes.float(), | |
| torch.as_tensor(pred_scores), | |
| torch.zeros_like(bboxes[:, 0]), | |
| iou_threshold=nms_iou_thresh, | |
| ) | |
| pred_masks = pred_masks[keep_by_nms].cpu().numpy() | |
| pred_scores = np.array(pred_scores)[keep_by_nms.cpu().numpy()] | |
| bboxes = bboxes[keep_by_nms].int().cpu().numpy() | |
| return pred_masks, bboxes, pred_scores | |