Spaces:
Sleeping
Sleeping
# progressive_loss_hook.py - Progressive Loss Switching Hook for Cascade R-CNN | |
import torch | |
from mmengine.hooks import Hook | |
from mmdet.registry import HOOKS | |
from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss | |
class ProgressiveLossHook(Hook): | |
""" | |
Progressive Loss Switching Hook for Cascade R-CNN. | |
Starts with SmoothL1Loss for all stages, then progressively switches | |
stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes. | |
Args: | |
switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss | |
target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss') | |
loss_weight (float): Loss weight for the new loss function | |
warmup_epochs (int): Number of epochs to gradually blend the losses | |
monitor_stage_weights (bool): Whether to log stage loss weights | |
nan_detection (bool): Whether to enable NaN detection and rollback | |
max_nan_tolerance (int): Maximum consecutive NaN losses before rollback | |
""" | |
def __init__(self, | |
switch_epoch=5, | |
target_loss_type='GIoULoss', | |
loss_weight=1.0, | |
warmup_epochs=2, | |
monitor_stage_weights=True, | |
nan_detection=False, | |
max_nan_tolerance=5): | |
super().__init__() | |
self.switch_epoch = switch_epoch | |
self.target_loss_type = target_loss_type | |
self.loss_weight = loss_weight | |
self.warmup_epochs = warmup_epochs | |
self.monitor_stage_weights = monitor_stage_weights | |
self.nan_detection = nan_detection | |
self.max_nan_tolerance = max_nan_tolerance | |
self.switched = False | |
self.original_loss = None | |
self.consecutive_nans = 0 | |
self.rollback_performed = False | |
def before_train_epoch(self, runner): | |
"""Check if we should switch the loss function.""" | |
current_epoch = runner.epoch | |
# Switch at the specified epoch | |
if current_epoch >= self.switch_epoch and not self.switched: | |
self._switch_stage2_loss(runner) | |
self.switched = True | |
runner.logger.info( | |
f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}") | |
# Monitor during warmup period | |
elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs: | |
if self.monitor_stage_weights: | |
self._log_loss_info(runner, current_epoch) | |
def _switch_stage2_loss(self, runner): | |
"""Switch stage 3 bbox loss from SmoothL1 to target loss.""" | |
model = runner.model | |
# Navigate to stage 3 bbox head (index 2) - final refinement stage | |
try: | |
# Handle DDP wrapper | |
if hasattr(model, 'module'): | |
bbox_head_stage2 = model.module.roi_head.bbox_head[2] | |
else: | |
bbox_head_stage2 = model.roi_head.bbox_head[2] | |
# Store original loss for comparison | |
self.original_loss = bbox_head_stage2.loss_bbox | |
# Create new loss function | |
if self.target_loss_type == 'GIoULoss': | |
new_loss = GIoULoss(loss_weight=self.loss_weight) | |
# Enable decoded bbox regression for IoU losses | |
bbox_head_stage2.reg_decoded_bbox = True | |
elif self.target_loss_type == 'CIoULoss': | |
new_loss = CIoULoss(loss_weight=self.loss_weight) | |
# Enable decoded bbox regression for IoU losses | |
bbox_head_stage2.reg_decoded_bbox = True | |
elif self.target_loss_type == 'DIoULoss': | |
new_loss = DIoULoss(loss_weight=self.loss_weight) | |
# Enable decoded bbox regression for IoU losses | |
bbox_head_stage2.reg_decoded_bbox = True | |
else: | |
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}") | |
# Store the switch information with loss-specific benefits | |
if self.target_loss_type == 'CIoULoss': | |
runner.logger.info(f"π― CIoU Loss Benefits for Data Points:") | |
runner.logger.info(f" β’ Directly optimizes center point distance") | |
runner.logger.info(f" β’ Enforces aspect ratio consistency (square-ish data points)") | |
runner.logger.info(f" β’ Better convergence for small objects") | |
runner.logger.info(f" β’ Most complete bounding box quality metric") | |
elif self.target_loss_type == 'DIoULoss': | |
runner.logger.info(f"π― DIoU Loss Benefits for Data Points:") | |
runner.logger.info(f" β’ Directly optimizes center point distance") | |
runner.logger.info(f" β’ Better convergence for small objects") | |
runner.logger.info(f" β’ More precise localization for data points") | |
elif self.target_loss_type == 'GIoULoss': | |
runner.logger.info(f"π― GIoU Loss Benefits:") | |
runner.logger.info(f" β’ Improved IoU-based optimization") | |
runner.logger.info(f" β’ Better than standard IoU loss") | |
# Replace the loss function | |
bbox_head_stage2.loss_bbox = new_loss | |
runner.logger.info( | |
f"Progressive Loss Switch: Stage 3 changed from " | |
f"{type(self.original_loss).__name__} to {self.target_loss_type}") | |
except Exception as e: | |
runner.logger.error(f"Failed to switch loss function: {e}") | |
def _log_loss_info(self, runner, epoch): | |
"""Log information about current loss configuration.""" | |
try: | |
model = runner.model | |
if hasattr(model, 'module'): | |
bbox_heads = model.module.roi_head.bbox_head | |
else: | |
bbox_heads = model.roi_head.bbox_head | |
loss_info = {} | |
for i, head in enumerate(bbox_heads): | |
loss_type = type(head.loss_bbox).__name__ | |
loss_weight = head.loss_bbox.loss_weight | |
loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})" | |
runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}") | |
except Exception as e: | |
runner.logger.warning(f"Could not log loss info: {e}") | |
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None): | |
"""Monitor loss values during training and detect NaN.""" | |
if self.switched and outputs is not None and isinstance(outputs, dict): | |
# NaN detection and rollback logic | |
if self.nan_detection and not self.rollback_performed: | |
total_loss = outputs.get('loss', None) | |
if total_loss is not None and torch.isnan(total_loss): | |
self.consecutive_nans += 1 | |
runner.logger.warning(f"π¨ NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}") | |
if self.consecutive_nans >= self.max_nan_tolerance: | |
self._rollback_loss(runner) | |
self.consecutive_nans = 0 | |
self.rollback_performed = True | |
runner.logger.error(f"π EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses") | |
return | |
elif total_loss is not None and torch.isfinite(total_loss): | |
# Reset NaN counter on successful iteration | |
self.consecutive_nans = 0 | |
# Log individual stage losses if available | |
log_vars = outputs.get('log_vars', {}) | |
stage_losses = {} | |
for key, value in log_vars.items(): | |
if 'loss_bbox' in key and isinstance(value, (int, float)): | |
stage_losses[key] = value | |
if stage_losses and self.monitor_stage_weights: | |
# Log every 100 iterations to avoid spam | |
if runner.iter % 100 == 0: | |
loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()]) | |
runner.logger.info(f"Stage Losses - {loss_summary}") | |
def after_train_epoch(self, runner): | |
"""Check epoch completion and reset NaN counters.""" | |
if self.nan_detection and self.switched: | |
# Log current status | |
if self.consecutive_nans > 0: | |
runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences") | |
else: | |
runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}") | |
def _rollback_loss(self, runner): | |
"""Rollback stage 3 to SmoothL1Loss.""" | |
try: | |
model = runner.model | |
if hasattr(model, 'module'): | |
bbox_head_stage2 = model.module.roi_head.bbox_head[2] | |
else: | |
bbox_head_stage2 = model.roi_head.bbox_head[2] | |
# Create new SmoothL1Loss | |
rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0) | |
bbox_head_stage2.loss_bbox = rollback_loss | |
bbox_head_stage2.reg_decoded_bbox = False # Disable decoded bbox for SmoothL1 | |
runner.logger.info(f"β Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss") | |
except Exception as e: | |
runner.logger.error(f"β Failed to rollback loss function: {e}") | |
class AdaptiveLossHook(Hook): | |
""" | |
Adaptive version that switches based on training stability metrics. | |
Monitors IoU overlap quality and switches when model is stable. | |
""" | |
def __init__(self, | |
min_epoch=3, | |
min_avg_iou=0.4, | |
target_loss_type='GIoULoss', | |
loss_weight=1.0, | |
check_interval=100): | |
super().__init__() | |
self.min_epoch = min_epoch | |
self.min_avg_iou = min_avg_iou | |
self.target_loss_type = target_loss_type | |
self.loss_weight = loss_weight | |
self.check_interval = check_interval | |
self.switched = False | |
self.iou_history = [] | |
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None): | |
"""Monitor training stability through IoU metrics.""" | |
if (not self.switched and | |
runner.epoch >= self.min_epoch and | |
runner.iter % self.check_interval == 0): | |
# Extract IoU information from outputs if available | |
if outputs and isinstance(outputs, dict): | |
log_vars = outputs.get('log_vars', {}) | |
# Look for any IoU-related metrics | |
iou_metrics = [v for k, v in log_vars.items() | |
if 'iou' in k.lower() and isinstance(v, (int, float))] | |
if iou_metrics: | |
avg_iou = sum(iou_metrics) / len(iou_metrics) | |
self.iou_history.append(avg_iou) | |
# Keep only recent history | |
if len(self.iou_history) > 10: | |
self.iou_history.pop(0) | |
# Check if we should switch | |
if (len(self.iou_history) >= 5 and | |
sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou): | |
self._switch_stage2_loss(runner) | |
self.switched = True | |
recent_iou = sum(self.iou_history[-5:]) / 5 | |
runner.logger.info( | |
f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: " | |
f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}") | |
def _switch_stage2_loss(self, runner): | |
"""Same switching logic as ProgressiveLossHook.""" | |
model = runner.model | |
try: | |
if hasattr(model, 'module'): | |
bbox_head_stage2 = model.module.roi_head.bbox_head[2] | |
else: | |
bbox_head_stage2 = model.roi_head.bbox_head[2] | |
if self.target_loss_type == 'GIoULoss': | |
new_loss = GIoULoss(loss_weight=self.loss_weight) | |
bbox_head_stage2.reg_decoded_bbox = True | |
elif self.target_loss_type == 'CIoULoss': | |
new_loss = CIoULoss(loss_weight=self.loss_weight) | |
bbox_head_stage2.reg_decoded_bbox = True | |
elif self.target_loss_type == 'DIoULoss': | |
new_loss = DIoULoss(loss_weight=self.loss_weight) | |
bbox_head_stage2.reg_decoded_bbox = True | |
else: | |
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}") | |
bbox_head_stage2.loss_bbox = new_loss | |
runner.logger.info(f"Adaptive Loss Switch: Stage 3 β {self.target_loss_type}") | |
except Exception as e: | |
runner.logger.error(f"Failed to switch loss function: {e}") |