Spaces:
Sleeping
Sleeping
File size: 13,484 Bytes
eb4d305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
# progressive_loss_hook.py - Progressive Loss Switching Hook for Cascade R-CNN
import torch
from mmengine.hooks import Hook
from mmdet.registry import HOOKS
from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss
@HOOKS.register_module()
class ProgressiveLossHook(Hook):
"""
Progressive Loss Switching Hook for Cascade R-CNN.
Starts with SmoothL1Loss for all stages, then progressively switches
stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes.
Args:
switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss
target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss')
loss_weight (float): Loss weight for the new loss function
warmup_epochs (int): Number of epochs to gradually blend the losses
monitor_stage_weights (bool): Whether to log stage loss weights
nan_detection (bool): Whether to enable NaN detection and rollback
max_nan_tolerance (int): Maximum consecutive NaN losses before rollback
"""
def __init__(self,
switch_epoch=5,
target_loss_type='GIoULoss',
loss_weight=1.0,
warmup_epochs=2,
monitor_stage_weights=True,
nan_detection=False,
max_nan_tolerance=5):
super().__init__()
self.switch_epoch = switch_epoch
self.target_loss_type = target_loss_type
self.loss_weight = loss_weight
self.warmup_epochs = warmup_epochs
self.monitor_stage_weights = monitor_stage_weights
self.nan_detection = nan_detection
self.max_nan_tolerance = max_nan_tolerance
self.switched = False
self.original_loss = None
self.consecutive_nans = 0
self.rollback_performed = False
def before_train_epoch(self, runner):
"""Check if we should switch the loss function."""
current_epoch = runner.epoch
# Switch at the specified epoch
if current_epoch >= self.switch_epoch and not self.switched:
self._switch_stage2_loss(runner)
self.switched = True
runner.logger.info(
f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}")
# Monitor during warmup period
elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs:
if self.monitor_stage_weights:
self._log_loss_info(runner, current_epoch)
def _switch_stage2_loss(self, runner):
"""Switch stage 3 bbox loss from SmoothL1 to target loss."""
model = runner.model
# Navigate to stage 3 bbox head (index 2) - final refinement stage
try:
# Handle DDP wrapper
if hasattr(model, 'module'):
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
else:
bbox_head_stage2 = model.roi_head.bbox_head[2]
# Store original loss for comparison
self.original_loss = bbox_head_stage2.loss_bbox
# Create new loss function
if self.target_loss_type == 'GIoULoss':
new_loss = GIoULoss(loss_weight=self.loss_weight)
# Enable decoded bbox regression for IoU losses
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'CIoULoss':
new_loss = CIoULoss(loss_weight=self.loss_weight)
# Enable decoded bbox regression for IoU losses
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'DIoULoss':
new_loss = DIoULoss(loss_weight=self.loss_weight)
# Enable decoded bbox regression for IoU losses
bbox_head_stage2.reg_decoded_bbox = True
else:
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
# Store the switch information with loss-specific benefits
if self.target_loss_type == 'CIoULoss':
runner.logger.info(f"π― CIoU Loss Benefits for Data Points:")
runner.logger.info(f" β’ Directly optimizes center point distance")
runner.logger.info(f" β’ Enforces aspect ratio consistency (square-ish data points)")
runner.logger.info(f" β’ Better convergence for small objects")
runner.logger.info(f" β’ Most complete bounding box quality metric")
elif self.target_loss_type == 'DIoULoss':
runner.logger.info(f"π― DIoU Loss Benefits for Data Points:")
runner.logger.info(f" β’ Directly optimizes center point distance")
runner.logger.info(f" β’ Better convergence for small objects")
runner.logger.info(f" β’ More precise localization for data points")
elif self.target_loss_type == 'GIoULoss':
runner.logger.info(f"π― GIoU Loss Benefits:")
runner.logger.info(f" β’ Improved IoU-based optimization")
runner.logger.info(f" β’ Better than standard IoU loss")
# Replace the loss function
bbox_head_stage2.loss_bbox = new_loss
runner.logger.info(
f"Progressive Loss Switch: Stage 3 changed from "
f"{type(self.original_loss).__name__} to {self.target_loss_type}")
except Exception as e:
runner.logger.error(f"Failed to switch loss function: {e}")
def _log_loss_info(self, runner, epoch):
"""Log information about current loss configuration."""
try:
model = runner.model
if hasattr(model, 'module'):
bbox_heads = model.module.roi_head.bbox_head
else:
bbox_heads = model.roi_head.bbox_head
loss_info = {}
for i, head in enumerate(bbox_heads):
loss_type = type(head.loss_bbox).__name__
loss_weight = head.loss_bbox.loss_weight
loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})"
runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}")
except Exception as e:
runner.logger.warning(f"Could not log loss info: {e}")
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""Monitor loss values during training and detect NaN."""
if self.switched and outputs is not None and isinstance(outputs, dict):
# NaN detection and rollback logic
if self.nan_detection and not self.rollback_performed:
total_loss = outputs.get('loss', None)
if total_loss is not None and torch.isnan(total_loss):
self.consecutive_nans += 1
runner.logger.warning(f"π¨ NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}")
if self.consecutive_nans >= self.max_nan_tolerance:
self._rollback_loss(runner)
self.consecutive_nans = 0
self.rollback_performed = True
runner.logger.error(f"π EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses")
return
elif total_loss is not None and torch.isfinite(total_loss):
# Reset NaN counter on successful iteration
self.consecutive_nans = 0
# Log individual stage losses if available
log_vars = outputs.get('log_vars', {})
stage_losses = {}
for key, value in log_vars.items():
if 'loss_bbox' in key and isinstance(value, (int, float)):
stage_losses[key] = value
if stage_losses and self.monitor_stage_weights:
# Log every 100 iterations to avoid spam
if runner.iter % 100 == 0:
loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()])
runner.logger.info(f"Stage Losses - {loss_summary}")
def after_train_epoch(self, runner):
"""Check epoch completion and reset NaN counters."""
if self.nan_detection and self.switched:
# Log current status
if self.consecutive_nans > 0:
runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences")
else:
runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}")
def _rollback_loss(self, runner):
"""Rollback stage 3 to SmoothL1Loss."""
try:
model = runner.model
if hasattr(model, 'module'):
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
else:
bbox_head_stage2 = model.roi_head.bbox_head[2]
# Create new SmoothL1Loss
rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0)
bbox_head_stage2.loss_bbox = rollback_loss
bbox_head_stage2.reg_decoded_bbox = False # Disable decoded bbox for SmoothL1
runner.logger.info(f"β
Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss")
except Exception as e:
runner.logger.error(f"β Failed to rollback loss function: {e}")
@HOOKS.register_module()
class AdaptiveLossHook(Hook):
"""
Adaptive version that switches based on training stability metrics.
Monitors IoU overlap quality and switches when model is stable.
"""
def __init__(self,
min_epoch=3,
min_avg_iou=0.4,
target_loss_type='GIoULoss',
loss_weight=1.0,
check_interval=100):
super().__init__()
self.min_epoch = min_epoch
self.min_avg_iou = min_avg_iou
self.target_loss_type = target_loss_type
self.loss_weight = loss_weight
self.check_interval = check_interval
self.switched = False
self.iou_history = []
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""Monitor training stability through IoU metrics."""
if (not self.switched and
runner.epoch >= self.min_epoch and
runner.iter % self.check_interval == 0):
# Extract IoU information from outputs if available
if outputs and isinstance(outputs, dict):
log_vars = outputs.get('log_vars', {})
# Look for any IoU-related metrics
iou_metrics = [v for k, v in log_vars.items()
if 'iou' in k.lower() and isinstance(v, (int, float))]
if iou_metrics:
avg_iou = sum(iou_metrics) / len(iou_metrics)
self.iou_history.append(avg_iou)
# Keep only recent history
if len(self.iou_history) > 10:
self.iou_history.pop(0)
# Check if we should switch
if (len(self.iou_history) >= 5 and
sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou):
self._switch_stage2_loss(runner)
self.switched = True
recent_iou = sum(self.iou_history[-5:]) / 5
runner.logger.info(
f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: "
f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}")
def _switch_stage2_loss(self, runner):
"""Same switching logic as ProgressiveLossHook."""
model = runner.model
try:
if hasattr(model, 'module'):
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
else:
bbox_head_stage2 = model.roi_head.bbox_head[2]
if self.target_loss_type == 'GIoULoss':
new_loss = GIoULoss(loss_weight=self.loss_weight)
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'CIoULoss':
new_loss = CIoULoss(loss_weight=self.loss_weight)
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'DIoULoss':
new_loss = DIoULoss(loss_weight=self.loss_weight)
bbox_head_stage2.reg_decoded_bbox = True
else:
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
bbox_head_stage2.loss_bbox = new_loss
runner.logger.info(f"Adaptive Loss Switch: Stage 3 β {self.target_loss_type}")
except Exception as e:
runner.logger.error(f"Failed to switch loss function: {e}") |