Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
from ...core import register | |
__all__ = ["DFINEPostProcessor"] | |
def mod(a, b): | |
out = a - a // b * b | |
return out | |
class DFINEPostProcessor(nn.Module): | |
__share__ = ["num_classes", "use_focal_loss", "num_top_queries", "remap_mscoco_category"] | |
def __init__( | |
self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False | |
) -> None: | |
super().__init__() | |
self.use_focal_loss = use_focal_loss | |
self.num_top_queries = num_top_queries | |
self.num_classes = int(num_classes) | |
self.remap_mscoco_category = remap_mscoco_category | |
self.deploy_mode = False | |
def extra_repr(self) -> str: | |
return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}" | |
# def forward(self, outputs, orig_target_sizes): | |
def forward(self, outputs, orig_target_sizes: torch.Tensor): | |
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] | |
# orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) | |
bbox_pred = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy") | |
bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) | |
if self.use_focal_loss: | |
scores = F.sigmoid(logits) | |
scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) | |
# TODO for older tensorrt | |
# labels = index % self.num_classes | |
labels = mod(index, self.num_classes) | |
index = index // self.num_classes | |
boxes = bbox_pred.gather( | |
dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]) | |
) | |
else: | |
scores = F.softmax(logits)[:, :, :-1] | |
scores, labels = scores.max(dim=-1) | |
if scores.shape[1] > self.num_top_queries: | |
scores, index = torch.topk(scores, self.num_top_queries, dim=-1) | |
labels = torch.gather(labels, dim=1, index=index) | |
boxes = torch.gather( | |
boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]) | |
) | |
# TODO for onnx export | |
if self.deploy_mode: | |
return labels, boxes, scores | |
# TODO | |
if self.remap_mscoco_category: | |
from ...data.dataset import mscoco_label2category | |
labels = ( | |
torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()]) | |
.to(boxes.device) | |
.reshape(labels.shape) | |
) | |
results = [] | |
for lab, box, sco in zip(labels, boxes, scores): | |
result = dict(labels=lab, boxes=box, scores=sco) | |
results.append(result) | |
return results | |
def deploy( | |
self, | |
): | |
self.eval() | |
self.deploy_mode = True | |
return self | |