File size: 13,484 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# progressive_loss_hook.py - Progressive Loss Switching Hook for Cascade R-CNN
import torch
from mmengine.hooks import Hook
from mmdet.registry import HOOKS
from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss

@HOOKS.register_module()
class ProgressiveLossHook(Hook):
    """
    Progressive Loss Switching Hook for Cascade R-CNN.
    
    Starts with SmoothL1Loss for all stages, then progressively switches
    stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes.
    
    Args:
        switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss
        target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss')
        loss_weight (float): Loss weight for the new loss function
        warmup_epochs (int): Number of epochs to gradually blend the losses
        monitor_stage_weights (bool): Whether to log stage loss weights
        nan_detection (bool): Whether to enable NaN detection and rollback
        max_nan_tolerance (int): Maximum consecutive NaN losses before rollback
    """
    
    def __init__(self, 
                 switch_epoch=5,
                 target_loss_type='GIoULoss',
                 loss_weight=1.0,
                 warmup_epochs=2,
                 monitor_stage_weights=True,
                 nan_detection=False,
                 max_nan_tolerance=5):
        super().__init__()
        self.switch_epoch = switch_epoch
        self.target_loss_type = target_loss_type
        self.loss_weight = loss_weight
        self.warmup_epochs = warmup_epochs
        self.monitor_stage_weights = monitor_stage_weights
        self.nan_detection = nan_detection
        self.max_nan_tolerance = max_nan_tolerance
        self.switched = False
        self.original_loss = None
        self.consecutive_nans = 0
        self.rollback_performed = False
        
    def before_train_epoch(self, runner):
        """Check if we should switch the loss function."""
        current_epoch = runner.epoch
        
        # Switch at the specified epoch
        if current_epoch >= self.switch_epoch and not self.switched:
            self._switch_stage2_loss(runner)
            self.switched = True
            runner.logger.info(
                f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}")
        
        # Monitor during warmup period
        elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs:
            if self.monitor_stage_weights:
                self._log_loss_info(runner, current_epoch)
    
    def _switch_stage2_loss(self, runner):
        """Switch stage 3 bbox loss from SmoothL1 to target loss."""
        model = runner.model
        
        # Navigate to stage 3 bbox head (index 2) - final refinement stage
        try:
            # Handle DDP wrapper
            if hasattr(model, 'module'):
                bbox_head_stage2 = model.module.roi_head.bbox_head[2]
            else:
                bbox_head_stage2 = model.roi_head.bbox_head[2]
            
            # Store original loss for comparison
            self.original_loss = bbox_head_stage2.loss_bbox
            
            # Create new loss function
            if self.target_loss_type == 'GIoULoss':
                new_loss = GIoULoss(loss_weight=self.loss_weight)
                # Enable decoded bbox regression for IoU losses
                bbox_head_stage2.reg_decoded_bbox = True
            elif self.target_loss_type == 'CIoULoss':
                new_loss = CIoULoss(loss_weight=self.loss_weight)
                # Enable decoded bbox regression for IoU losses
                bbox_head_stage2.reg_decoded_bbox = True
            elif self.target_loss_type == 'DIoULoss':
                new_loss = DIoULoss(loss_weight=self.loss_weight)
                # Enable decoded bbox regression for IoU losses
                bbox_head_stage2.reg_decoded_bbox = True
            else:
                raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
            
            # Store the switch information with loss-specific benefits
            if self.target_loss_type == 'CIoULoss':
                runner.logger.info(f"🎯 CIoU Loss Benefits for Data Points:")
                runner.logger.info(f"   β€’ Directly optimizes center point distance")
                runner.logger.info(f"   β€’ Enforces aspect ratio consistency (square-ish data points)")
                runner.logger.info(f"   β€’ Better convergence for small objects")
                runner.logger.info(f"   β€’ Most complete bounding box quality metric")
            elif self.target_loss_type == 'DIoULoss':
                runner.logger.info(f"🎯 DIoU Loss Benefits for Data Points:")
                runner.logger.info(f"   β€’ Directly optimizes center point distance")
                runner.logger.info(f"   β€’ Better convergence for small objects")
                runner.logger.info(f"   β€’ More precise localization for data points")
            elif self.target_loss_type == 'GIoULoss':
                runner.logger.info(f"🎯 GIoU Loss Benefits:")
                runner.logger.info(f"   β€’ Improved IoU-based optimization")
                runner.logger.info(f"   β€’ Better than standard IoU loss")
            
            # Replace the loss function
            bbox_head_stage2.loss_bbox = new_loss
            
            runner.logger.info(
                f"Progressive Loss Switch: Stage 3 changed from "
                f"{type(self.original_loss).__name__} to {self.target_loss_type}")
                
        except Exception as e:
            runner.logger.error(f"Failed to switch loss function: {e}")
    
    def _log_loss_info(self, runner, epoch):
        """Log information about current loss configuration."""
        try:
            model = runner.model
            if hasattr(model, 'module'):
                bbox_heads = model.module.roi_head.bbox_head
            else:
                bbox_heads = model.roi_head.bbox_head
            
            loss_info = {}
            for i, head in enumerate(bbox_heads):
                loss_type = type(head.loss_bbox).__name__
                loss_weight = head.loss_bbox.loss_weight
                loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})"
            
            runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}")
            
        except Exception as e:
            runner.logger.warning(f"Could not log loss info: {e}")

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """Monitor loss values during training and detect NaN."""
        if self.switched and outputs is not None and isinstance(outputs, dict):
            # NaN detection and rollback logic
            if self.nan_detection and not self.rollback_performed:
                total_loss = outputs.get('loss', None)
                if total_loss is not None and torch.isnan(total_loss):
                    self.consecutive_nans += 1
                    runner.logger.warning(f"🚨 NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}")
                    
                    if self.consecutive_nans >= self.max_nan_tolerance:
                        self._rollback_loss(runner)
                        self.consecutive_nans = 0
                        self.rollback_performed = True
                        runner.logger.error(f"πŸ”„ EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses")
                        return
                elif total_loss is not None and torch.isfinite(total_loss):
                    # Reset NaN counter on successful iteration
                    self.consecutive_nans = 0
            
            # Log individual stage losses if available
            log_vars = outputs.get('log_vars', {})
            stage_losses = {}
            
            for key, value in log_vars.items():
                if 'loss_bbox' in key and isinstance(value, (int, float)):
                    stage_losses[key] = value
            
            if stage_losses and self.monitor_stage_weights:
                # Log every 100 iterations to avoid spam
                if runner.iter % 100 == 0:
                    loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()])
                    runner.logger.info(f"Stage Losses - {loss_summary}")

    def after_train_epoch(self, runner):
        """Check epoch completion and reset NaN counters."""
        if self.nan_detection and self.switched:
            # Log current status
            if self.consecutive_nans > 0:
                runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences")
            else:
                runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}")

    def _rollback_loss(self, runner):
        """Rollback stage 3 to SmoothL1Loss."""
        try:
            model = runner.model
            if hasattr(model, 'module'):
                bbox_head_stage2 = model.module.roi_head.bbox_head[2]
            else:
                bbox_head_stage2 = model.roi_head.bbox_head[2]
            
            # Create new SmoothL1Loss
            rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0)
            bbox_head_stage2.loss_bbox = rollback_loss
            bbox_head_stage2.reg_decoded_bbox = False  # Disable decoded bbox for SmoothL1
            
            runner.logger.info(f"βœ… Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss")
            
        except Exception as e:
            runner.logger.error(f"❌ Failed to rollback loss function: {e}")


@HOOKS.register_module()
class AdaptiveLossHook(Hook):
    """
    Adaptive version that switches based on training stability metrics.
    
    Monitors IoU overlap quality and switches when model is stable.
    """
    
    def __init__(self,
                 min_epoch=3,
                 min_avg_iou=0.4,
                 target_loss_type='GIoULoss',
                 loss_weight=1.0,
                 check_interval=100):
        super().__init__()
        self.min_epoch = min_epoch
        self.min_avg_iou = min_avg_iou
        self.target_loss_type = target_loss_type
        self.loss_weight = loss_weight
        self.check_interval = check_interval
        self.switched = False
        self.iou_history = []
        
    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """Monitor training stability through IoU metrics."""
        if (not self.switched and 
            runner.epoch >= self.min_epoch and 
            runner.iter % self.check_interval == 0):
            
            # Extract IoU information from outputs if available
            if outputs and isinstance(outputs, dict):
                log_vars = outputs.get('log_vars', {})
                
                # Look for any IoU-related metrics
                iou_metrics = [v for k, v in log_vars.items() 
                              if 'iou' in k.lower() and isinstance(v, (int, float))]
                
                if iou_metrics:
                    avg_iou = sum(iou_metrics) / len(iou_metrics)
                    self.iou_history.append(avg_iou)
                    
                    # Keep only recent history
                    if len(self.iou_history) > 10:
                        self.iou_history.pop(0)
                    
                    # Check if we should switch
                    if (len(self.iou_history) >= 5 and 
                        sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou):
                        
                        self._switch_stage2_loss(runner)
                        self.switched = True
                        
                        recent_iou = sum(self.iou_history[-5:]) / 5
                        runner.logger.info(
                            f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: "
                            f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}")
    
    def _switch_stage2_loss(self, runner):
        """Same switching logic as ProgressiveLossHook."""
        model = runner.model
        try:
            if hasattr(model, 'module'):
                bbox_head_stage2 = model.module.roi_head.bbox_head[2]
            else:
                bbox_head_stage2 = model.roi_head.bbox_head[2]
            
            if self.target_loss_type == 'GIoULoss':
                new_loss = GIoULoss(loss_weight=self.loss_weight)
                bbox_head_stage2.reg_decoded_bbox = True
            elif self.target_loss_type == 'CIoULoss':
                new_loss = CIoULoss(loss_weight=self.loss_weight)
                bbox_head_stage2.reg_decoded_bbox = True
            elif self.target_loss_type == 'DIoULoss':
                new_loss = DIoULoss(loss_weight=self.loss_weight)
                bbox_head_stage2.reg_decoded_bbox = True
            else:
                raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
            
            bbox_head_stage2.loss_bbox = new_loss
            
            runner.logger.info(f"Adaptive Loss Switch: Stage 3 β†’ {self.target_loss_type}")
                
        except Exception as e:
            runner.logger.error(f"Failed to switch loss function: {e}")