Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmdet.engine.hooks import CheckInvalidLossHook
class TestCheckInvalidLossHook(TestCase):
def test_after_train_iter(self):
n = 50
hook = CheckInvalidLossHook(n)
runner = Mock()
runner.logger = Mock()
runner.logger.info = Mock()
# Test `after_train_iter` function within the n iteration.
runner.iter = 10
outputs = dict(loss=torch.LongTensor([2]))
hook.after_train_iter(runner, 10, outputs=outputs)
outputs = dict(loss=torch.tensor(float('nan')))
hook.after_train_iter(runner, 10, outputs=outputs)
outputs = dict(loss=torch.tensor(float('inf')))
hook.after_train_iter(runner, 10, outputs=outputs)
# Test `after_train_iter` at the n iteration.
runner.iter = n - 1
outputs = dict(loss=torch.LongTensor([2]))
hook.after_train_iter(runner, n - 1, outputs=outputs)
outputs = dict(loss=torch.tensor(float('nan')))
with self.assertRaises(AssertionError):
hook.after_train_iter(runner, n - 1, outputs=outputs)
outputs = dict(loss=torch.tensor(float('inf')))
with self.assertRaises(AssertionError):
hook.after_train_iter(runner, n - 1, outputs=outputs)