|
import math |
|
from typing import Dict, Any, Union, Optional |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, StepLR, ExponentialLR, CosineAnnealingWarmRestarts, OneCycleLR |
|
|
|
|
|
class AutoScheduler(_LRScheduler): |
|
""" |
|
Automatic learning rate scheduler with warmup and configurable main schedule. |
|
""" |
|
|
|
|
|
DEFAULT_PARAMS = { |
|
"cosine": {"eta_min": 1e-6}, |
|
"cosine_restart": {"T_mult": 2, "eta_min": 1e-6, "rounds": 5}, |
|
"step": {"gamma": 0.5, "decay_steps": 3}, |
|
"exponential": {"gamma": 0.95}, |
|
"one_cycle": {"lr_mult": 10.0, "div_factor": 25.0, "final_div_factor": 1e4} |
|
} |
|
|
|
def __init__( |
|
self, |
|
optimizer: Optimizer, |
|
total_steps: int, |
|
scheduler_type: str = "cosine_restart", |
|
warmup_ratio: float = 0.1, |
|
warmup_type: str = "linear", |
|
**kwargs |
|
): |
|
self.scheduler_type = scheduler_type |
|
self.warmup_type = warmup_type |
|
self.warmup_ratio = warmup_ratio |
|
self.total_steps = total_steps |
|
self.warmup_steps = max(1, int(total_steps * warmup_ratio)) |
|
self.current_step = 0 |
|
self._is_warmup = True |
|
|
|
|
|
self.params = self._get_merged_params(kwargs) |
|
|
|
|
|
self._validate_parameters() |
|
|
|
|
|
|
|
self.after_scheduler = None |
|
|
|
|
|
super().__init__(optimizer) |
|
|
|
|
|
self.after_scheduler = self._create_main_scheduler() |
|
|
|
def _get_merged_params(self, user_kwargs: Dict) -> Dict: |
|
"""Merge default parameters with user-provided parameters.""" |
|
defaults = self.DEFAULT_PARAMS.get(self.scheduler_type, {}).copy() |
|
defaults.update(user_kwargs) |
|
return defaults |
|
|
|
def _validate_parameters(self): |
|
"""Validate scheduler parameters.""" |
|
if self.warmup_type not in ["linear", "exponential"]: |
|
raise ValueError(f"Invalid warmup type: {self.warmup_type}") |
|
if self.scheduler_type not in self.DEFAULT_PARAMS: |
|
raise ValueError(f"Unsupported scheduler type: {self.scheduler_type}") |
|
|
|
def _create_main_scheduler(self) -> _LRScheduler: |
|
"""Create the main scheduler after warmup.""" |
|
remaining_steps = self.total_steps - self.warmup_steps |
|
|
|
if self.scheduler_type == "cosine": |
|
return CosineAnnealingLR( |
|
self.optimizer, |
|
T_max=remaining_steps, |
|
eta_min=self.params["eta_min"] |
|
) |
|
|
|
elif self.scheduler_type == "cosine_restart": |
|
T_0 = max(1, remaining_steps // self.params["rounds"]) |
|
return CosineAnnealingWarmRestarts( |
|
self.optimizer, |
|
T_0=T_0, |
|
T_mult=self.params["T_mult"], |
|
eta_min=self.params["eta_min"] |
|
) |
|
|
|
elif self.scheduler_type == "step": |
|
step_size = max(1, remaining_steps // self.params["decay_steps"]) |
|
return StepLR( |
|
self.optimizer, |
|
step_size=step_size, |
|
gamma=self.params["gamma"] |
|
) |
|
|
|
elif self.scheduler_type == "exponential": |
|
return ExponentialLR( |
|
self.optimizer, |
|
gamma=self.params["gamma"] |
|
) |
|
|
|
elif self.scheduler_type == "one_cycle": |
|
|
|
base_lrs = getattr(self, 'base_lrs', [group['lr'] for group in self.optimizer.param_groups]) |
|
return OneCycleLR( |
|
self.optimizer, |
|
max_lr=[base_lr * self.params["lr_mult"] for base_lr in base_lrs], |
|
total_steps=self.total_steps, |
|
pct_start=self.warmup_ratio, |
|
anneal_strategy='cos', |
|
div_factor=self.params["div_factor"], |
|
final_div_factor=self.params["final_div_factor"] |
|
) |
|
|
|
def get_lr(self): |
|
"""Get current learning rate.""" |
|
if self._is_warmup: |
|
progress = min(1.0, self.current_step / self.warmup_steps) |
|
if self.warmup_type == "linear": |
|
factor = progress |
|
else: |
|
factor = math.exp(progress * math.log(100)) / 100 |
|
return [base_lr * factor for base_lr in self.base_lrs] |
|
|
|
|
|
if self.after_scheduler is None: |
|
return self.base_lrs |
|
|
|
return self.after_scheduler.get_last_lr() |
|
|
|
def step(self): |
|
"""Step the scheduler.""" |
|
self.current_step += 1 |
|
if self._is_warmup and self.current_step >= self.warmup_steps: |
|
self._is_warmup = False |
|
|
|
if self._is_warmup: |
|
super().step() |
|
else: |
|
|
|
if self.after_scheduler is not None: |
|
self.after_scheduler.step() |
|
|
|
|
|
def get_scheduler(args, optimizer, num_samples): |
|
""" |
|
Create a learning rate scheduler from training arguments. |
|
|
|
Args: |
|
args: Training arguments object containing scheduler configuration |
|
Expected attributes: |
|
- batch_size: Training batch size |
|
- num_epoch: Number of training epochs |
|
- scheduler_type: Type of scheduler (default: 'cosine_restart') |
|
- warmup_ratio: Warmup ratio (default: 0.1) |
|
- warmup_type: Warmup type (default: 'linear') |
|
- eta_min, T_mult, rounds, gamma, decay_steps: Optional scheduler-specific params |
|
optimizer: PyTorch optimizer |
|
num_samples: Number of training samples |
|
|
|
Returns: |
|
AutoScheduler instance |
|
|
|
Example: |
|
# In any trainer class: |
|
self.optimizer = optim.AdamW(model.parameters(), lr=args.lr) |
|
self.scheduler = get_scheduler(args, self.optimizer, len(dataset)) |
|
|
|
# During training: |
|
self.scheduler.step() |
|
""" |
|
|
|
scheduler_kwargs = {} |
|
for param in ['eta_min', 'T_mult', 'rounds', 'gamma', 'decay_steps']: |
|
if hasattr(args, param): |
|
scheduler_kwargs[param] = getattr(args, param) |
|
|
|
|
|
total_steps = math.ceil(num_samples / args.batch_size) * args.num_epoch |
|
|
|
return AutoScheduler( |
|
optimizer=optimizer, |
|
total_steps=total_steps, |
|
scheduler_type=args.scheduler_type, |
|
warmup_ratio=getattr(args, 'warmup_ratio', 0.1), |
|
warmup_type=getattr(args, 'warmup_type', 'linear'), |
|
**scheduler_kwargs |
|
) |