File size: 3,175 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.structures import InstanceData
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.testing import get_detector_cfg
from mmdet.utils import register_all_modules
class TestConditionalDETR(TestCase):
def setUp(self) -> None:
register_all_modules()
def test_conditional_detr_head_loss(self):
"""Tests transformer head loss when truth is empty and non-empty."""
s = 256
metainfo = {
'img_shape': (s, s),
'scale_factor': (1, 1),
'pad_shape': (s, s),
'batch_input_shape': (s, s)
}
img_metas = DetDataSample()
img_metas.set_metainfo(metainfo)
batch_data_samples = []
batch_data_samples.append(img_metas)
config = get_detector_cfg(
'conditional_detr/conditional-detr_r50_8xb2-50e_coco.py')
model = MODELS.build(config)
model.init_weights()
random_image = torch.rand(1, 3, s, s)
# Test that empty ground truth encourages the network to
# predict background
gt_instances = InstanceData()
gt_instances.bboxes = torch.empty((0, 4))
gt_instances.labels = torch.LongTensor([])
img_metas.gt_instances = gt_instances
batch_data_samples1 = []
batch_data_samples1.append(img_metas)
empty_gt_losses = model.loss(
random_image, batch_data_samples=batch_data_samples1)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
for key, loss in empty_gt_losses.items():
if 'cls' in key:
self.assertGreater(loss.item(), 0,
'cls loss should be non-zero')
elif 'bbox' in key:
self.assertEqual(
loss.item(), 0,
'there should be no box loss when no ground true boxes')
elif 'iou' in key:
self.assertEqual(
loss.item(), 0,
'there should be no iou loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should be nonzero
# for random inputs
gt_instances = InstanceData()
gt_instances.bboxes = torch.Tensor(
[[23.6667, 23.8757, 238.6326, 151.8874]])
gt_instances.labels = torch.LongTensor([2])
img_metas.gt_instances = gt_instances
batch_data_samples2 = []
batch_data_samples2.append(img_metas)
one_gt_losses = model.loss(
random_image, batch_data_samples=batch_data_samples2)
for loss in one_gt_losses.values():
self.assertGreater(
loss.item(), 0,
'cls loss, or box loss, or iou loss should be non-zero')
model.eval()
# test _forward
model._forward(random_image, batch_data_samples=batch_data_samples2)
# test only predict
model.predict(
random_image, batch_data_samples=batch_data_samples2, rescale=True)
|