File size: 7,084 Bytes
e611d1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import math
from typing import Dict, Any, Union, Optional
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, StepLR, ExponentialLR, CosineAnnealingWarmRestarts, OneCycleLR
class AutoScheduler(_LRScheduler):
"""
Automatic learning rate scheduler with warmup and configurable main schedule.
"""
# Default parameters for different scheduler types
DEFAULT_PARAMS = {
"cosine": {"eta_min": 1e-6},
"cosine_restart": {"T_mult": 2, "eta_min": 1e-6, "rounds": 5},
"step": {"gamma": 0.5, "decay_steps": 3},
"exponential": {"gamma": 0.95},
"one_cycle": {"lr_mult": 10.0, "div_factor": 25.0, "final_div_factor": 1e4}
}
def __init__(
self,
optimizer: Optimizer,
total_steps: int,
scheduler_type: str = "cosine_restart",
warmup_ratio: float = 0.1,
warmup_type: str = "linear",
**kwargs
):
self.scheduler_type = scheduler_type
self.warmup_type = warmup_type
self.warmup_ratio = warmup_ratio
self.total_steps = total_steps
self.warmup_steps = max(1, int(total_steps * warmup_ratio))
self.current_step = 0
self._is_warmup = True
# Merge default parameters with user-provided kwargs
self.params = self._get_merged_params(kwargs)
# Validate parameters
self._validate_parameters()
# Create the main scheduler BEFORE calling super().__init__
# This is needed because super().__init__ will call step() immediately
self.after_scheduler = None # Initialize as None first
# Initialize parent class
super().__init__(optimizer)
# Now create the main scheduler
self.after_scheduler = self._create_main_scheduler()
def _get_merged_params(self, user_kwargs: Dict) -> Dict:
"""Merge default parameters with user-provided parameters."""
defaults = self.DEFAULT_PARAMS.get(self.scheduler_type, {}).copy()
defaults.update(user_kwargs)
return defaults
def _validate_parameters(self):
"""Validate scheduler parameters."""
if self.warmup_type not in ["linear", "exponential"]:
raise ValueError(f"Invalid warmup type: {self.warmup_type}")
if self.scheduler_type not in self.DEFAULT_PARAMS:
raise ValueError(f"Unsupported scheduler type: {self.scheduler_type}")
def _create_main_scheduler(self) -> _LRScheduler:
"""Create the main scheduler after warmup."""
remaining_steps = self.total_steps - self.warmup_steps
if self.scheduler_type == "cosine":
return CosineAnnealingLR(
self.optimizer,
T_max=remaining_steps,
eta_min=self.params["eta_min"]
)
elif self.scheduler_type == "cosine_restart":
T_0 = max(1, remaining_steps // self.params["rounds"])
return CosineAnnealingWarmRestarts(
self.optimizer,
T_0=T_0,
T_mult=self.params["T_mult"],
eta_min=self.params["eta_min"]
)
elif self.scheduler_type == "step":
step_size = max(1, remaining_steps // self.params["decay_steps"])
return StepLR(
self.optimizer,
step_size=step_size,
gamma=self.params["gamma"]
)
elif self.scheduler_type == "exponential":
return ExponentialLR(
self.optimizer,
gamma=self.params["gamma"]
)
elif self.scheduler_type == "one_cycle":
# Get base learning rates safely
base_lrs = getattr(self, 'base_lrs', [group['lr'] for group in self.optimizer.param_groups])
return OneCycleLR(
self.optimizer,
max_lr=[base_lr * self.params["lr_mult"] for base_lr in base_lrs],
total_steps=self.total_steps,
pct_start=self.warmup_ratio,
anneal_strategy='cos',
div_factor=self.params["div_factor"],
final_div_factor=self.params["final_div_factor"]
)
def get_lr(self):
"""Get current learning rate."""
if self._is_warmup:
progress = min(1.0, self.current_step / self.warmup_steps)
if self.warmup_type == "linear":
factor = progress
else: # exponential
factor = math.exp(progress * math.log(100)) / 100
return [base_lr * factor for base_lr in self.base_lrs]
# Return base learning rates if after_scheduler is not yet created
if self.after_scheduler is None:
return self.base_lrs
return self.after_scheduler.get_last_lr()
def step(self):
"""Step the scheduler."""
self.current_step += 1
if self._is_warmup and self.current_step >= self.warmup_steps:
self._is_warmup = False
if self._is_warmup:
super().step()
else:
# Only step the after_scheduler if it's been created
if self.after_scheduler is not None:
self.after_scheduler.step()
def get_scheduler(args, optimizer, num_samples):
"""
Create a learning rate scheduler from training arguments.
Args:
args: Training arguments object containing scheduler configuration
Expected attributes:
- batch_size: Training batch size
- num_epoch: Number of training epochs
- scheduler_type: Type of scheduler (default: 'cosine_restart')
- warmup_ratio: Warmup ratio (default: 0.1)
- warmup_type: Warmup type (default: 'linear')
- eta_min, T_mult, rounds, gamma, decay_steps: Optional scheduler-specific params
optimizer: PyTorch optimizer
num_samples: Number of training samples
Returns:
AutoScheduler instance
Example:
# In any trainer class:
self.optimizer = optim.AdamW(model.parameters(), lr=args.lr)
self.scheduler = get_scheduler(args, self.optimizer, len(dataset))
# During training:
self.scheduler.step()
"""
# Extract scheduler-specific parameters from args
scheduler_kwargs = {}
for param in ['eta_min', 'T_mult', 'rounds', 'gamma', 'decay_steps']:
if hasattr(args, param):
scheduler_kwargs[param] = getattr(args, param)
# Calculate total steps and create scheduler
total_steps = math.ceil(num_samples / args.batch_size) * args.num_epoch
return AutoScheduler(
optimizer=optimizer,
total_steps=total_steps,
scheduler_type=args.scheduler_type,
warmup_ratio=getattr(args, 'warmup_ratio', 0.1),
warmup_type=getattr(args, 'warmup_type', 'linear'),
**scheduler_kwargs
) |