File size: 1,678 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
from mmdet.models.necks import CTResNetNeck
class TestCTResNetNeck(unittest.TestCase):
def test_init(self):
# num_filters/num_kernels must be same length
with self.assertRaises(AssertionError):
CTResNetNeck(
in_channels=10,
num_deconv_filters=(10, 10),
num_deconv_kernels=(4, ))
ct_resnet_neck = CTResNetNeck(
in_channels=16,
num_deconv_filters=(8, 8),
num_deconv_kernels=(4, 4),
use_dcn=False)
ct_resnet_neck.init_weights()
def test_forward(self):
in_channels = 16
num_filters = (8, 8)
num_kernels = (4, 4)
feat = torch.rand(1, 16, 4, 4)
ct_resnet_neck = CTResNetNeck(
in_channels=in_channels,
num_deconv_filters=num_filters,
num_deconv_kernels=num_kernels,
use_dcn=False)
# feat must be list or tuple
with self.assertRaises(AssertionError):
ct_resnet_neck(feat)
out_feat = ct_resnet_neck([feat])[0]
self.assertEqual(out_feat.shape, (1, num_filters[-1], 16, 16))
if torch.cuda.is_available():
# test dcn
ct_resnet_neck = CTResNetNeck(
in_channels=in_channels,
num_deconv_filters=num_filters,
num_deconv_kernels=num_kernels)
ct_resnet_neck = ct_resnet_neck.cuda()
feat = feat.cuda()
out_feat = ct_resnet_neck([feat])[0]
self.assertEqual(out_feat.shape, (1, num_filters[-1], 16, 16))
|