Spaces:
Sleeping
Sleeping
File size: 8,087 Bytes
eb4d305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from torch import Tensor
from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead
from mmdet.registry import MODELS
from mmdet.structures.mask.mask_target import mask_target
from typing import Union, Dict, Any
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from typing import List
from .square_mask_target import square_mask_target
@MODELS.register_module()
class SquareFCNMaskHead(FCNMaskHead):
"""FCN mask head that forces square mask targets.
This head ensures that all mask targets are square regardless of the original
aspect ratio to avoid tensor size mismatches during training.
"""
def __init__(self, *args, **kwargs):
print(f"π SQUARE_FCN_MASK_HEAD: Initializing SquareFCNMaskHead")
print(f"π SQUARE_FCN_MASK_HEAD: args: {args}")
print(f"π SQUARE_FCN_MASK_HEAD: kwargs: {kwargs}")
super().__init__(*args, **kwargs)
print(f"π SQUARE_FCN_MASK_HEAD: SquareFCNMaskHead initialized successfully")
def forward(self, x: Tensor) -> Tensor:
"""Forward features from the upstream network.
Args:
x (Tensor): Extract mask RoI features.
Returns:
Tensor: Predicted foreground masks.
"""
print(f"π SQUARE_FCN_MASK_HEAD: Input shape: {x.shape}")
for i, conv in enumerate(self.convs):
x = conv(x)
print(f"π SQUARE_FCN_MASK_HEAD: After conv {i} shape: {x.shape}")
if self.upsample is not None:
print(f"π SQUARE_FCN_MASK_HEAD: Upsampling from {x.shape}")
x = self.upsample(x)
if self.upsample_method == 'deconv':
x = self.relu(x)
print(f"π SQUARE_FCN_MASK_HEAD: After upsample shape: {x.shape}")
else:
print(f"π SQUARE_FCN_MASK_HEAD: No upsampling, shape: {x.shape}")
mask_preds = self.conv_logits(x)
print(f"π SQUARE_FCN_MASK_HEAD: Final mask_preds shape: {mask_preds.shape}")
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds dtype: {mask_preds.dtype}")
return mask_preds
def loss_and_target(self,
mask_preds: Tensor,
sampling_results: List[Any],
batch_gt_instances: List[InstanceData],
rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> dict:
"""Calculate the loss based on the features extracted by the mask head.
Args:
mask_preds (Tensor): Predicted foreground masks, has shape
(num_pos, num_classes, mask_h, mask_w).
sampling_results (List[:obj:`SamplingResult`]): Assign results of
all images in a batch after sampling.
batch_gt_instances (List[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``labels``,
and ``masks`` attributes.
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
Returns:
dict: A dictionary of loss components.
"""
print(f"π SQUARE_FCN_MASK_HEAD: loss_and_target called")
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
# Get mask targets
mask_targets = self.get_targets(sampling_results, batch_gt_instances,
rcnn_train_cfg)
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
# Get labels for positive proposals
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels: {pos_labels}")
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels min: {pos_labels.min()}")
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels max: {pos_labels.max()}")
print(f"π SQUARE_FCN_MASK_HEAD: num_classes: {self.num_classes}")
# Check for out-of-bounds labels
if pos_labels.max() >= self.num_classes:
print(f"π SQUARE_FCN_MASK_HEAD: ERROR! Found label {pos_labels.max()} >= num_classes {self.num_classes}")
# Clamp labels to valid range
pos_labels = torch.clamp(pos_labels, 0, self.num_classes - 1)
print(f"π SQUARE_FCN_MASK_HEAD: Clamped pos_labels max: {pos_labels.max()}")
# Check for size mismatch between predictions and targets
if mask_preds.shape[-2:] != mask_targets.shape[-2:]:
print(f"π SQUARE_FCN_MASK_HEAD: SIZE MISMATCH!")
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
# Calculate loss - use the original approach like FCNMaskHead
print(f"π SQUARE_FCN_MASK_HEAD: About to call loss_mask")
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
print(f"π SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
print(f"π SQUARE_FCN_MASK_HEAD: mask_targets device: {mask_targets.device}")
print(f"π SQUARE_FCN_MASK_HEAD: pos_labels device: {pos_labels.device}")
# Call loss function with full mask_preds and pos_labels like the original FCN mask head
loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels)
print(f"π SQUARE_FCN_MASK_HEAD: Loss calculated successfully: {loss_mask}")
# only return the *nested* loss dict that StandardRoIHead.update() expects
return dict(
loss_mask={'loss_mask': loss_mask},
# if you really need mask_targets downstream you can still return it under a
# different key, but it will be ignored by the standard loss updater
mask_targets=mask_targets
)
def get_targets(self,
sampling_results: List[Any],
batch_gt_instances: List[InstanceData],
rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> Tensor:
"""Calculate the ground truth for all samples in a batch according to
the sampling_results.
Args:
sampling_results (List[:obj:`SamplingResult`]): Assign results of
all images in a batch after sampling.
batch_gt_instances (List[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes``, ``labels``,
and ``masks`` attributes.
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
Returns:
Tensor: Mask targets of each positive proposals in the image,
has shape (num_pos, mask_h, mask_w).
"""
print(f"π SQUARE_FCN_MASK_HEAD: get_targets called")
pos_proposals_list = [res.pos_priors for res in sampling_results]
pos_assigned_gt_inds_list = [
res.pos_assigned_gt_inds for res in sampling_results
]
gt_masks_list = [res.masks for res in batch_gt_instances]
print(f"π SQUARE_FCN_MASK_HEAD: Number of sampling results: {len(sampling_results)}")
print(f"π SQUARE_FCN_MASK_HEAD: rcnn_train_cfg: {rcnn_train_cfg}")
# Use our custom square mask target function
mask_targets = square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
gt_masks_list, rcnn_train_cfg)
print(f"π SQUARE_FCN_MASK_HEAD: Final mask_targets shape: {mask_targets.shape}")
return mask_targets |