|
import copy |
|
|
|
import pytest |
|
import torch |
|
from mmengine.structures import InstanceData |
|
|
|
from mmdet.models.utils import (empty_instances, filter_gt_instances, |
|
rename_loss_dict, reweight_loss_dict, |
|
unpack_gt_instances) |
|
from mmdet.testing import demo_mm_inputs |
|
|
|
|
|
def test_parse_gt_instance_info(): |
|
packed_inputs = demo_mm_inputs()['data_samples'] |
|
batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \ |
|
= unpack_gt_instances(packed_inputs) |
|
assert len(batch_gt_instances) == len(packed_inputs) |
|
assert len(batch_gt_instances_ignore) == len(packed_inputs) |
|
assert len(batch_img_metas) == len(packed_inputs) |
|
|
|
|
|
def test_process_empty_roi(): |
|
batch_size = 2 |
|
batch_img_metas = [{'ori_shape': (10, 12)}] * batch_size |
|
device = torch.device('cpu') |
|
|
|
results_list = empty_instances(batch_img_metas, device, task_type='bbox') |
|
assert len(results_list) == batch_size |
|
for results in results_list: |
|
assert isinstance(results, InstanceData) |
|
assert len(results) == 0 |
|
assert torch.allclose(results.bboxes, torch.zeros(0, 4, device=device)) |
|
|
|
results_list = empty_instances( |
|
batch_img_metas, |
|
device, |
|
task_type='mask', |
|
instance_results=results_list, |
|
mask_thr_binary=0.5) |
|
assert len(results_list) == batch_size |
|
for results in results_list: |
|
assert isinstance(results, InstanceData) |
|
assert len(results) == 0 |
|
assert results.masks.shape == (0, 10, 12) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
empty_instances( |
|
batch_img_metas, |
|
device, |
|
task_type='mask', |
|
instance_results=[results_list[0]] * 3) |
|
|
|
|
|
def test_filter_gt_instances(): |
|
packed_inputs = demo_mm_inputs()['data_samples'] |
|
score_thr = 0.7 |
|
with pytest.raises(AssertionError): |
|
filter_gt_instances(packed_inputs, score_thr=score_thr) |
|
|
|
|
|
for inputs in packed_inputs: |
|
inputs.gt_instances.scores = torch.ones_like( |
|
inputs.gt_instances.labels).float() |
|
filtered_packed_inputs = filter_gt_instances( |
|
copy.deepcopy(packed_inputs), score_thr=score_thr) |
|
for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs): |
|
assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances) |
|
|
|
|
|
for inputs in packed_inputs: |
|
inputs.gt_instances.scores = torch.zeros_like( |
|
inputs.gt_instances.labels).float() |
|
filtered_packed_inputs = filter_gt_instances( |
|
copy.deepcopy(packed_inputs), score_thr=score_thr) |
|
for filtered_inputs in filtered_packed_inputs: |
|
assert len(filtered_inputs.gt_instances) == 0 |
|
|
|
packed_inputs = demo_mm_inputs()['data_samples'] |
|
|
|
wh_thr = (0, 0) |
|
filtered_packed_inputs = filter_gt_instances( |
|
copy.deepcopy(packed_inputs), wh_thr=wh_thr) |
|
for filtered_inputs, inputs in zip(filtered_packed_inputs, packed_inputs): |
|
assert len(filtered_inputs.gt_instances) == len(inputs.gt_instances) |
|
|
|
|
|
for inputs in packed_inputs: |
|
img_shape = inputs.img_shape |
|
wh_thr = (max(wh_thr[0], img_shape[0]), max(wh_thr[1], img_shape[1])) |
|
filtered_packed_inputs = filter_gt_instances( |
|
copy.deepcopy(packed_inputs), wh_thr=wh_thr) |
|
for filtered_inputs in filtered_packed_inputs: |
|
assert len(filtered_inputs.gt_instances) == 0 |
|
|
|
|
|
def test_rename_loss_dict(): |
|
prefix = 'sup_' |
|
losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)} |
|
sup_losses = rename_loss_dict(prefix, losses) |
|
for name in losses.keys(): |
|
assert sup_losses[prefix + name] == losses[name] |
|
|
|
|
|
def test_reweight_loss_dict(): |
|
weight = 4 |
|
losses = {'cls_loss': torch.tensor(2.), 'reg_loss': torch.tensor(1.)} |
|
weighted_losses = reweight_loss_dict(copy.deepcopy(losses), weight) |
|
for name in losses.keys(): |
|
assert weighted_losses[name] == losses[name] * weight |
|
|