|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from scipy.optimize import linear_sum_assignment |
|
from torch.cuda.amp import autocast |
|
|
|
from mmdet.registry import MODELS, TASK_UTILS |
|
from mmdet.utils import reduce_mean |
|
|
|
|
|
def compute_mask_iou(inputs, targets): |
|
inputs = inputs.sigmoid() |
|
|
|
binarized_inputs = (inputs >= 0.4).float() |
|
targets = (targets > 0.5).float() |
|
intersection = (binarized_inputs * targets).sum(-1) |
|
union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection |
|
score = intersection / (union + 1e-6) |
|
return score |
|
|
|
|
|
def dice_score(inputs, targets): |
|
inputs = inputs.sigmoid() |
|
numerator = 2 * torch.matmul(inputs, targets.t()) |
|
denominator = (inputs * inputs).sum(-1)[:, |
|
None] + (targets * targets).sum(-1) |
|
score = numerator / (denominator + 1e-4) |
|
return score |
|
|
|
|
|
@MODELS.register_module() |
|
class SparseInstCriterion(nn.Module): |
|
"""This part is partially derivated from: |
|
|
|
https://github.com/facebookresearch/detr/blob/main/models/detr.py. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes, |
|
assigner, |
|
loss_cls=dict( |
|
type='FocalLoss', |
|
use_sigmoid=True, |
|
alpha=0.25, |
|
gamma=2.0, |
|
reduction='sum', |
|
loss_weight=2.0), |
|
loss_obj=dict( |
|
type='CrossEntropyLoss', |
|
use_sigmoid=True, |
|
reduction='mean', |
|
loss_weight=1.0), |
|
loss_mask=dict( |
|
type='CrossEntropyLoss', |
|
use_sigmoid=True, |
|
reduction='mean', |
|
loss_weight=5.0), |
|
loss_dice=dict( |
|
type='DiceLoss', |
|
use_sigmoid=True, |
|
reduction='sum', |
|
eps=5e-5, |
|
loss_weight=2.0), |
|
): |
|
super().__init__() |
|
self.matcher = TASK_UTILS.build(assigner) |
|
self.num_classes = num_classes |
|
self.loss_cls = MODELS.build(loss_cls) |
|
self.loss_obj = MODELS.build(loss_obj) |
|
self.loss_mask = MODELS.build(loss_mask) |
|
self.loss_dice = MODELS.build(loss_dice) |
|
|
|
def _get_src_permutation_idx(self, indices): |
|
|
|
batch_idx = torch.cat( |
|
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) |
|
src_idx = torch.cat([src for (src, _) in indices]) |
|
return batch_idx, src_idx |
|
|
|
def _get_tgt_permutation_idx(self, indices): |
|
|
|
batch_idx = torch.cat( |
|
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) |
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) |
|
return batch_idx, tgt_idx |
|
|
|
def loss_classification(self, outputs, batch_gt_instances, indices, |
|
num_instances): |
|
assert 'pred_logits' in outputs |
|
src_logits = outputs['pred_logits'] |
|
idx = self._get_src_permutation_idx(indices) |
|
target_classes_o = torch.cat( |
|
[gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices)]) |
|
target_classes = torch.full( |
|
src_logits.shape[:2], |
|
self.num_classes, |
|
dtype=torch.int64, |
|
device=src_logits.device) |
|
target_classes[idx] = target_classes_o |
|
|
|
src_logits = src_logits.flatten(0, 1) |
|
target_classes = target_classes.flatten(0, 1) |
|
|
|
class_loss = self.loss_cls( |
|
src_logits, |
|
target_classes, |
|
) / num_instances |
|
return class_loss |
|
|
|
def loss_masks_with_iou_objectness(self, outputs, batch_gt_instances, |
|
indices, num_instances): |
|
src_idx = self._get_src_permutation_idx(indices) |
|
tgt_idx = self._get_tgt_permutation_idx(indices) |
|
|
|
assert 'pred_masks' in outputs |
|
assert 'pred_scores' in outputs |
|
src_iou_scores = outputs['pred_scores'] |
|
src_masks = outputs['pred_masks'] |
|
with torch.no_grad(): |
|
target_masks = torch.cat([ |
|
gt.masks.to_tensor( |
|
dtype=src_masks.dtype, device=src_masks.device) |
|
for gt in batch_gt_instances |
|
]) |
|
num_masks = [len(gt.masks) for gt in batch_gt_instances] |
|
target_masks = target_masks.to(src_masks) |
|
if len(target_masks) == 0: |
|
|
|
loss_dice = src_masks.sum() * 0.0 |
|
loss_mask = src_masks.sum() * 0.0 |
|
loss_objectness = src_iou_scores.sum() * 0.0 |
|
|
|
return loss_objectness, loss_dice, loss_mask |
|
|
|
src_masks = src_masks[src_idx] |
|
target_masks = F.interpolate( |
|
target_masks[:, None], |
|
size=src_masks.shape[-2:], |
|
mode='bilinear', |
|
align_corners=False).squeeze(1) |
|
|
|
src_masks = src_masks.flatten(1) |
|
|
|
mix_tgt_idx = torch.zeros_like(tgt_idx[1]) |
|
cum_sum = 0 |
|
for num_mask in num_masks: |
|
mix_tgt_idx[cum_sum:cum_sum + num_mask] = cum_sum |
|
cum_sum += num_mask |
|
mix_tgt_idx += tgt_idx[1] |
|
|
|
target_masks = target_masks[mix_tgt_idx].flatten(1) |
|
|
|
with torch.no_grad(): |
|
ious = compute_mask_iou(src_masks, target_masks) |
|
|
|
tgt_iou_scores = ious |
|
src_iou_scores = src_iou_scores[src_idx] |
|
tgt_iou_scores = tgt_iou_scores.flatten(0) |
|
src_iou_scores = src_iou_scores.flatten(0) |
|
|
|
loss_objectness = self.loss_obj(src_iou_scores, tgt_iou_scores) |
|
loss_dice = self.loss_dice(src_masks, target_masks) / num_instances |
|
loss_mask = self.loss_mask(src_masks, target_masks) |
|
|
|
return loss_objectness, loss_dice, loss_mask |
|
|
|
def forward(self, outputs, batch_gt_instances, batch_img_metas, |
|
batch_gt_instances_ignore): |
|
|
|
|
|
indices = self.matcher(outputs, batch_gt_instances) |
|
|
|
|
|
num_instances = sum(gt.labels.shape[0] for gt in batch_gt_instances) |
|
num_instances = torch.as_tensor([num_instances], |
|
dtype=torch.float, |
|
device=next(iter( |
|
outputs.values())).device) |
|
num_instances = reduce_mean(num_instances).clamp_(min=1).item() |
|
|
|
loss_cls = self.loss_classification(outputs, batch_gt_instances, |
|
indices, num_instances) |
|
loss_obj, loss_dice, loss_mask = self.loss_masks_with_iou_objectness( |
|
outputs, batch_gt_instances, indices, num_instances) |
|
|
|
return dict( |
|
loss_cls=loss_cls, |
|
loss_obj=loss_obj, |
|
loss_dice=loss_dice, |
|
loss_mask=loss_mask) |
|
|
|
|
|
@TASK_UTILS.register_module() |
|
class SparseInstMatcher(nn.Module): |
|
|
|
def __init__(self, alpha=0.8, beta=0.2): |
|
super().__init__() |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.mask_score = dice_score |
|
|
|
def forward(self, outputs, batch_gt_instances): |
|
with torch.no_grad(): |
|
B, N, H, W = outputs['pred_masks'].shape |
|
pred_masks = outputs['pred_masks'] |
|
pred_logits = outputs['pred_logits'].sigmoid() |
|
device = pred_masks.device |
|
|
|
tgt_ids = torch.cat([gt.labels for gt in batch_gt_instances]) |
|
|
|
if tgt_ids.shape[0] == 0: |
|
return [(torch.as_tensor([]).to(pred_logits), |
|
torch.as_tensor([]).to(pred_logits))] * B |
|
tgt_masks = torch.cat([ |
|
gt.masks.to_tensor(dtype=pred_masks.dtype, device=device) |
|
for gt in batch_gt_instances |
|
]) |
|
|
|
tgt_masks = F.interpolate( |
|
tgt_masks[:, None], |
|
size=pred_masks.shape[-2:], |
|
mode='bilinear', |
|
align_corners=False).squeeze(1) |
|
|
|
pred_masks = pred_masks.view(B * N, -1) |
|
tgt_masks = tgt_masks.flatten(1) |
|
with autocast(enabled=False): |
|
pred_masks = pred_masks.float() |
|
tgt_masks = tgt_masks.float() |
|
pred_logits = pred_logits.float() |
|
mask_score = self.mask_score(pred_masks, tgt_masks) |
|
|
|
matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids] |
|
C = (mask_score**self.alpha) * (matching_prob**self.beta) |
|
|
|
C = C.view(B, N, -1).cpu() |
|
|
|
sizes = [len(gt.masks) for gt in batch_gt_instances] |
|
indices = [ |
|
linear_sum_assignment(c[i], maximize=True) |
|
for i, c in enumerate(C.split(sizes, -1)) |
|
] |
|
indices = [(torch.as_tensor(i, dtype=torch.int64), |
|
torch.as_tensor(j, dtype=torch.int64)) |
|
for i, j in indices] |
|
return indices |
|
|