|
|
|
import torch |
|
from mmengine.hooks import Hook |
|
from mmdet.registry import HOOKS |
|
from mmdet.models.losses import SmoothL1Loss, GIoULoss, CIoULoss, DIoULoss |
|
|
|
@HOOKS.register_module() |
|
class ProgressiveLossHook(Hook): |
|
""" |
|
Progressive Loss Switching Hook for Cascade R-CNN. |
|
|
|
Starts with SmoothL1Loss for all stages, then progressively switches |
|
stage 3 (final stage) to GIoU/CIoU/DIoU after the model stabilizes. |
|
|
|
Args: |
|
switch_epoch (int): Epoch to switch stage 3 from SmoothL1 to target loss |
|
target_loss_type (str): Target loss type for stage 3 ('GIoULoss', 'CIoULoss', or 'DIoULoss') |
|
loss_weight (float): Loss weight for the new loss function |
|
warmup_epochs (int): Number of epochs to gradually blend the losses |
|
monitor_stage_weights (bool): Whether to log stage loss weights |
|
nan_detection (bool): Whether to enable NaN detection and rollback |
|
max_nan_tolerance (int): Maximum consecutive NaN losses before rollback |
|
""" |
|
|
|
def __init__(self, |
|
switch_epoch=5, |
|
target_loss_type='GIoULoss', |
|
loss_weight=1.0, |
|
warmup_epochs=2, |
|
monitor_stage_weights=True, |
|
nan_detection=False, |
|
max_nan_tolerance=5): |
|
super().__init__() |
|
self.switch_epoch = switch_epoch |
|
self.target_loss_type = target_loss_type |
|
self.loss_weight = loss_weight |
|
self.warmup_epochs = warmup_epochs |
|
self.monitor_stage_weights = monitor_stage_weights |
|
self.nan_detection = nan_detection |
|
self.max_nan_tolerance = max_nan_tolerance |
|
self.switched = False |
|
self.original_loss = None |
|
self.consecutive_nans = 0 |
|
self.rollback_performed = False |
|
|
|
def before_train_epoch(self, runner): |
|
"""Check if we should switch the loss function.""" |
|
current_epoch = runner.epoch |
|
|
|
|
|
if current_epoch >= self.switch_epoch and not self.switched: |
|
self._switch_stage2_loss(runner) |
|
self.switched = True |
|
runner.logger.info( |
|
f"Epoch {current_epoch}: Switched Stage 3 loss to {self.target_loss_type}") |
|
|
|
|
|
elif current_epoch >= self.switch_epoch and current_epoch < self.switch_epoch + self.warmup_epochs: |
|
if self.monitor_stage_weights: |
|
self._log_loss_info(runner, current_epoch) |
|
|
|
def _switch_stage2_loss(self, runner): |
|
"""Switch stage 3 bbox loss from SmoothL1 to target loss.""" |
|
model = runner.model |
|
|
|
|
|
try: |
|
|
|
if hasattr(model, 'module'): |
|
bbox_head_stage2 = model.module.roi_head.bbox_head[2] |
|
else: |
|
bbox_head_stage2 = model.roi_head.bbox_head[2] |
|
|
|
|
|
self.original_loss = bbox_head_stage2.loss_bbox |
|
|
|
|
|
if self.target_loss_type == 'GIoULoss': |
|
new_loss = GIoULoss(loss_weight=self.loss_weight) |
|
|
|
bbox_head_stage2.reg_decoded_bbox = True |
|
elif self.target_loss_type == 'CIoULoss': |
|
new_loss = CIoULoss(loss_weight=self.loss_weight) |
|
|
|
bbox_head_stage2.reg_decoded_bbox = True |
|
elif self.target_loss_type == 'DIoULoss': |
|
new_loss = DIoULoss(loss_weight=self.loss_weight) |
|
|
|
bbox_head_stage2.reg_decoded_bbox = True |
|
else: |
|
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}") |
|
|
|
|
|
if self.target_loss_type == 'CIoULoss': |
|
runner.logger.info(f"π― CIoU Loss Benefits for Data Points:") |
|
runner.logger.info(f" β’ Directly optimizes center point distance") |
|
runner.logger.info(f" β’ Enforces aspect ratio consistency (square-ish data points)") |
|
runner.logger.info(f" β’ Better convergence for small objects") |
|
runner.logger.info(f" β’ Most complete bounding box quality metric") |
|
elif self.target_loss_type == 'DIoULoss': |
|
runner.logger.info(f"π― DIoU Loss Benefits for Data Points:") |
|
runner.logger.info(f" β’ Directly optimizes center point distance") |
|
runner.logger.info(f" β’ Better convergence for small objects") |
|
runner.logger.info(f" β’ More precise localization for data points") |
|
elif self.target_loss_type == 'GIoULoss': |
|
runner.logger.info(f"π― GIoU Loss Benefits:") |
|
runner.logger.info(f" β’ Improved IoU-based optimization") |
|
runner.logger.info(f" β’ Better than standard IoU loss") |
|
|
|
|
|
bbox_head_stage2.loss_bbox = new_loss |
|
|
|
runner.logger.info( |
|
f"Progressive Loss Switch: Stage 3 changed from " |
|
f"{type(self.original_loss).__name__} to {self.target_loss_type}") |
|
|
|
except Exception as e: |
|
runner.logger.error(f"Failed to switch loss function: {e}") |
|
|
|
def _log_loss_info(self, runner, epoch): |
|
"""Log information about current loss configuration.""" |
|
try: |
|
model = runner.model |
|
if hasattr(model, 'module'): |
|
bbox_heads = model.module.roi_head.bbox_head |
|
else: |
|
bbox_heads = model.roi_head.bbox_head |
|
|
|
loss_info = {} |
|
for i, head in enumerate(bbox_heads): |
|
loss_type = type(head.loss_bbox).__name__ |
|
loss_weight = head.loss_bbox.loss_weight |
|
loss_info[f'stage_{i+1}'] = f"{loss_type}(w={loss_weight})" |
|
|
|
runner.logger.info(f"Epoch {epoch} Loss Configuration: {loss_info}") |
|
|
|
except Exception as e: |
|
runner.logger.warning(f"Could not log loss info: {e}") |
|
|
|
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None): |
|
"""Monitor loss values during training and detect NaN.""" |
|
if self.switched and outputs is not None and isinstance(outputs, dict): |
|
|
|
if self.nan_detection and not self.rollback_performed: |
|
total_loss = outputs.get('loss', None) |
|
if total_loss is not None and torch.isnan(total_loss): |
|
self.consecutive_nans += 1 |
|
runner.logger.warning(f"π¨ NaN detected in total loss! Consecutive: {self.consecutive_nans}/{self.max_nan_tolerance}") |
|
|
|
if self.consecutive_nans >= self.max_nan_tolerance: |
|
self._rollback_loss(runner) |
|
self.consecutive_nans = 0 |
|
self.rollback_performed = True |
|
runner.logger.error(f"π EMERGENCY ROLLBACK: Switched back to SmoothL1Loss due to {self.max_nan_tolerance} consecutive NaN losses") |
|
return |
|
elif total_loss is not None and torch.isfinite(total_loss): |
|
|
|
self.consecutive_nans = 0 |
|
|
|
|
|
log_vars = outputs.get('log_vars', {}) |
|
stage_losses = {} |
|
|
|
for key, value in log_vars.items(): |
|
if 'loss_bbox' in key and isinstance(value, (int, float)): |
|
stage_losses[key] = value |
|
|
|
if stage_losses and self.monitor_stage_weights: |
|
|
|
if runner.iter % 100 == 0: |
|
loss_summary = ", ".join([f"{k}: {v:.4f}" for k, v in stage_losses.items()]) |
|
runner.logger.info(f"Stage Losses - {loss_summary}") |
|
|
|
def after_train_epoch(self, runner): |
|
"""Check epoch completion and reset NaN counters.""" |
|
if self.nan_detection and self.switched: |
|
|
|
if self.consecutive_nans > 0: |
|
runner.logger.warning(f"Epoch {runner.epoch} completed with {self.consecutive_nans} NaN occurrences") |
|
else: |
|
runner.logger.info(f"Epoch {runner.epoch} completed successfully with {self.target_loss_type}") |
|
|
|
def _rollback_loss(self, runner): |
|
"""Rollback stage 3 to SmoothL1Loss.""" |
|
try: |
|
model = runner.model |
|
if hasattr(model, 'module'): |
|
bbox_head_stage2 = model.module.roi_head.bbox_head[2] |
|
else: |
|
bbox_head_stage2 = model.roi_head.bbox_head[2] |
|
|
|
|
|
rollback_loss = SmoothL1Loss(beta=1.0, loss_weight=1.0) |
|
bbox_head_stage2.loss_bbox = rollback_loss |
|
bbox_head_stage2.reg_decoded_bbox = False |
|
|
|
runner.logger.info(f"β
Successfully rolled back Stage 3 from {self.target_loss_type} to SmoothL1Loss") |
|
|
|
except Exception as e: |
|
runner.logger.error(f"β Failed to rollback loss function: {e}") |
|
|
|
|
|
@HOOKS.register_module() |
|
class AdaptiveLossHook(Hook): |
|
""" |
|
Adaptive version that switches based on training stability metrics. |
|
|
|
Monitors IoU overlap quality and switches when model is stable. |
|
""" |
|
|
|
def __init__(self, |
|
min_epoch=3, |
|
min_avg_iou=0.4, |
|
target_loss_type='GIoULoss', |
|
loss_weight=1.0, |
|
check_interval=100): |
|
super().__init__() |
|
self.min_epoch = min_epoch |
|
self.min_avg_iou = min_avg_iou |
|
self.target_loss_type = target_loss_type |
|
self.loss_weight = loss_weight |
|
self.check_interval = check_interval |
|
self.switched = False |
|
self.iou_history = [] |
|
|
|
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None): |
|
"""Monitor training stability through IoU metrics.""" |
|
if (not self.switched and |
|
runner.epoch >= self.min_epoch and |
|
runner.iter % self.check_interval == 0): |
|
|
|
|
|
if outputs and isinstance(outputs, dict): |
|
log_vars = outputs.get('log_vars', {}) |
|
|
|
|
|
iou_metrics = [v for k, v in log_vars.items() |
|
if 'iou' in k.lower() and isinstance(v, (int, float))] |
|
|
|
if iou_metrics: |
|
avg_iou = sum(iou_metrics) / len(iou_metrics) |
|
self.iou_history.append(avg_iou) |
|
|
|
|
|
if len(self.iou_history) > 10: |
|
self.iou_history.pop(0) |
|
|
|
|
|
if (len(self.iou_history) >= 5 and |
|
sum(self.iou_history[-5:]) / 5 >= self.min_avg_iou): |
|
|
|
self._switch_stage2_loss(runner) |
|
self.switched = True |
|
|
|
recent_iou = sum(self.iou_history[-5:]) / 5 |
|
runner.logger.info( |
|
f"Adaptive switch at epoch {runner.epoch}, iter {runner.iter}: " |
|
f"avg IoU {recent_iou:.3f} >= {self.min_avg_iou}") |
|
|
|
def _switch_stage2_loss(self, runner): |
|
"""Same switching logic as ProgressiveLossHook.""" |
|
model = runner.model |
|
try: |
|
if hasattr(model, 'module'): |
|
bbox_head_stage2 = model.module.roi_head.bbox_head[2] |
|
else: |
|
bbox_head_stage2 = model.roi_head.bbox_head[2] |
|
|
|
if self.target_loss_type == 'GIoULoss': |
|
new_loss = GIoULoss(loss_weight=self.loss_weight) |
|
bbox_head_stage2.reg_decoded_bbox = True |
|
elif self.target_loss_type == 'CIoULoss': |
|
new_loss = CIoULoss(loss_weight=self.loss_weight) |
|
bbox_head_stage2.reg_decoded_bbox = True |
|
elif self.target_loss_type == 'DIoULoss': |
|
new_loss = DIoULoss(loss_weight=self.loss_weight) |
|
bbox_head_stage2.reg_decoded_bbox = True |
|
else: |
|
raise ValueError(f"Unsupported target loss type: {self.target_loss_type}") |
|
|
|
bbox_head_stage2.loss_bbox = new_loss |
|
|
|
runner.logger.info(f"Adaptive Loss Switch: Stage 3 β {self.target_loss_type}") |
|
|
|
except Exception as e: |
|
runner.logger.error(f"Failed to switch loss function: {e}") |