|
|
|
from unittest import TestCase |
|
|
|
import torch |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.structures import DetDataSample |
|
from mmdet.testing import get_detector_cfg |
|
from mmdet.utils import register_all_modules |
|
|
|
|
|
class TestConditionalDETR(TestCase): |
|
|
|
def setUp(self) -> None: |
|
register_all_modules() |
|
|
|
def test_conditional_detr_head_loss(self): |
|
"""Tests transformer head loss when truth is empty and non-empty.""" |
|
s = 256 |
|
metainfo = { |
|
'img_shape': (s, s), |
|
'scale_factor': (1, 1), |
|
'pad_shape': (s, s), |
|
'batch_input_shape': (s, s) |
|
} |
|
img_metas = DetDataSample() |
|
img_metas.set_metainfo(metainfo) |
|
batch_data_samples = [] |
|
batch_data_samples.append(img_metas) |
|
|
|
config = get_detector_cfg( |
|
'conditional_detr/conditional-detr_r50_8xb2-50e_coco.py') |
|
|
|
model = MODELS.build(config) |
|
model.init_weights() |
|
random_image = torch.rand(1, 3, s, s) |
|
|
|
|
|
|
|
gt_instances = InstanceData() |
|
gt_instances.bboxes = torch.empty((0, 4)) |
|
gt_instances.labels = torch.LongTensor([]) |
|
img_metas.gt_instances = gt_instances |
|
batch_data_samples1 = [] |
|
batch_data_samples1.append(img_metas) |
|
empty_gt_losses = model.loss( |
|
random_image, batch_data_samples=batch_data_samples1) |
|
|
|
|
|
for key, loss in empty_gt_losses.items(): |
|
if 'cls' in key: |
|
self.assertGreater(loss.item(), 0, |
|
'cls loss should be non-zero') |
|
elif 'bbox' in key: |
|
self.assertEqual( |
|
loss.item(), 0, |
|
'there should be no box loss when no ground true boxes') |
|
elif 'iou' in key: |
|
self.assertEqual( |
|
loss.item(), 0, |
|
'there should be no iou loss when there are no true boxes') |
|
|
|
|
|
|
|
gt_instances = InstanceData() |
|
gt_instances.bboxes = torch.Tensor( |
|
[[23.6667, 23.8757, 238.6326, 151.8874]]) |
|
gt_instances.labels = torch.LongTensor([2]) |
|
img_metas.gt_instances = gt_instances |
|
batch_data_samples2 = [] |
|
batch_data_samples2.append(img_metas) |
|
one_gt_losses = model.loss( |
|
random_image, batch_data_samples=batch_data_samples2) |
|
for loss in one_gt_losses.values(): |
|
self.assertGreater( |
|
loss.item(), 0, |
|
'cls loss, or box loss, or iou loss should be non-zero') |
|
|
|
model.eval() |
|
|
|
model._forward(random_image, batch_data_samples=batch_data_samples2) |
|
|
|
model.predict( |
|
random_image, batch_data_samples=batch_data_samples2, rescale=True) |
|
|