|
import torch |
|
import torch.nn.functional as F |
|
from fvcore.nn import sigmoid_focal_loss_jit |
|
from torch import nn |
|
|
|
import torch.distributed as dist |
|
from torch.distributed import get_world_size |
|
from torchvision import ops |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
|
|
def get_fed_loss_classes(gt_classes, num_fed_loss_classes, num_classes, weight): |
|
""" |
|
Args: |
|
gt_classes: a long tensor of shape R that contains the gt class label of each proposal. |
|
num_fed_loss_classes: minimum number of classes to keep when calculating federated loss. |
|
Will sample negative classes if number of unique gt_classes is smaller than this value. |
|
num_classes: number of foreground classes |
|
weight: probabilities used to sample negative classes |
|
Returns: |
|
Tensor: |
|
classes to keep when calculating the federated loss, including both unique gt |
|
classes and sampled negative classes. |
|
""" |
|
unique_gt_classes = torch.unique(gt_classes) |
|
prob = unique_gt_classes.new_ones(num_classes + 1).float() |
|
prob[-1] = 0 |
|
if len(unique_gt_classes) < num_fed_loss_classes: |
|
prob[:num_classes] = weight.float().clone() |
|
prob[unique_gt_classes] = 0 |
|
sampled_negative_classes = torch.multinomial( |
|
prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False |
|
) |
|
fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes]) |
|
else: |
|
fed_loss_classes = unique_gt_classes |
|
return fed_loss_classes |
|
|
|
|
|
class CriterionDynamicK(nn.Module): |
|
""" This class computes the loss for DiffusionDet. |
|
The process happens in two steps: |
|
1) we compute hungarian assignment between ground truth boxes and the outputs of the model |
|
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) |
|
""" |
|
|
|
def __init__(self, config, num_classes, weight_dict): |
|
""" Create the criterion. |
|
Parameters: |
|
num_classes: number of object categories, omitting the special no-object category |
|
weight_dict: dict containing as key the names of the losses and as values their relative weight. |
|
""" |
|
super().__init__() |
|
self.config = config |
|
self.num_classes = num_classes |
|
self.matcher = HungarianMatcherDynamicK(config) |
|
self.weight_dict = weight_dict |
|
self.eos_coef = config.no_object_weight |
|
self.use_focal = config.use_focal |
|
self.use_fed_loss = config.use_fed_loss |
|
|
|
if self.use_focal: |
|
self.focal_loss_alpha = config.alpha |
|
self.focal_loss_gamma = config.gamma |
|
|
|
|
|
def loss_labels(self, outputs, targets, indices): |
|
"""Classification loss (NLL) |
|
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] |
|
""" |
|
assert 'pred_logits' in outputs |
|
src_logits = outputs['pred_logits'] |
|
batch_size = len(targets) |
|
|
|
|
|
|
|
target_classes = torch.full(src_logits.shape[:2], self.num_classes, |
|
dtype=torch.int64, device=src_logits.device) |
|
src_logits_list = [] |
|
target_classes_o_list = [] |
|
|
|
for batch_idx in range(batch_size): |
|
valid_query = indices[batch_idx][0] |
|
gt_multi_idx = indices[batch_idx][1] |
|
if len(gt_multi_idx) == 0: |
|
continue |
|
bz_src_logits = src_logits[batch_idx] |
|
target_classes_o = targets[batch_idx]["labels"] |
|
target_classes[batch_idx, valid_query] = target_classes_o[gt_multi_idx] |
|
|
|
src_logits_list.append(bz_src_logits[valid_query]) |
|
target_classes_o_list.append(target_classes_o[gt_multi_idx]) |
|
|
|
if self.use_focal or self.use_fed_loss: |
|
num_boxes = torch.cat(target_classes_o_list).shape[0] if len(target_classes_o_list) != 0 else 1 |
|
|
|
target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], self.num_classes + 1], |
|
dtype=src_logits.dtype, layout=src_logits.layout, |
|
device=src_logits.device) |
|
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) |
|
|
|
gt_classes = torch.argmax(target_classes_onehot, dim=-1) |
|
target_classes_onehot = target_classes_onehot[:, :, :-1] |
|
|
|
src_logits = src_logits.flatten(0, 1) |
|
target_classes_onehot = target_classes_onehot.flatten(0, 1) |
|
if self.use_focal: |
|
cls_loss = sigmoid_focal_loss_jit(src_logits, target_classes_onehot, alpha=self.focal_loss_alpha, |
|
gamma=self.focal_loss_gamma, reduction="none") |
|
else: |
|
cls_loss = F.binary_cross_entropy_with_logits(src_logits, target_classes_onehot, reduction="none") |
|
if self.use_fed_loss: |
|
K = self.num_classes |
|
N = src_logits.shape[0] |
|
fed_loss_classes = get_fed_loss_classes( |
|
gt_classes, |
|
num_fed_loss_classes=self.fed_loss_num_classes, |
|
num_classes=K, |
|
weight=self.fed_loss_cls_weights, |
|
) |
|
fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1) |
|
fed_loss_classes_mask[fed_loss_classes] = 1 |
|
fed_loss_classes_mask = fed_loss_classes_mask[:K] |
|
weight = fed_loss_classes_mask.view(1, K).expand(N, K).float() |
|
|
|
loss_ce = torch.sum(cls_loss * weight) / num_boxes |
|
else: |
|
loss_ce = torch.sum(cls_loss) / num_boxes |
|
|
|
losses = {'loss_ce': loss_ce} |
|
else: |
|
raise NotImplementedError |
|
|
|
return losses |
|
|
|
def loss_boxes(self, outputs, targets, indices): |
|
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss |
|
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] |
|
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. |
|
""" |
|
assert 'pred_boxes' in outputs |
|
|
|
src_boxes = outputs['pred_boxes'] |
|
|
|
batch_size = len(targets) |
|
pred_box_list = [] |
|
pred_norm_box_list = [] |
|
tgt_box_list = [] |
|
tgt_box_xyxy_list = [] |
|
for batch_idx in range(batch_size): |
|
valid_query = indices[batch_idx][0] |
|
gt_multi_idx = indices[batch_idx][1] |
|
if len(gt_multi_idx) == 0: |
|
continue |
|
bz_image_whwh = targets[batch_idx]['image_size_xyxy'] |
|
bz_src_boxes = src_boxes[batch_idx] |
|
bz_target_boxes = targets[batch_idx]["boxes"] |
|
bz_target_boxes_xyxy = targets[batch_idx]["boxes_xyxy"] |
|
pred_box_list.append(bz_src_boxes[valid_query]) |
|
pred_norm_box_list.append(bz_src_boxes[valid_query] / bz_image_whwh) |
|
tgt_box_list.append(bz_target_boxes[gt_multi_idx]) |
|
tgt_box_xyxy_list.append(bz_target_boxes_xyxy[gt_multi_idx]) |
|
|
|
if len(pred_box_list) != 0: |
|
src_boxes = torch.cat(pred_box_list) |
|
src_boxes_norm = torch.cat(pred_norm_box_list) |
|
target_boxes = torch.cat(tgt_box_list) |
|
target_boxes_abs_xyxy = torch.cat(tgt_box_xyxy_list) |
|
num_boxes = src_boxes.shape[0] |
|
|
|
losses = {} |
|
|
|
loss_bbox = F.l1_loss(src_boxes_norm, ops.box_convert(target_boxes, 'cxcywh', 'xyxy'), reduction='none') |
|
losses['loss_bbox'] = loss_bbox.sum() / num_boxes |
|
|
|
|
|
loss_giou = 1 - torch.diag(ops.generalized_box_iou(src_boxes, target_boxes_abs_xyxy)) |
|
losses['loss_giou'] = loss_giou.sum() / num_boxes |
|
else: |
|
losses = {'loss_bbox': outputs['pred_boxes'].sum() * 0, |
|
'loss_giou': outputs['pred_boxes'].sum() * 0} |
|
|
|
return losses |
|
|
|
def get_loss(self, loss, outputs, targets, indices): |
|
loss_map = { |
|
'labels': self.loss_labels, |
|
'boxes': self.loss_boxes, |
|
} |
|
assert loss in loss_map, f'do you really want to compute {loss} loss?' |
|
return loss_map[loss](outputs, targets, indices) |
|
|
|
def forward(self, outputs, targets): |
|
""" This performs the loss computation. |
|
Parameters: |
|
outputs: dict of tensors, see the output specification of the model for the format |
|
targets: list of dicts, such that len(targets) == batch_size. |
|
The expected keys in each dict depends on the losses applied, see each loss' doc |
|
""" |
|
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} |
|
|
|
|
|
indices, _ = self.matcher(outputs_without_aux, targets) |
|
|
|
|
|
losses = {} |
|
for loss in ["labels", "boxes"]: |
|
losses.update(self.get_loss(loss, outputs, targets, indices)) |
|
|
|
|
|
if 'aux_outputs' in outputs: |
|
for i, aux_outputs in enumerate(outputs['aux_outputs']): |
|
indices, _ = self.matcher(aux_outputs, targets) |
|
for loss in ["labels", "boxes"]: |
|
if loss == 'masks': |
|
|
|
continue |
|
|
|
l_dict = self.get_loss(loss, aux_outputs, targets, indices) |
|
l_dict = {k + f'_{i}': v for k, v in l_dict.items()} |
|
losses.update(l_dict) |
|
|
|
return losses |
|
|
|
|
|
def get_in_boxes_info(boxes, target_gts): |
|
xy_target_gts = ops.box_convert(target_gts, 'cxcywh', 'xyxy') |
|
|
|
anchor_center_x = boxes[:, 0].unsqueeze(1) |
|
anchor_center_y = boxes[:, 1].unsqueeze(1) |
|
|
|
|
|
b_l = anchor_center_x > xy_target_gts[:, 0].unsqueeze(0) |
|
b_r = anchor_center_x < xy_target_gts[:, 2].unsqueeze(0) |
|
b_t = anchor_center_y > xy_target_gts[:, 1].unsqueeze(0) |
|
b_b = anchor_center_y < xy_target_gts[:, 3].unsqueeze(0) |
|
|
|
is_in_boxes = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) |
|
is_in_boxes_all = is_in_boxes.sum(1) > 0 |
|
|
|
center_radius = 2.5 |
|
|
|
|
|
b_l = anchor_center_x > ( |
|
target_gts[:, 0] - (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) |
|
b_r = anchor_center_x < ( |
|
target_gts[:, 0] + (center_radius * (xy_target_gts[:, 2] - xy_target_gts[:, 0]))).unsqueeze(0) |
|
b_t = anchor_center_y > ( |
|
target_gts[:, 1] - (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) |
|
b_b = anchor_center_y < ( |
|
target_gts[:, 1] + (center_radius * (xy_target_gts[:, 3] - xy_target_gts[:, 1]))).unsqueeze(0) |
|
|
|
is_in_centers = ((b_l.long() + b_r.long() + b_t.long() + b_b.long()) == 4) |
|
is_in_centers_all = is_in_centers.sum(1) > 0 |
|
|
|
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all |
|
is_in_boxes_and_center = (is_in_boxes & is_in_centers) |
|
|
|
return is_in_boxes_anchor, is_in_boxes_and_center |
|
|
|
|
|
class HungarianMatcherDynamicK(nn.Module): |
|
"""This class computes an assignment between the targets and the predictions of the network |
|
For efficiency reasons, the targets don't include the no_object. Because of this, in general, |
|
there are more predictions than targets. In this case, we do a 1-to-k (dynamic) matching of the best predictions, |
|
while the others are un-matched (and thus treated as non-objects). |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.use_focal = config.use_focal |
|
self.use_fed_loss = config.use_fed_loss |
|
self.cost_class = config.class_weight |
|
self.cost_giou = config.giou_weight |
|
self.cost_bbox = config.l1_weight |
|
self.ota_k = config.ota_k |
|
|
|
if self.use_focal: |
|
self.focal_loss_alpha = config.alpha |
|
self.focal_loss_gamma = config.gamma |
|
|
|
assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0" |
|
|
|
def forward(self, outputs, targets): |
|
""" simOTA for detr""" |
|
with torch.no_grad(): |
|
bs, num_queries = outputs["pred_logits"].shape[:2] |
|
|
|
if self.use_focal or self.use_fed_loss: |
|
out_prob = outputs["pred_logits"].sigmoid() |
|
out_bbox = outputs["pred_boxes"] |
|
else: |
|
out_prob = outputs["pred_logits"].softmax(-1) |
|
out_bbox = outputs["pred_boxes"] |
|
|
|
indices = [] |
|
matched_ids = [] |
|
assert bs == len(targets) |
|
for batch_idx in range(bs): |
|
bz_boxes = out_bbox[batch_idx] |
|
bz_out_prob = out_prob[batch_idx] |
|
bz_tgt_ids = targets[batch_idx]["labels"] |
|
num_insts = len(bz_tgt_ids) |
|
if num_insts == 0: |
|
non_valid = torch.zeros(bz_out_prob.shape[0]).to(bz_out_prob) > 0 |
|
indices_batchi = (non_valid, torch.arange(0, 0).to(bz_out_prob)) |
|
matched_qidx = torch.arange(0, 0).to(bz_out_prob) |
|
indices.append(indices_batchi) |
|
matched_ids.append(matched_qidx) |
|
continue |
|
|
|
bz_gtboxs = targets[batch_idx]['boxes'] |
|
bz_gtboxs_abs_xyxy = targets[batch_idx]['boxes_xyxy'] |
|
fg_mask, is_in_boxes_and_center = get_in_boxes_info( |
|
ops.box_convert(bz_boxes, 'xyxy', 'cxcywh'), |
|
ops.box_convert(bz_gtboxs_abs_xyxy, 'xyxy', 'cxcywh') |
|
) |
|
|
|
pair_wise_ious = ops.box_iou(bz_boxes, bz_gtboxs_abs_xyxy) |
|
|
|
|
|
if self.use_focal: |
|
alpha = self.focal_loss_alpha |
|
gamma = self.focal_loss_gamma |
|
neg_cost_class = (1 - alpha) * (bz_out_prob ** gamma) * (-(1 - bz_out_prob + 1e-8).log()) |
|
pos_cost_class = alpha * ((1 - bz_out_prob) ** gamma) * (-(bz_out_prob + 1e-8).log()) |
|
cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids] |
|
elif self.use_fed_loss: |
|
|
|
neg_cost_class = (-(1 - bz_out_prob + 1e-8).log()) |
|
pos_cost_class = (-(bz_out_prob + 1e-8).log()) |
|
cost_class = pos_cost_class[:, bz_tgt_ids] - neg_cost_class[:, bz_tgt_ids] |
|
else: |
|
cost_class = -bz_out_prob[:, bz_tgt_ids] |
|
|
|
|
|
|
|
|
|
|
|
|
|
bz_image_size_out = targets[batch_idx]['image_size_xyxy'] |
|
bz_image_size_tgt = targets[batch_idx]['image_size_xyxy_tgt'] |
|
|
|
bz_out_bbox_ = bz_boxes / bz_image_size_out |
|
bz_tgt_bbox_ = bz_gtboxs_abs_xyxy / bz_image_size_tgt |
|
cost_bbox = torch.cdist(bz_out_bbox_, bz_tgt_bbox_, p=1) |
|
|
|
cost_giou = -ops.generalized_box_iou(bz_boxes, bz_gtboxs_abs_xyxy) |
|
|
|
|
|
cost = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + 100.0 * ( |
|
~is_in_boxes_and_center) |
|
|
|
cost[~fg_mask] = cost[~fg_mask] + 10000.0 |
|
|
|
|
|
indices_batchi, matched_qidx = self.dynamic_k_matching(cost, pair_wise_ious, bz_gtboxs.shape[0]) |
|
|
|
indices.append(indices_batchi) |
|
matched_ids.append(matched_qidx) |
|
|
|
return indices, matched_ids |
|
|
|
def dynamic_k_matching(self, cost, pair_wise_ious, num_gt): |
|
matching_matrix = torch.zeros_like(cost) |
|
ious_in_boxes_matrix = pair_wise_ious |
|
n_candidate_k = self.ota_k |
|
|
|
|
|
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=0) |
|
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) |
|
|
|
for gt_idx in range(num_gt): |
|
_, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) |
|
matching_matrix[:, gt_idx][pos_idx] = 1.0 |
|
|
|
del topk_ious, dynamic_ks, pos_idx |
|
|
|
anchor_matching_gt = matching_matrix.sum(1) |
|
|
|
if (anchor_matching_gt > 1).sum() > 0: |
|
_, cost_argmin = torch.min(cost[anchor_matching_gt > 1], dim=1) |
|
matching_matrix[anchor_matching_gt > 1] *= 0 |
|
matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 |
|
|
|
while (matching_matrix.sum(0) == 0).any(): |
|
num_zero_gt = (matching_matrix.sum(0) == 0).sum() |
|
matched_query_id = matching_matrix.sum(1) > 0 |
|
cost[matched_query_id] += 100000.0 |
|
unmatch_id = torch.nonzero(matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1) |
|
for gt_idx in unmatch_id: |
|
pos_idx = torch.argmin(cost[:, gt_idx]) |
|
matching_matrix[:, gt_idx][pos_idx] = 1.0 |
|
if (matching_matrix.sum(1) > 1).sum() > 0: |
|
_, cost_argmin = torch.min(cost[anchor_matching_gt > 1], |
|
dim=1) |
|
matching_matrix[anchor_matching_gt > 1] *= 0 |
|
matching_matrix[anchor_matching_gt > 1, cost_argmin,] = 1 |
|
|
|
assert not (matching_matrix.sum(0) == 0).any() |
|
selected_query = matching_matrix.sum(1) > 0 |
|
gt_indices = matching_matrix[selected_query].max(1)[1] |
|
assert selected_query.sum() == len(gt_indices) |
|
|
|
cost[matching_matrix == 0] = cost[matching_matrix == 0] + float('inf') |
|
matched_query_id = torch.min(cost, dim=0)[1] |
|
|
|
return (selected_query, gt_indices), matched_query_id |
|
|