File size: 7,758 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# nan_recovery_hook.py - Graceful NaN loss recovery for Cascade R-CNN
import torch
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmdet.registry import HOOKS
from typing import Optional, Dict, Any


@HOOKS.register_module()
class NanRecoveryHook(Hook):
    """Hook to handle NaN losses gracefully without crashing training.
    
    This hook detects NaN losses and handles them by:
    1. Replacing NaN losses with the last valid loss value
    2. Skipping gradient updates for that iteration
    3. Logging the recovery for monitoring
    4. Allowing training to continue normally
    """
    
    def __init__(self, 
                 fallback_loss: float = 0.5,
                 max_consecutive_nans: int = 100,  # Increased from 10
                 log_interval: int = 50):  # Log less frequently
        self.fallback_loss = fallback_loss
        self.max_consecutive_nans = max_consecutive_nans
        self.log_interval = log_interval
        
        # State tracking
        self.last_valid_loss = fallback_loss
        self.consecutive_nans = 0
        self.total_nans = 0
        self.nan_iterations = []
    
    def before_train_iter(self,
                         runner: Runner,
                         batch_idx: int,
                         data_batch: Optional[dict] = None) -> None:
        """Reset any state before training iteration."""
        pass
    
    def after_train_iter(self,
                        runner: Runner,
                        batch_idx: int,
                        data_batch: Optional[dict] = None,
                        outputs: Optional[Dict[str, Any]] = None) -> None:
        """Handle NaN losses after training iteration."""
        if outputs is None:
            return
        
        # Check ALL loss components for NaN, not just the main loss
        has_nan = False
        
        # Check main loss
        total_loss = outputs.get('loss')
        if total_loss is not None and (torch.isnan(total_loss) or torch.isinf(total_loss)):
            has_nan = True
        
        # Check all individual loss components
        for key, value in outputs.items():
            if isinstance(value, torch.Tensor) and 'loss' in key.lower():
                if torch.isnan(value) or torch.isinf(value):
                    has_nan = True
                    break
        
        if has_nan:
            self._handle_nan_loss(runner, batch_idx, outputs)
        else:
            # Valid loss - update tracking
            if total_loss is not None:
                self.last_valid_loss = float(total_loss.item())
                if self.consecutive_nans > 0:
                    runner.logger.info(f"πŸŽ‰ Loss recovered after {self.consecutive_nans} NaN iterations")
                self.consecutive_nans = 0
    
    def _handle_nan_loss(self, runner: Runner, batch_idx: int, outputs: Dict[str, Any]) -> None:
        """Handle NaN loss by replacing with detached fallback and managing state."""
        self.consecutive_nans += 1
        self.total_nans += 1
        self.nan_iterations.append(batch_idx)
        
        # Try to get last good state from SkipBadSamplesHook if available
        last_good_iteration = batch_idx
        last_good_loss = self.last_valid_loss
        
        for hook in runner.hooks:
            if hasattr(hook, 'last_good_iteration') and hasattr(hook, 'last_good_loss'):
                if hook.last_good_loss is not None:
                    last_good_iteration = hook.last_good_iteration
                    last_good_loss = hook.last_good_loss
                break
        
        # Replace NaN loss with detached fallback (no gradients = true no-op)
        if 'loss' in outputs and outputs['loss'] is not None:
            fallback_tensor = torch.tensor(
                last_good_loss, 
                device=outputs['loss'].device,
                dtype=outputs['loss'].dtype
                # NOTE: No requires_grad=True - this makes it detached
            )
            outputs['loss'] = fallback_tensor
        
        # Also fix individual loss components with detached tensors
        self._fix_loss_components(outputs, last_good_loss)
        
        # Log recovery with state info
        if self.consecutive_nans <= 5 or self.consecutive_nans % self.log_interval == 0:
            runner.logger.warning(
                f"πŸ”„ NaN Recovery at iteration {batch_idx}: "
                f"Using last good loss {last_good_loss:.4f} from iteration {last_good_iteration}. "
                f"Consecutive NaNs: {self.consecutive_nans}, Total: {self.total_nans}"
            )
        
        # Reset training state if too many consecutive NaNs
        if self.consecutive_nans >= self.max_consecutive_nans:
            self._reset_nan_state(runner, last_good_iteration)
    
    def _reset_nan_state(self, runner: Runner, last_good_iteration: int) -> None:
        """Reset training state when too many consecutive NaNs."""
        runner.logger.error(
            f"πŸ”„ Too many consecutive NaN losses ({self.consecutive_nans}). "
            f"Resetting to last good state from iteration {last_good_iteration}"
        )
        
        try:
            # Clear model gradients
            if hasattr(runner.model, 'zero_grad'):
                runner.model.zero_grad()
            
            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Reset consecutive counter
            self.consecutive_nans = 0
            
            runner.logger.info(f"βœ… NaN state reset. Resuming training...")
            
        except Exception as e:
            runner.logger.error(f"❌ Failed to reset NaN state: {e}")
    
    def _fix_loss_components(self, outputs: Dict[str, Any], fallback_loss: float = None) -> None:
        """Fix ALL loss components with detached tensors (no gradients)."""
        if fallback_loss is None:
            fallback_loss = self.last_valid_loss
            
        fallback_small = max(0.01, fallback_loss * 0.1)  # Ensure non-zero minimum
        
        # Fix ALL tensors with 'loss' in the key name using detached tensors
        for key, value in outputs.items():
            if isinstance(value, torch.Tensor) and 'loss' in key.lower():
                if torch.isnan(value) or torch.isinf(value):
                    # Create detached replacement tensor (no gradients)
                    replacement = torch.tensor(
                        fallback_small,
                        device=value.device,
                        dtype=value.dtype
                        # NOTE: No requires_grad=True - detached for true no-op
                    )
                    outputs[key] = replacement
                    print(f"   πŸ”§ Fixed {key}: {value.item():.4f} -> detached {fallback_small:.4f}")
        
        # Also fix any scalar values that might be NaN
        for key, value in list(outputs.items()):
            if isinstance(value, (int, float)) and 'loss' in key.lower():
                if not torch.isfinite(torch.tensor(value)):
                    outputs[key] = fallback_small
                    print(f"   πŸ”§ Fixed scalar {key}: {value} -> {fallback_small:.4f}")
    
    def after_train_epoch(self, runner: Runner) -> None:
        """Summary statistics after each epoch."""
        if self.total_nans > 0:
            runner.logger.info(
                f"πŸ“Š NaN Recovery Summary for Epoch: "
                f"{self.total_nans} NaN losses recovered. "
                f"Training continued successfully."
            )
            
            # Reset for next epoch
            self.consecutive_nans = 0
            self.total_nans = 0
            self.nan_iterations.clear()