File size: 8,087 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
from torch import Tensor

from mmdet.models.roi_heads.mask_heads.fcn_mask_head import FCNMaskHead
from mmdet.registry import MODELS
from mmdet.structures.mask.mask_target import mask_target
from typing import Union, Dict, Any
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from typing import List

from .square_mask_target import square_mask_target


@MODELS.register_module()
class SquareFCNMaskHead(FCNMaskHead):
    """FCN mask head that forces square mask targets.
    
    This head ensures that all mask targets are square regardless of the original
    aspect ratio to avoid tensor size mismatches during training.
    """

    def __init__(self, *args, **kwargs):
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: Initializing SquareFCNMaskHead")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: args: {args}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: kwargs: {kwargs}")
        super().__init__(*args, **kwargs)
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: SquareFCNMaskHead initialized successfully")

    def forward(self, x: Tensor) -> Tensor:
        """Forward features from the upstream network.

        Args:
            x (Tensor): Extract mask RoI features.

        Returns:
            Tensor: Predicted foreground masks.
        """
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: Input shape: {x.shape}")
        
        for i, conv in enumerate(self.convs):
            x = conv(x)
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: After conv {i} shape: {x.shape}")
            
        if self.upsample is not None:
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: Upsampling from {x.shape}")
            x = self.upsample(x)
            if self.upsample_method == 'deconv':
                x = self.relu(x)
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: After upsample shape: {x.shape}")
        else:
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: No upsampling, shape: {x.shape}")
            
        mask_preds = self.conv_logits(x)
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: Final mask_preds shape: {mask_preds.shape}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds dtype: {mask_preds.dtype}")
        
        return mask_preds

    def loss_and_target(self,
                       mask_preds: Tensor,
                       sampling_results: List[Any],
                       batch_gt_instances: List[InstanceData],
                       rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> dict:
        """Calculate the loss based on the features extracted by the mask head.

        Args:
            mask_preds (Tensor): Predicted foreground masks, has shape
                (num_pos, num_classes, mask_h, mask_w).
            sampling_results (List[:obj:`SamplingResult`]): Assign results of
                all images in a batch after sampling.
            batch_gt_instances (List[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes``, ``labels``,
                and ``masks`` attributes.
            rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.

        Returns:
            dict: A dictionary of loss components.
        """
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: loss_and_target called")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
        
        # Get mask targets
        mask_targets = self.get_targets(sampling_results, batch_gt_instances,
                                       rcnn_train_cfg)
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
        
        # Get labels for positive proposals
        pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels: {pos_labels}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels min: {pos_labels.min()}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels max: {pos_labels.max()}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: num_classes: {self.num_classes}")
        
        # Check for out-of-bounds labels
        if pos_labels.max() >= self.num_classes:
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: ERROR! Found label {pos_labels.max()} >= num_classes {self.num_classes}")
            # Clamp labels to valid range
            pos_labels = torch.clamp(pos_labels, 0, self.num_classes - 1)
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: Clamped pos_labels max: {pos_labels.max()}")
        
        # Check for size mismatch between predictions and targets
        if mask_preds.shape[-2:] != mask_targets.shape[-2:]:
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: SIZE MISMATCH!")
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
            print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
        
        # Calculate loss - use the original approach like FCNMaskHead
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: About to call loss_mask")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds shape: {mask_preds.shape}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets shape: {mask_targets.shape}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels shape: {pos_labels.shape}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_preds device: {mask_preds.device}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: mask_targets device: {mask_targets.device}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: pos_labels device: {pos_labels.device}")
        
        # Call loss function with full mask_preds and pos_labels like the original FCN mask head
        loss_mask = self.loss_mask(mask_preds, mask_targets, pos_labels)
        
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: Loss calculated successfully: {loss_mask}")
        
        # only return the *nested* loss dict that StandardRoIHead.update() expects
        return dict(
            loss_mask={'loss_mask': loss_mask},
            # if you really need mask_targets downstream you can still return it under a
            # different key, but it will be ignored by the standard loss updater
            mask_targets=mask_targets
        )

    def get_targets(self,
                   sampling_results: List[Any],
                   batch_gt_instances: List[InstanceData],
                   rcnn_train_cfg: Union[Dict[str, Any], ConfigDict]) -> Tensor:
        """Calculate the ground truth for all samples in a batch according to
        the sampling_results.

        Args:
            sampling_results (List[:obj:`SamplingResult`]): Assign results of
                all images in a batch after sampling.
            batch_gt_instances (List[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes``, ``labels``,
                and ``masks`` attributes.
            rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.

        Returns:
            Tensor: Mask targets of each positive proposals in the image,
                has shape (num_pos, mask_h, mask_w).
        """
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: get_targets called")
        
        pos_proposals_list = [res.pos_priors for res in sampling_results]
        pos_assigned_gt_inds_list = [
            res.pos_assigned_gt_inds for res in sampling_results
        ]
        gt_masks_list = [res.masks for res in batch_gt_instances]
        
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: Number of sampling results: {len(sampling_results)}")
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: rcnn_train_cfg: {rcnn_train_cfg}")
        
        # Use our custom square mask target function
        mask_targets = square_mask_target(pos_proposals_list, pos_assigned_gt_inds_list,
                                         gt_masks_list, rcnn_train_cfg)
        
        print(f"πŸ” SQUARE_FCN_MASK_HEAD: Final mask_targets shape: {mask_targets.shape}")
        return mask_targets