Spaces:
Sleeping
Sleeping
File size: 7,758 Bytes
eb4d305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# nan_recovery_hook.py - Graceful NaN loss recovery for Cascade R-CNN
import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
from typing import Optional, Dict, Any
@HOOKS.register_module()
class NanRecoveryHook(Hook):
"""Hook to handle NaN losses gracefully without crashing training.
This hook detects NaN losses and handles them by:
1. Replacing NaN losses with the last valid loss value
2. Skipping gradient updates for that iteration
3. Logging the recovery for monitoring
4. Allowing training to continue normally
"""
def __init__(self,
fallback_loss: float = 0.5,
max_consecutive_nans: int = 100, # Increased from 10
log_interval: int = 50): # Log less frequently
self.fallback_loss = fallback_loss
self.max_consecutive_nans = max_consecutive_nans
self.log_interval = log_interval
# State tracking
self.last_valid_loss = fallback_loss
self.consecutive_nans = 0
self.total_nans = 0
self.nan_iterations = []
def before_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None) -> None:
"""Reset any state before training iteration."""
pass
def after_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None,
outputs: Optional[Dict[str, Any]] = None) -> None:
"""Handle NaN losses after training iteration."""
if outputs is None:
return
# Check ALL loss components for NaN, not just the main loss
has_nan = False
# Check main loss
total_loss = outputs.get('loss')
if total_loss is not None and (torch.isnan(total_loss) or torch.isinf(total_loss)):
has_nan = True
# Check all individual loss components
for key, value in outputs.items():
if isinstance(value, torch.Tensor) and 'loss' in key.lower():
if torch.isnan(value) or torch.isinf(value):
has_nan = True
break
if has_nan:
self._handle_nan_loss(runner, batch_idx, outputs)
else:
# Valid loss - update tracking
if total_loss is not None:
self.last_valid_loss = float(total_loss.item())
if self.consecutive_nans > 0:
runner.logger.info(f"π Loss recovered after {self.consecutive_nans} NaN iterations")
self.consecutive_nans = 0
def _handle_nan_loss(self, runner: Runner, batch_idx: int, outputs: Dict[str, Any]) -> None:
"""Handle NaN loss by replacing with detached fallback and managing state."""
self.consecutive_nans += 1
self.total_nans += 1
self.nan_iterations.append(batch_idx)
# Try to get last good state from SkipBadSamplesHook if available
last_good_iteration = batch_idx
last_good_loss = self.last_valid_loss
for hook in runner.hooks:
if hasattr(hook, 'last_good_iteration') and hasattr(hook, 'last_good_loss'):
if hook.last_good_loss is not None:
last_good_iteration = hook.last_good_iteration
last_good_loss = hook.last_good_loss
break
# Replace NaN loss with detached fallback (no gradients = true no-op)
if 'loss' in outputs and outputs['loss'] is not None:
fallback_tensor = torch.tensor(
last_good_loss,
device=outputs['loss'].device,
dtype=outputs['loss'].dtype
# NOTE: No requires_grad=True - this makes it detached
)
outputs['loss'] = fallback_tensor
# Also fix individual loss components with detached tensors
self._fix_loss_components(outputs, last_good_loss)
# Log recovery with state info
if self.consecutive_nans <= 5 or self.consecutive_nans % self.log_interval == 0:
runner.logger.warning(
f"π NaN Recovery at iteration {batch_idx}: "
f"Using last good loss {last_good_loss:.4f} from iteration {last_good_iteration}. "
f"Consecutive NaNs: {self.consecutive_nans}, Total: {self.total_nans}"
)
# Reset training state if too many consecutive NaNs
if self.consecutive_nans >= self.max_consecutive_nans:
self._reset_nan_state(runner, last_good_iteration)
def _reset_nan_state(self, runner: Runner, last_good_iteration: int) -> None:
"""Reset training state when too many consecutive NaNs."""
runner.logger.error(
f"π Too many consecutive NaN losses ({self.consecutive_nans}). "
f"Resetting to last good state from iteration {last_good_iteration}"
)
try:
# Clear model gradients
if hasattr(runner.model, 'zero_grad'):
runner.model.zero_grad()
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Reset consecutive counter
self.consecutive_nans = 0
runner.logger.info(f"β
NaN state reset. Resuming training...")
except Exception as e:
runner.logger.error(f"β Failed to reset NaN state: {e}")
def _fix_loss_components(self, outputs: Dict[str, Any], fallback_loss: float = None) -> None:
"""Fix ALL loss components with detached tensors (no gradients)."""
if fallback_loss is None:
fallback_loss = self.last_valid_loss
fallback_small = max(0.01, fallback_loss * 0.1) # Ensure non-zero minimum
# Fix ALL tensors with 'loss' in the key name using detached tensors
for key, value in outputs.items():
if isinstance(value, torch.Tensor) and 'loss' in key.lower():
if torch.isnan(value) or torch.isinf(value):
# Create detached replacement tensor (no gradients)
replacement = torch.tensor(
fallback_small,
device=value.device,
dtype=value.dtype
# NOTE: No requires_grad=True - detached for true no-op
)
outputs[key] = replacement
print(f" π§ Fixed {key}: {value.item():.4f} -> detached {fallback_small:.4f}")
# Also fix any scalar values that might be NaN
for key, value in list(outputs.items()):
if isinstance(value, (int, float)) and 'loss' in key.lower():
if not torch.isfinite(torch.tensor(value)):
outputs[key] = fallback_small
print(f" π§ Fixed scalar {key}: {value} -> {fallback_small:.4f}")
def after_train_epoch(self, runner: Runner) -> None:
"""Summary statistics after each epoch."""
if self.total_nans > 0:
runner.logger.info(
f"π NaN Recovery Summary for Epoch: "
f"{self.total_nans} NaN losses recovered. "
f"Training continued successfully."
)
# Reset for next epoch
self.consecutive_nans = 0
self.total_nans = 0
self.nan_iterations.clear() |