Dense-Captioning-Platform / custom_models /progressive_loss_hook.py
hanszhu's picture
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
eb4d305
# progressive_loss_hook.py - Progressive Loss Switching Hook for Cascade R-CNN
import torch
from mmengine.hooks import Hook
from mmdet.registry import HOOKS
from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss
@HOOKS.register_module()
class ProgressiveLossHook(Hook):
"""
Progressive Loss Switching Hook for Cascade R-CNN.
Starts with SmoothL1Loss for all stages, then progressively switches
stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes.
Args:
switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss
target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss')
loss_weight (float): Loss weight for the new loss function
warmup_epochs (int): Number of epochs to gradually blend the losses
monitor_stage_weights (bool): Whether to log stage loss weights
nan_detection (bool): Whether to enable NaN detection and rollback
max_nan_tolerance (int): Maximum consecutive NaN losses before rollback
"""
def __init__(self,
switch_epoch=5,
target_loss_type='GIoULoss',
loss_weight=1.0,
warmup_epochs=2,
monitor_stage_weights=True,
nan_detection=False,
max_nan_tolerance=5):
super().__init__()
self.switch_epoch = switch_epoch
self.target_loss_type = target_loss_type
self.loss_weight = loss_weight
self.warmup_epochs = warmup_epochs
self.monitor_stage_weights = monitor_stage_weights
self.nan_detection = nan_detection
self.max_nan_tolerance = max_nan_tolerance
self.switched = False
self.original_loss = None
self.consecutive_nans = 0
self.rollback_performed = False
def before_train_epoch(self, runner):
"""Check if we should switch the loss function."""
current_epoch = runner.epoch
# Switch at the specified epoch
if current_epoch >= self.switch_epoch and not self.switched:
self._switch_stage2_loss(runner)
self.switched = True
runner.logger.info(
f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}")
# Monitor during warmup period
elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs:
if self.monitor_stage_weights:
self._log_loss_info(runner, current_epoch)
def _switch_stage2_loss(self, runner):
"""Switch stage 3 bbox loss from SmoothL1 to target loss."""
model = runner.model
# Navigate to stage 3 bbox head (index 2) - final refinement stage
try:
# Handle DDP wrapper
if hasattr(model, 'module'):
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
else:
bbox_head_stage2 = model.roi_head.bbox_head[2]
# Store original loss for comparison
self.original_loss = bbox_head_stage2.loss_bbox
# Create new loss function
if self.target_loss_type == 'GIoULoss':
new_loss = GIoULoss(loss_weight=self.loss_weight)
# Enable decoded bbox regression for IoU losses
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'CIoULoss':
new_loss = CIoULoss(loss_weight=self.loss_weight)
# Enable decoded bbox regression for IoU losses
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'DIoULoss':
new_loss = DIoULoss(loss_weight=self.loss_weight)
# Enable decoded bbox regression for IoU losses
bbox_head_stage2.reg_decoded_bbox = True
else:
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
# Store the switch information with loss-specific benefits
if self.target_loss_type == 'CIoULoss':
runner.logger.info(f"🎯 CIoU Loss Benefits for Data Points:")
runner.logger.info(f" β€’ Directly optimizes center point distance")
runner.logger.info(f" β€’ Enforces aspect ratio consistency (square-ish data points)")
runner.logger.info(f" β€’ Better convergence for small objects")
runner.logger.info(f" β€’ Most complete bounding box quality metric")
elif self.target_loss_type == 'DIoULoss':
runner.logger.info(f"🎯 DIoU Loss Benefits for Data Points:")
runner.logger.info(f" β€’ Directly optimizes center point distance")
runner.logger.info(f" β€’ Better convergence for small objects")
runner.logger.info(f" β€’ More precise localization for data points")
elif self.target_loss_type == 'GIoULoss':
runner.logger.info(f"🎯 GIoU Loss Benefits:")
runner.logger.info(f" β€’ Improved IoU-based optimization")
runner.logger.info(f" β€’ Better than standard IoU loss")
# Replace the loss function
bbox_head_stage2.loss_bbox = new_loss
runner.logger.info(
f"Progressive Loss Switch: Stage 3 changed from "
f"{type(self.original_loss).__name__} to {self.target_loss_type}")
except Exception as e:
runner.logger.error(f"Failed to switch loss function: {e}")
def _log_loss_info(self, runner, epoch):
"""Log information about current loss configuration."""
try:
model = runner.model
if hasattr(model, 'module'):
bbox_heads = model.module.roi_head.bbox_head
else:
bbox_heads = model.roi_head.bbox_head
loss_info = {}
for i, head in enumerate(bbox_heads):
loss_type = type(head.loss_bbox).__name__
loss_weight = head.loss_bbox.loss_weight
loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})"
runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}")
except Exception as e:
runner.logger.warning(f"Could not log loss info: {e}")
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""Monitor loss values during training and detect NaN."""
if self.switched and outputs is not None and isinstance(outputs, dict):
# NaN detection and rollback logic
if self.nan_detection and not self.rollback_performed:
total_loss = outputs.get('loss', None)
if total_loss is not None and torch.isnan(total_loss):
self.consecutive_nans += 1
runner.logger.warning(f"🚨 NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}")
if self.consecutive_nans >= self.max_nan_tolerance:
self._rollback_loss(runner)
self.consecutive_nans = 0
self.rollback_performed = True
runner.logger.error(f"πŸ”„ EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses")
return
elif total_loss is not None and torch.isfinite(total_loss):
# Reset NaN counter on successful iteration
self.consecutive_nans = 0
# Log individual stage losses if available
log_vars = outputs.get('log_vars', {})
stage_losses = {}
for key, value in log_vars.items():
if 'loss_bbox' in key and isinstance(value, (int, float)):
stage_losses[key] = value
if stage_losses and self.monitor_stage_weights:
# Log every 100 iterations to avoid spam
if runner.iter % 100 == 0:
loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()])
runner.logger.info(f"Stage Losses - {loss_summary}")
def after_train_epoch(self, runner):
"""Check epoch completion and reset NaN counters."""
if self.nan_detection and self.switched:
# Log current status
if self.consecutive_nans > 0:
runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences")
else:
runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}")
def _rollback_loss(self, runner):
"""Rollback stage 3 to SmoothL1Loss."""
try:
model = runner.model
if hasattr(model, 'module'):
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
else:
bbox_head_stage2 = model.roi_head.bbox_head[2]
# Create new SmoothL1Loss
rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0)
bbox_head_stage2.loss_bbox = rollback_loss
bbox_head_stage2.reg_decoded_bbox = False # Disable decoded bbox for SmoothL1
runner.logger.info(f"βœ… Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss")
except Exception as e:
runner.logger.error(f"❌ Failed to rollback loss function: {e}")
@HOOKS.register_module()
class AdaptiveLossHook(Hook):
"""
Adaptive version that switches based on training stability metrics.
Monitors IoU overlap quality and switches when model is stable.
"""
def __init__(self,
min_epoch=3,
min_avg_iou=0.4,
target_loss_type='GIoULoss',
loss_weight=1.0,
check_interval=100):
super().__init__()
self.min_epoch = min_epoch
self.min_avg_iou = min_avg_iou
self.target_loss_type = target_loss_type
self.loss_weight = loss_weight
self.check_interval = check_interval
self.switched = False
self.iou_history = []
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""Monitor training stability through IoU metrics."""
if (not self.switched and
runner.epoch >= self.min_epoch and
runner.iter % self.check_interval == 0):
# Extract IoU information from outputs if available
if outputs and isinstance(outputs, dict):
log_vars = outputs.get('log_vars', {})
# Look for any IoU-related metrics
iou_metrics = [v for k, v in log_vars.items()
if 'iou' in k.lower() and isinstance(v, (int, float))]
if iou_metrics:
avg_iou = sum(iou_metrics) / len(iou_metrics)
self.iou_history.append(avg_iou)
# Keep only recent history
if len(self.iou_history) > 10:
self.iou_history.pop(0)
# Check if we should switch
if (len(self.iou_history) >= 5 and
sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou):
self._switch_stage2_loss(runner)
self.switched = True
recent_iou = sum(self.iou_history[-5:]) / 5
runner.logger.info(
f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: "
f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}")
def _switch_stage2_loss(self, runner):
"""Same switching logic as ProgressiveLossHook."""
model = runner.model
try:
if hasattr(model, 'module'):
bbox_head_stage2 = model.module.roi_head.bbox_head[2]
else:
bbox_head_stage2 = model.roi_head.bbox_head[2]
if self.target_loss_type == 'GIoULoss':
new_loss = GIoULoss(loss_weight=self.loss_weight)
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'CIoULoss':
new_loss = CIoULoss(loss_weight=self.loss_weight)
bbox_head_stage2.reg_decoded_bbox = True
elif self.target_loss_type == 'DIoULoss':
new_loss = DIoULoss(loss_weight=self.loss_weight)
bbox_head_stage2.reg_decoded_bbox = True
else:
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}")
bbox_head_stage2.loss_bbox = new_loss
runner.logger.info(f"Adaptive Loss Switch: Stage 3 β†’ {self.target_loss_type}")
except Exception as e:
runner.logger.error(f"Failed to switch loss function: {e}")