Dense-Captioning-Platform / custom_models /square_mask_target.py
hanszhu's picture
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
eb4d305
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from mmdet.registry import MODELS
from mmdet.structures.mask.mask_target import mask_target as original_mask_target
def square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, cfg):
"""Compute square mask target for positive proposals in multiple images.
This function forces all mask targets to be square regardless of the original
aspect ratio to avoid tensor size mismatches.
Args:
pos_proposals_list (list[Tensor]): Positive proposals in multiple images.
pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each positive proposals.
gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of each image.
cfg (dict): Config dict that specifies the mask size.
Returns:
Tensor: Square mask target of each image, has shape (num_pos, size, size).
"""
# Get the target size (should be a tuple like (14, 14))
mask_size = _pair(cfg.mask_size)
# Force square size by using the minimum dimension
square_size = min(mask_size)
# Create a proper ConfigDict object
from mmengine.config import ConfigDict
square_cfg = ConfigDict({'mask_size': (square_size, square_size)})
# Call the original mask target function with square size
mask_targets = original_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
gt_masks_list, square_cfg)
print(f"πŸ” SQUARE_MASK_TARGET: Original mask_targets shape: {mask_targets.shape}")
print(f"πŸ” SQUARE_MASK_TARGET: Expected square_size: {square_size}")
# Force square shape by padding or cropping if necessary
if mask_targets.size(1) != square_size or mask_targets.size(2) != square_size:
print(f"πŸ” SQUARE_MASK_TARGET: Forcing square shape from {mask_targets.shape} to ({mask_targets.size(0)}, {square_size}, {square_size})")
# Create new tensor with square shape
num_masks = mask_targets.size(0)
square_targets = torch.zeros(num_masks, square_size, square_size,
device=mask_targets.device, dtype=mask_targets.dtype)
# Copy the mask data, padding with zeros if necessary
h, w = mask_targets.size(1), mask_targets.size(2)
h_copy = min(h, square_size)
w_copy = min(w, square_size)
square_targets[:, :h_copy, :w_copy] = mask_targets[:, :h_copy, :w_copy]
mask_targets = square_targets
print(f"πŸ” SQUARE_MASK_TARGET: Final mask_targets shape: {mask_targets.shape}")
else:
print(f"πŸ” SQUARE_MASK_TARGET: Masks already square: {mask_targets.shape}")
return mask_targets
# Register the custom function
MODELS.register_module(name='square_mask_target', module=square_mask_target)