Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.structures import InstanceData
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.testing import get_detector_cfg
from mmdet.utils import register_all_modules
class TestDINO(TestCase):
def setUp(self):
register_all_modules()
def test_dino_head_loss(self):
"""Tests transformer head loss when truth is empty and non-empty."""
s = 256
metainfo = {
'img_shape': (s, s),
'scale_factor': (1, 1),
'pad_shape': (s, s),
'batch_input_shape': (s, s)
}
data_sample = DetDataSample()
data_sample.set_metainfo(metainfo)
configs = [get_detector_cfg('dino/dino-4scale_r50_8xb2-12e_coco.py')]
for config in configs:
model = MODELS.build(config)
model.init_weights()
random_image = torch.rand(1, 3, s, s)
# Test that empty ground truth encourages the network to
# predict background
gt_instances = InstanceData()
gt_instances.bboxes = torch.empty((0, 4))
gt_instances.labels = torch.LongTensor([])
data_sample.gt_instances = gt_instances
batch_data_samples_1 = [data_sample]
empty_gt_losses = model.loss(
random_image, batch_data_samples=batch_data_samples_1)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
for key, loss in empty_gt_losses.items():
_loss = loss.item()
if 'bbox' in key or 'iou' in key or 'dn' in key:
self.assertEqual(
_loss, 0, f'there should be no {key}({_loss}) '
f'when no ground true boxes')
elif 'cls' in key:
self.assertGreater(_loss, 0,
f'{key}({_loss}) should be non-zero')
# When truth is non-empty then both cls and box loss should
# be nonzero for random inputs
gt_instances = InstanceData()
gt_instances.bboxes = torch.Tensor(
[[23.6667, 23.8757, 238.6326, 151.8874]])
gt_instances.labels = torch.LongTensor([2])
data_sample.gt_instances = gt_instances
batch_data_samples_2 = [data_sample]
one_gt_losses = model.loss(
random_image, batch_data_samples=batch_data_samples_2)
for loss in one_gt_losses.values():
self.assertGreater(
loss.item(), 0,
'cls loss, or box loss, or iou loss should be non-zero')
model.eval()
# test _forward
model._forward(
random_image, batch_data_samples=batch_data_samples_2)
# test only predict
model.predict(
random_image,
batch_data_samples=batch_data_samples_2,
rescale=True)