|
|
|
import torch |
|
from mmengine.hooks import Hook |
|
from mmengine.runner import Runner |
|
from mmdet.registry import HOOKS |
|
from typing import Optional, Dict, Any |
|
|
|
|
|
@HOOKS.register_module() |
|
class NanRecoveryHook(Hook): |
|
"""Hook to handle NaN losses gracefully without crashing training. |
|
|
|
This hook detects NaN losses and handles them by: |
|
1. Replacing NaN losses with the last valid loss value |
|
2. Skipping gradient updates for that iteration |
|
3. Logging the recovery for monitoring |
|
4. Allowing training to continue normally |
|
""" |
|
|
|
def __init__(self, |
|
fallback_loss: float = 0.5, |
|
max_consecutive_nans: int = 100, |
|
log_interval: int = 50): |
|
self.fallback_loss = fallback_loss |
|
self.max_consecutive_nans = max_consecutive_nans |
|
self.log_interval = log_interval |
|
|
|
|
|
self.last_valid_loss = fallback_loss |
|
self.consecutive_nans = 0 |
|
self.total_nans = 0 |
|
self.nan_iterations = [] |
|
|
|
def before_train_iter(self, |
|
runner: Runner, |
|
batch_idx: int, |
|
data_batch: Optional[dict] = None) -> None: |
|
"""Reset any state before training iteration.""" |
|
pass |
|
|
|
def after_train_iter(self, |
|
runner: Runner, |
|
batch_idx: int, |
|
data_batch: Optional[dict] = None, |
|
outputs: Optional[Dict[str, Any]] = None) -> None: |
|
"""Handle NaN losses after training iteration.""" |
|
if outputs is None: |
|
return |
|
|
|
|
|
has_nan = False |
|
|
|
|
|
total_loss = outputs.get('loss') |
|
if total_loss is not None and (torch.isnan(total_loss) or torch.isinf(total_loss)): |
|
has_nan = True |
|
|
|
|
|
for key, value in outputs.items(): |
|
if isinstance(value, torch.Tensor) and 'loss' in key.lower(): |
|
if torch.isnan(value) or torch.isinf(value): |
|
has_nan = True |
|
break |
|
|
|
if has_nan: |
|
self._handle_nan_loss(runner, batch_idx, outputs) |
|
else: |
|
|
|
if total_loss is not None: |
|
self.last_valid_loss = float(total_loss.item()) |
|
if self.consecutive_nans > 0: |
|
runner.logger.info(f"π Loss recovered after {self.consecutive_nans} NaN iterations") |
|
self.consecutive_nans = 0 |
|
|
|
def _handle_nan_loss(self, runner: Runner, batch_idx: int, outputs: Dict[str, Any]) -> None: |
|
"""Handle NaN loss by replacing with detached fallback and managing state.""" |
|
self.consecutive_nans += 1 |
|
self.total_nans += 1 |
|
self.nan_iterations.append(batch_idx) |
|
|
|
|
|
last_good_iteration = batch_idx |
|
last_good_loss = self.last_valid_loss |
|
|
|
for hook in runner.hooks: |
|
if hasattr(hook, 'last_good_iteration') and hasattr(hook, 'last_good_loss'): |
|
if hook.last_good_loss is not None: |
|
last_good_iteration = hook.last_good_iteration |
|
last_good_loss = hook.last_good_loss |
|
break |
|
|
|
|
|
if 'loss' in outputs and outputs['loss'] is not None: |
|
fallback_tensor = torch.tensor( |
|
last_good_loss, |
|
device=outputs['loss'].device, |
|
dtype=outputs['loss'].dtype |
|
|
|
) |
|
outputs['loss'] = fallback_tensor |
|
|
|
|
|
self._fix_loss_components(outputs, last_good_loss) |
|
|
|
|
|
if self.consecutive_nans <= 5 or self.consecutive_nans % self.log_interval == 0: |
|
runner.logger.warning( |
|
f"π NaN Recovery at iteration {batch_idx}: " |
|
f"Using last good loss {last_good_loss:.4f} from iteration {last_good_iteration}. " |
|
f"Consecutive NaNs: {self.consecutive_nans}, Total: {self.total_nans}" |
|
) |
|
|
|
|
|
if self.consecutive_nans >= self.max_consecutive_nans: |
|
self._reset_nan_state(runner, last_good_iteration) |
|
|
|
def _reset_nan_state(self, runner: Runner, last_good_iteration: int) -> None: |
|
"""Reset training state when too many consecutive NaNs.""" |
|
runner.logger.error( |
|
f"π Too many consecutive NaN losses ({self.consecutive_nans}). " |
|
f"Resetting to last good state from iteration {last_good_iteration}" |
|
) |
|
|
|
try: |
|
|
|
if hasattr(runner.model, 'zero_grad'): |
|
runner.model.zero_grad() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.consecutive_nans = 0 |
|
|
|
runner.logger.info(f"β
NaN state reset. Resuming training...") |
|
|
|
except Exception as e: |
|
runner.logger.error(f"β Failed to reset NaN state: {e}") |
|
|
|
def _fix_loss_components(self, outputs: Dict[str, Any], fallback_loss: float = None) -> None: |
|
"""Fix ALL loss components with detached tensors (no gradients).""" |
|
if fallback_loss is None: |
|
fallback_loss = self.last_valid_loss |
|
|
|
fallback_small = max(0.01, fallback_loss * 0.1) |
|
|
|
|
|
for key, value in outputs.items(): |
|
if isinstance(value, torch.Tensor) and 'loss' in key.lower(): |
|
if torch.isnan(value) or torch.isinf(value): |
|
|
|
replacement = torch.tensor( |
|
fallback_small, |
|
device=value.device, |
|
dtype=value.dtype |
|
|
|
) |
|
outputs[key] = replacement |
|
print(f" π§ Fixed {key}: {value.item():.4f} -> detached {fallback_small:.4f}") |
|
|
|
|
|
for key, value in list(outputs.items()): |
|
if isinstance(value, (int, float)) and 'loss' in key.lower(): |
|
if not torch.isfinite(torch.tensor(value)): |
|
outputs[key] = fallback_small |
|
print(f" π§ Fixed scalar {key}: {value} -> {fallback_small:.4f}") |
|
|
|
def after_train_epoch(self, runner: Runner) -> None: |
|
"""Summary statistics after each epoch.""" |
|
if self.total_nans > 0: |
|
runner.logger.info( |
|
f"π NaN Recovery Summary for Epoch: " |
|
f"{self.total_nans} NaN losses recovered. " |
|
f"Training continued successfully." |
|
) |
|
|
|
|
|
self.consecutive_nans = 0 |
|
self.total_nans = 0 |
|
self.nan_iterations.clear() |