Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
import copy
import pytest
import torch
from mmengine.structures import InstanceData
from mmdet.models.utils import (empty_instances, filter_gt_instances,
rename_loss_dict, reweight_loss_dict,
unpack_gt_instances)
from mmdet.testing import demo_mm_inputs
def test_parse_gt_instance_info():
packed_inputs = demo_mm_inputs()['data_samples']
batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
= unpack_gt_instances(packed_inputs)
assert len(batch_gt_instances) == len(packed_inputs)
assert len(batch_gt_instances_ignore) == len(packed_inputs)
assert len(batch_img_metas) == len(packed_inputs)
def test_process_empty_roi():
batch_size = 2
batch_img_metas = [{'ori_shape': (10, 12)}] * batch_size
device = torch.device('cpu')
results_list = empty_instances(batch_img_metas, device, task_type='bbox')
assert len(results_list) == batch_size
for results in results_list:
assert isinstance(results, InstanceData)
assert len(results) == 0
assert torch.allclose(results.bboxes, torch.zeros(0, 4, device=device))
results_list = empty_instances(
batch_img_metas,
device,
task_type='mask',
instance_results=results_list,
mask_thr_binary=0.5)
assert len(results_list) == batch_size
for results in results_list:
assert isinstance(results, InstanceData)
assert len(results) == 0
assert results.masks.shape == (0, 10, 12)
# batch_img_metas and instance_results length must be the same
with pytest.raises(AssertionError):
empty_instances(
batch_img_metas,
device,
task_type='mask',
instance_results=[results_list[0]] * 3)
def test_filter_gt_instances():
packed_inputs = demo_mm_inputs()['data_samples']
score_thr = 0.7
with pytest.raises(AssertionError):
filter_gt_instances(packed_inputs, score_thr=score_thr)
# filter no instances by score
for inputs in packed_inputs:
inputs.gt_instances.scores = torch.ones_like(
inputs.gt_instances.labels).float()
filtered_packed_inputs = filter_gt_instances(
copy.deepcopy(packed_inputs), score_thr=score_thr)
for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs):
assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances)
# filter all instances
for inputs in packed_inputs:
inputs.gt_instances.scores = torch.zeros_like(
inputs.gt_instances.labels).float()
filtered_packed_inputs = filter_gt_instances(
copy.deepcopy(packed_inputs), score_thr=score_thr)
for filtered_inputs in filtered_packed_inputs:
assert len(filtered_inputs.gt_instances) == 0
packed_inputs = demo_mm_inputs()['data_samples']
# filter no instances by size
wh_thr = (0, 0)
filtered_packed_inputs = filter_gt_instances(
copy.deepcopy(packed_inputs), wh_thr=wh_thr)
for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs):
assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances)
# filter all instances by size
for inputs in packed_inputs:
img_shape = inputs.img_shape
wh_thr = (max(wh_thr[0], img_shape[0]), max(wh_thr[1], img_shape[1]))
filtered_packed_inputs = filter_gt_instances(
copy.deepcopy(packed_inputs), wh_thr=wh_thr)
for filtered_inputs in filtered_packed_inputs:
assert len(filtered_inputs.gt_instances) == 0
def test_rename_loss_dict():
prefix = 'sup_'
losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)}
sup_losses = rename_loss_dict(prefix, losses)
for name in losses.keys():
assert sup_losses[prefix + name] == losses[name]
def test_reweight_loss_dict():
weight = 4
losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)}
weighted_losses = reweight_loss_dict(copy.deepcopy(losses), weight)
for name in losses.keys():
assert weighted_losses[name] == losses[name] * weight