File size: 1,303 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import unittest
import torch
from mmdet.models.losses import GaussianFocalLoss
class TestGaussianFocalLoss(unittest.TestCase):
def test_forward(self):
pred = torch.rand((10, 4))
target = torch.rand((10, 4))
gaussian_focal_loss = GaussianFocalLoss()
loss1 = gaussian_focal_loss(pred, target)
self.assertIsInstance(loss1, torch.Tensor)
loss2 = gaussian_focal_loss(pred, target, avg_factor=0.5)
self.assertIsInstance(loss2, torch.Tensor)
# test reduction
gaussian_focal_loss = GaussianFocalLoss(reduction='none')
loss = gaussian_focal_loss(pred, target)
self.assertTrue(loss.shape == (10, 4))
# test reduction_override
loss = gaussian_focal_loss(pred, target, reduction_override='mean')
self.assertTrue(loss.ndim == 0)
# Only supports None, 'none', 'mean', 'sum'
with self.assertRaises(AssertionError):
gaussian_focal_loss(pred, target, reduction_override='max')
# test pos_inds
pos_inds = (torch.rand(5) * 8).long()
pos_labels = (torch.rand(5) * 2).long()
gaussian_focal_loss = GaussianFocalLoss()
loss = gaussian_focal_loss(pred, target, pos_inds, pos_labels)
self.assertIsInstance(loss, torch.Tensor)
|