ReCEP / src /bce /model /scheduler.py
NielTT's picture
Upload 108 files
e611d1f verified
import math
from typing import Dict, Any, Union, Optional
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, StepLR, ExponentialLR, CosineAnnealingWarmRestarts, OneCycleLR
class AutoScheduler(_LRScheduler):
"""
Automatic learning rate scheduler with warmup and configurable main schedule.
"""
# Default parameters for different scheduler types
DEFAULT_PARAMS = {
"cosine": {"eta_min": 1e-6},
"cosine_restart": {"T_mult": 2, "eta_min": 1e-6, "rounds": 5},
"step": {"gamma": 0.5, "decay_steps": 3},
"exponential": {"gamma": 0.95},
"one_cycle": {"lr_mult": 10.0, "div_factor": 25.0, "final_div_factor": 1e4}
}
def __init__(
self,
optimizer: Optimizer,
total_steps: int,
scheduler_type: str = "cosine_restart",
warmup_ratio: float = 0.1,
warmup_type: str = "linear",
**kwargs
):
self.scheduler_type = scheduler_type
self.warmup_type = warmup_type
self.warmup_ratio = warmup_ratio
self.total_steps = total_steps
self.warmup_steps = max(1, int(total_steps * warmup_ratio))
self.current_step = 0
self._is_warmup = True
# Merge default parameters with user-provided kwargs
self.params = self._get_merged_params(kwargs)
# Validate parameters
self._validate_parameters()
# Create the main scheduler BEFORE calling super().__init__
# This is needed because super().__init__ will call step() immediately
self.after_scheduler = None # Initialize as None first
# Initialize parent class
super().__init__(optimizer)
# Now create the main scheduler
self.after_scheduler = self._create_main_scheduler()
def _get_merged_params(self, user_kwargs: Dict) -> Dict:
"""Merge default parameters with user-provided parameters."""
defaults = self.DEFAULT_PARAMS.get(self.scheduler_type, {}).copy()
defaults.update(user_kwargs)
return defaults
def _validate_parameters(self):
"""Validate scheduler parameters."""
if self.warmup_type not in ["linear", "exponential"]:
raise ValueError(f"Invalid warmup type: {self.warmup_type}")
if self.scheduler_type not in self.DEFAULT_PARAMS:
raise ValueError(f"Unsupported scheduler type: {self.scheduler_type}")
def _create_main_scheduler(self) -> _LRScheduler:
"""Create the main scheduler after warmup."""
remaining_steps = self.total_steps - self.warmup_steps
if self.scheduler_type == "cosine":
return CosineAnnealingLR(
self.optimizer,
T_max=remaining_steps,
eta_min=self.params["eta_min"]
)
elif self.scheduler_type == "cosine_restart":
T_0 = max(1, remaining_steps // self.params["rounds"])
return CosineAnnealingWarmRestarts(
self.optimizer,
T_0=T_0,
T_mult=self.params["T_mult"],
eta_min=self.params["eta_min"]
)
elif self.scheduler_type == "step":
step_size = max(1, remaining_steps // self.params["decay_steps"])
return StepLR(
self.optimizer,
step_size=step_size,
gamma=self.params["gamma"]
)
elif self.scheduler_type == "exponential":
return ExponentialLR(
self.optimizer,
gamma=self.params["gamma"]
)
elif self.scheduler_type == "one_cycle":
# Get base learning rates safely
base_lrs = getattr(self, 'base_lrs', [group['lr'] for group in self.optimizer.param_groups])
return OneCycleLR(
self.optimizer,
max_lr=[base_lr * self.params["lr_mult"] for base_lr in base_lrs],
total_steps=self.total_steps,
pct_start=self.warmup_ratio,
anneal_strategy='cos',
div_factor=self.params["div_factor"],
final_div_factor=self.params["final_div_factor"]
)
def get_lr(self):
"""Get current learning rate."""
if self._is_warmup:
progress = min(1.0, self.current_step / self.warmup_steps)
if self.warmup_type == "linear":
factor = progress
else: # exponential
factor = math.exp(progress * math.log(100)) / 100
return [base_lr * factor for base_lr in self.base_lrs]
# Return base learning rates if after_scheduler is not yet created
if self.after_scheduler is None:
return self.base_lrs
return self.after_scheduler.get_last_lr()
def step(self):
"""Step the scheduler."""
self.current_step += 1
if self._is_warmup and self.current_step >= self.warmup_steps:
self._is_warmup = False
if self._is_warmup:
super().step()
else:
# Only step the after_scheduler if it's been created
if self.after_scheduler is not None:
self.after_scheduler.step()
def get_scheduler(args, optimizer, num_samples):
"""
Create a learning rate scheduler from training arguments.
Args:
args: Training arguments object containing scheduler configuration
Expected attributes:
- batch_size: Training batch size
- num_epoch: Number of training epochs
- scheduler_type: Type of scheduler (default: 'cosine_restart')
- warmup_ratio: Warmup ratio (default: 0.1)
- warmup_type: Warmup type (default: 'linear')
- eta_min, T_mult, rounds, gamma, decay_steps: Optional scheduler-specific params
optimizer: PyTorch optimizer
num_samples: Number of training samples
Returns:
AutoScheduler instance
Example:
# In any trainer class:
self.optimizer = optim.AdamW(model.parameters(), lr=args.lr)
self.scheduler = get_scheduler(args, self.optimizer, len(dataset))
# During training:
self.scheduler.step()
"""
# Extract scheduler-specific parameters from args
scheduler_kwargs = {}
for param in ['eta_min', 'T_mult', 'rounds', 'gamma', 'decay_steps']:
if hasattr(args, param):
scheduler_kwargs[param] = getattr(args, param)
# Calculate total steps and create scheduler
total_steps = math.ceil(num_samples / args.batch_size) * args.num_epoch
return AutoScheduler(
optimizer=optimizer,
total_steps=total_steps,
scheduler_type=args.scheduler_type,
warmup_ratio=getattr(args, 'warmup_ratio', 0.1),
warmup_type=getattr(args, 'warmup_type', 'linear'),
**scheduler_kwargs
)