|
|
|
from unittest import TestCase |
|
from unittest.mock import Mock |
|
|
|
import torch |
|
|
|
from mmdet.engine.hooks import CheckInvalidLossHook |
|
|
|
|
|
class TestCheckInvalidLossHook(TestCase): |
|
|
|
def test_after_train_iter(self): |
|
n = 50 |
|
hook = CheckInvalidLossHook(n) |
|
runner = Mock() |
|
runner.logger = Mock() |
|
runner.logger.info = Mock() |
|
|
|
|
|
runner.iter = 10 |
|
outputs = dict(loss=torch.LongTensor([2])) |
|
hook.after_train_iter(runner, 10, outputs=outputs) |
|
outputs = dict(loss=torch.tensor(float('nan'))) |
|
hook.after_train_iter(runner, 10, outputs=outputs) |
|
outputs = dict(loss=torch.tensor(float('inf'))) |
|
hook.after_train_iter(runner, 10, outputs=outputs) |
|
|
|
|
|
runner.iter = n - 1 |
|
outputs = dict(loss=torch.LongTensor([2])) |
|
hook.after_train_iter(runner, n - 1, outputs=outputs) |
|
outputs = dict(loss=torch.tensor(float('nan'))) |
|
with self.assertRaises(AssertionError): |
|
hook.after_train_iter(runner, n - 1, outputs=outputs) |
|
outputs = dict(loss=torch.tensor(float('inf'))) |
|
with self.assertRaises(AssertionError): |
|
hook.after_train_iter(runner, n - 1, outputs=outputs) |
|
|