|
|
|
from typing import Optional |
|
|
|
import torch |
|
from mmengine.hooks import Hook |
|
from mmengine.runner import Runner |
|
|
|
from mmdet.registry import HOOKS |
|
|
|
|
|
@HOOKS.register_module() |
|
class CheckInvalidLossHook(Hook): |
|
"""Check invalid loss hook. |
|
|
|
This hook will regularly check whether the loss is valid |
|
during training. |
|
|
|
Args: |
|
interval (int): Checking interval (every k iterations). |
|
Default: 50. |
|
""" |
|
|
|
def __init__(self, interval: int = 50) -> None: |
|
self.interval = interval |
|
|
|
def after_train_iter(self, |
|
runner: Runner, |
|
batch_idx: int, |
|
data_batch: Optional[dict] = None, |
|
outputs: Optional[dict] = None) -> None: |
|
"""Regularly check whether the loss is valid every n iterations. |
|
|
|
Args: |
|
runner (:obj:`Runner`): The runner of the training process. |
|
batch_idx (int): The index of the current batch in the train loop. |
|
data_batch (dict, Optional): Data from dataloader. |
|
Defaults to None. |
|
outputs (dict, Optional): Outputs from model. Defaults to None. |
|
""" |
|
if self.every_n_train_iters(runner, self.interval): |
|
assert torch.isfinite(outputs['loss']), \ |
|
runner.logger.info('loss become infinite or NaN!') |
|
|