Spaces:
Sleeping
Sleeping
# nan_recovery_hook.py - Graceful NaN loss recovery for Cascade R-CNN | |
import torch | |
from mmengine.hooks import Hook | |
from mmengine.runner import Runner | |
from mmdet.registry import HOOKS | |
from typing import Optional, Dict, Any | |
class NanRecoveryHook(Hook): | |
"""Hook to handle NaN losses gracefully without crashing training. | |
This hook detects NaN losses and handles them by: | |
1. Replacing NaN losses with the last valid loss value | |
2. Skipping gradient updates for that iteration | |
3. Logging the recovery for monitoring | |
4. Allowing training to continue normally | |
""" | |
def __init__(self, | |
fallback_loss: float = 0.5, | |
max_consecutive_nans: int = 100, # Increased from 10 | |
log_interval: int = 50): # Log less frequently | |
self.fallback_loss = fallback_loss | |
self.max_consecutive_nans = max_consecutive_nans | |
self.log_interval = log_interval | |
# State tracking | |
self.last_valid_loss = fallback_loss | |
self.consecutive_nans = 0 | |
self.total_nans = 0 | |
self.nan_iterations = [] | |
def before_train_iter(self, | |
runner: Runner, | |
batch_idx: int, | |
data_batch: Optional[dict] = None) -> None: | |
"""Reset any state before training iteration.""" | |
pass | |
def after_train_iter(self, | |
runner: Runner, | |
batch_idx: int, | |
data_batch: Optional[dict] = None, | |
outputs: Optional[Dict[str, Any]] = None) -> None: | |
"""Handle NaN losses after training iteration.""" | |
if outputs is None: | |
return | |
# Check ALL loss components for NaN, not just the main loss | |
has_nan = False | |
# Check main loss | |
total_loss = outputs.get('loss') | |
if total_loss is not None and (torch.isnan(total_loss) or torch.isinf(total_loss)): | |
has_nan = True | |
# Check all individual loss components | |
for key, value in outputs.items(): | |
if isinstance(value, torch.Tensor) and 'loss' in key.lower(): | |
if torch.isnan(value) or torch.isinf(value): | |
has_nan = True | |
break | |
if has_nan: | |
self._handle_nan_loss(runner, batch_idx, outputs) | |
else: | |
# Valid loss - update tracking | |
if total_loss is not None: | |
self.last_valid_loss = float(total_loss.item()) | |
if self.consecutive_nans > 0: | |
runner.logger.info(f"π Loss recovered after {self.consecutive_nans} NaN iterations") | |
self.consecutive_nans = 0 | |
def _handle_nan_loss(self, runner: Runner, batch_idx: int, outputs: Dict[str, Any]) -> None: | |
"""Handle NaN loss by replacing with detached fallback and managing state.""" | |
self.consecutive_nans += 1 | |
self.total_nans += 1 | |
self.nan_iterations.append(batch_idx) | |
# Try to get last good state from SkipBadSamplesHook if available | |
last_good_iteration = batch_idx | |
last_good_loss = self.last_valid_loss | |
for hook in runner.hooks: | |
if hasattr(hook, 'last_good_iteration') and hasattr(hook, 'last_good_loss'): | |
if hook.last_good_loss is not None: | |
last_good_iteration = hook.last_good_iteration | |
last_good_loss = hook.last_good_loss | |
break | |
# Replace NaN loss with detached fallback (no gradients = true no-op) | |
if 'loss' in outputs and outputs['loss'] is not None: | |
fallback_tensor = torch.tensor( | |
last_good_loss, | |
device=outputs['loss'].device, | |
dtype=outputs['loss'].dtype | |
# NOTE: No requires_grad=True - this makes it detached | |
) | |
outputs['loss'] = fallback_tensor | |
# Also fix individual loss components with detached tensors | |
self._fix_loss_components(outputs, last_good_loss) | |
# Log recovery with state info | |
if self.consecutive_nans <= 5 or self.consecutive_nans % self.log_interval == 0: | |
runner.logger.warning( | |
f"π NaN Recovery at iteration {batch_idx}: " | |
f"Using last good loss {last_good_loss:.4f} from iteration {last_good_iteration}. " | |
f"Consecutive NaNs: {self.consecutive_nans}, Total: {self.total_nans}" | |
) | |
# Reset training state if too many consecutive NaNs | |
if self.consecutive_nans >= self.max_consecutive_nans: | |
self._reset_nan_state(runner, last_good_iteration) | |
def _reset_nan_state(self, runner: Runner, last_good_iteration: int) -> None: | |
"""Reset training state when too many consecutive NaNs.""" | |
runner.logger.error( | |
f"π Too many consecutive NaN losses ({self.consecutive_nans}). " | |
f"Resetting to last good state from iteration {last_good_iteration}" | |
) | |
try: | |
# Clear model gradients | |
if hasattr(runner.model, 'zero_grad'): | |
runner.model.zero_grad() | |
# Clear CUDA cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Reset consecutive counter | |
self.consecutive_nans = 0 | |
runner.logger.info(f"β NaN state reset. Resuming training...") | |
except Exception as e: | |
runner.logger.error(f"β Failed to reset NaN state: {e}") | |
def _fix_loss_components(self, outputs: Dict[str, Any], fallback_loss: float = None) -> None: | |
"""Fix ALL loss components with detached tensors (no gradients).""" | |
if fallback_loss is None: | |
fallback_loss = self.last_valid_loss | |
fallback_small = max(0.01, fallback_loss * 0.1) # Ensure non-zero minimum | |
# Fix ALL tensors with 'loss' in the key name using detached tensors | |
for key, value in outputs.items(): | |
if isinstance(value, torch.Tensor) and 'loss' in key.lower(): | |
if torch.isnan(value) or torch.isinf(value): | |
# Create detached replacement tensor (no gradients) | |
replacement = torch.tensor( | |
fallback_small, | |
device=value.device, | |
dtype=value.dtype | |
# NOTE: No requires_grad=True - detached for true no-op | |
) | |
outputs[key] = replacement | |
print(f" π§ Fixed {key}: {value.item():.4f} -> detached {fallback_small:.4f}") | |
# Also fix any scalar values that might be NaN | |
for key, value in list(outputs.items()): | |
if isinstance(value, (int, float)) and 'loss' in key.lower(): | |
if not torch.isfinite(torch.tensor(value)): | |
outputs[key] = fallback_small | |
print(f" π§ Fixed scalar {key}: {value} -> {fallback_small:.4f}") | |
def after_train_epoch(self, runner: Runner) -> None: | |
"""Summary statistics after each epoch.""" | |
if self.total_nans > 0: | |
runner.logger.info( | |
f"π NaN Recovery Summary for Epoch: " | |
f"{self.total_nans} NaN losses recovered. " | |
f"Training continued successfully." | |
) | |
# Reset for next epoch | |
self.consecutive_nans = 0 | |
self.total_nans = 0 | |
self.nan_iterations.clear() |