|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
|
|
from ...core import register
|
|
from ...misc import box_ops, dist_utils
|
|
|
|
|
|
@register()
|
|
class DetCriterion(torch.nn.Module):
|
|
"""Default Detection Criterion"""
|
|
|
|
__share__ = ["num_classes"]
|
|
__inject__ = ["matcher"]
|
|
|
|
def __init__(
|
|
self,
|
|
losses,
|
|
weight_dict,
|
|
num_classes=80,
|
|
alpha=0.75,
|
|
gamma=2.0,
|
|
box_fmt="cxcywh",
|
|
matcher=None,
|
|
):
|
|
"""
|
|
Args:
|
|
losses (list[str]): requested losses, support ['boxes', 'vfl', 'focal']
|
|
weight_dict (dict[str, float)]: corresponding losses weight, including
|
|
['loss_bbox', 'loss_giou', 'loss_vfl', 'loss_focal']
|
|
box_fmt (str): in box format, 'cxcywh' or 'xyxy'
|
|
matcher (Matcher): matcher used to match source to target
|
|
"""
|
|
super().__init__()
|
|
self.losses = losses
|
|
self.weight_dict = weight_dict
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
self.num_classes = num_classes
|
|
self.box_fmt = box_fmt
|
|
assert matcher is not None, ""
|
|
self.matcher = matcher
|
|
|
|
def forward(self, outputs, targets, **kwargs):
|
|
"""
|
|
Args:
|
|
outputs: Dict[Tensor], 'pred_boxes', 'pred_logits', 'meta'.
|
|
targets, List[Dict[str, Tensor]], len(targets) == batch_size.
|
|
kwargs, store other information such as current epoch id.
|
|
Return:
|
|
losses, Dict[str, Tensor]
|
|
"""
|
|
matched = self.matcher(outputs, targets)
|
|
values = matched["values"]
|
|
indices = matched["indices"]
|
|
num_boxes = self._get_positive_nums(indices)
|
|
|
|
|
|
losses = {}
|
|
for loss in self.losses:
|
|
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
|
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
|
losses.update(l_dict)
|
|
return losses
|
|
|
|
def _get_src_permutation_idx(self, indices):
|
|
|
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
src_idx = torch.cat([src for (src, _) in indices])
|
|
return batch_idx, src_idx
|
|
|
|
def _get_tgt_permutation_idx(self, indices):
|
|
|
|
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
|
return batch_idx, tgt_idx
|
|
|
|
def _get_positive_nums(self, indices):
|
|
|
|
num_pos = sum(len(i) for (i, _) in indices)
|
|
num_pos = torch.as_tensor([num_pos], dtype=torch.float32, device=indices[0][0].device)
|
|
if dist_utils.is_dist_available_and_initialized():
|
|
torch.distributed.all_reduce(num_pos)
|
|
num_pos = torch.clamp(num_pos / dist_utils.get_world_size(), min=1).item()
|
|
return num_pos
|
|
|
|
def loss_labels_focal(self, outputs, targets, indices, num_boxes):
|
|
assert "pred_logits" in outputs
|
|
src_logits = outputs["pred_logits"]
|
|
|
|
idx = self._get_src_permutation_idx(indices)
|
|
target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
|
|
target_classes = torch.full(
|
|
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
|
)
|
|
target_classes[idx] = target_classes_o
|
|
|
|
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1].to(
|
|
src_logits.dtype
|
|
)
|
|
loss = torchvision.ops.sigmoid_focal_loss(
|
|
src_logits, target, self.alpha, self.gamma, reduction="none"
|
|
)
|
|
loss = loss.sum() / num_boxes
|
|
return {"loss_focal": loss}
|
|
|
|
def loss_labels_vfl(self, outputs, targets, indices, num_boxes):
|
|
assert "pred_boxes" in outputs
|
|
idx = self._get_src_permutation_idx(indices)
|
|
|
|
src_boxes = outputs["pred_boxes"][idx]
|
|
target_boxes = torch.cat([t["boxes"][j] for t, (_, j) in zip(targets, indices)], dim=0)
|
|
|
|
src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
|
target_boxes = torchvision.ops.box_convert(
|
|
target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
|
|
)
|
|
iou, _ = box_ops.elementwise_box_iou(src_boxes.detach(), target_boxes)
|
|
|
|
src_logits: torch.Tensor = outputs["pred_logits"]
|
|
target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
|
|
target_classes = torch.full(
|
|
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
|
)
|
|
target_classes[idx] = target_classes_o
|
|
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
|
|
|
target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
|
|
target_score_o[idx] = iou.to(src_logits.dtype)
|
|
target_score = target_score_o.unsqueeze(-1) * target
|
|
|
|
src_score = F.sigmoid(src_logits.detach())
|
|
weight = self.alpha * src_score.pow(self.gamma) * (1 - target) + target_score
|
|
|
|
loss = F.binary_cross_entropy_with_logits(
|
|
src_logits, target_score, weight=weight, reduction="none"
|
|
)
|
|
loss = loss.sum() / num_boxes
|
|
return {"loss_vfl": loss}
|
|
|
|
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
|
assert "pred_boxes" in outputs
|
|
idx = self._get_src_permutation_idx(indices)
|
|
src_boxes = outputs["pred_boxes"][idx]
|
|
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
|
|
losses = {}
|
|
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
|
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
|
|
|
src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
|
target_boxes = torchvision.ops.box_convert(
|
|
target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
|
|
)
|
|
loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
|
|
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
|
return losses
|
|
|
|
def loss_boxes_giou(self, outputs, targets, indices, num_boxes):
|
|
assert "pred_boxes" in outputs
|
|
idx = self._get_src_permutation_idx(indices)
|
|
src_boxes = outputs["pred_boxes"][idx]
|
|
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
|
|
|
losses = {}
|
|
src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
|
target_boxes = torchvision.ops.box_convert(
|
|
target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
|
|
)
|
|
loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
|
|
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
|
return losses
|
|
|
|
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
|
loss_map = {
|
|
"boxes": self.loss_boxes,
|
|
"giou": self.loss_boxes_giou,
|
|
"vfl": self.loss_labels_vfl,
|
|
"focal": self.loss_labels_focal,
|
|
}
|
|
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
|
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
|
|