|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
from torch import Tensor
|
|
|
|
from ...core import register
|
|
|
|
__all__ = [
|
|
"DetNMSPostProcessor",
|
|
]
|
|
|
|
|
|
@register()
|
|
class DetNMSPostProcessor(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
iou_threshold=0.7,
|
|
score_threshold=0.01,
|
|
keep_topk=300,
|
|
box_fmt="cxcywh",
|
|
logit_fmt="sigmoid",
|
|
) -> None:
|
|
super().__init__()
|
|
self.iou_threshold = iou_threshold
|
|
self.score_threshold = score_threshold
|
|
self.keep_topk = keep_topk
|
|
self.box_fmt = box_fmt.lower()
|
|
self.logit_fmt = logit_fmt.lower()
|
|
self.logit_func = getattr(F, self.logit_fmt, None)
|
|
self.deploy_mode = False
|
|
|
|
def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
|
|
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
|
|
pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
|
pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
|
|
|
|
values, pred_labels = torch.max(logits, dim=-1)
|
|
|
|
if self.logit_func:
|
|
pred_scores = self.logit_func(values)
|
|
else:
|
|
pred_scores = values
|
|
|
|
|
|
if self.deploy_mode:
|
|
blobs = {
|
|
"pred_labels": pred_labels,
|
|
"pred_boxes": pred_boxes,
|
|
"pred_scores": pred_scores,
|
|
}
|
|
return blobs
|
|
|
|
results = []
|
|
for i in range(logits.shape[0]):
|
|
score_keep = pred_scores[i] > self.score_threshold
|
|
pred_box = pred_boxes[i][score_keep]
|
|
pred_label = pred_labels[i][score_keep]
|
|
pred_score = pred_scores[i][score_keep]
|
|
|
|
keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
|
|
keep = keep[: self.keep_topk]
|
|
|
|
blob = {
|
|
"labels": pred_label[keep],
|
|
"boxes": pred_box[keep],
|
|
"scores": pred_score[keep],
|
|
}
|
|
|
|
results.append(blob)
|
|
|
|
return results
|
|
|
|
def deploy(
|
|
self,
|
|
):
|
|
self.eval()
|
|
self.deploy_mode = True
|
|
return self
|
|
|