File size: 7,084 Bytes
e611d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import math
from typing import Dict, Any, Union, Optional
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, StepLR, ExponentialLR, CosineAnnealingWarmRestarts, OneCycleLR


class AutoScheduler(_LRScheduler):
    """
    Automatic learning rate scheduler with warmup and configurable main schedule.
    """
    
    # Default parameters for different scheduler types
    DEFAULT_PARAMS = {
        "cosine": {"eta_min": 1e-6},
        "cosine_restart": {"T_mult": 2, "eta_min": 1e-6, "rounds": 5},
        "step": {"gamma": 0.5, "decay_steps": 3},
        "exponential": {"gamma": 0.95},
        "one_cycle": {"lr_mult": 10.0, "div_factor": 25.0, "final_div_factor": 1e4}
    }
    
    def __init__(
        self,
        optimizer: Optimizer,
        total_steps: int,
        scheduler_type: str = "cosine_restart",
        warmup_ratio: float = 0.1,
        warmup_type: str = "linear",
        **kwargs
    ):
        self.scheduler_type = scheduler_type
        self.warmup_type = warmup_type
        self.warmup_ratio = warmup_ratio
        self.total_steps = total_steps
        self.warmup_steps = max(1, int(total_steps * warmup_ratio))
        self.current_step = 0
        self._is_warmup = True
        
        # Merge default parameters with user-provided kwargs
        self.params = self._get_merged_params(kwargs)
        
        # Validate parameters
        self._validate_parameters()
        
        # Create the main scheduler BEFORE calling super().__init__
        # This is needed because super().__init__ will call step() immediately
        self.after_scheduler = None  # Initialize as None first
        
        # Initialize parent class
        super().__init__(optimizer)
        
        # Now create the main scheduler
        self.after_scheduler = self._create_main_scheduler()

    def _get_merged_params(self, user_kwargs: Dict) -> Dict:
        """Merge default parameters with user-provided parameters."""
        defaults = self.DEFAULT_PARAMS.get(self.scheduler_type, {}).copy()
        defaults.update(user_kwargs)
        return defaults

    def _validate_parameters(self):
        """Validate scheduler parameters."""
        if self.warmup_type not in ["linear", "exponential"]:
            raise ValueError(f"Invalid warmup type: {self.warmup_type}")
        if self.scheduler_type not in self.DEFAULT_PARAMS:
            raise ValueError(f"Unsupported scheduler type: {self.scheduler_type}")

    def _create_main_scheduler(self) -> _LRScheduler:
        """Create the main scheduler after warmup."""
        remaining_steps = self.total_steps - self.warmup_steps
        
        if self.scheduler_type == "cosine":
            return CosineAnnealingLR(
                self.optimizer, 
                T_max=remaining_steps,
                eta_min=self.params["eta_min"]
            )
        
        elif self.scheduler_type == "cosine_restart":
            T_0 = max(1, remaining_steps // self.params["rounds"])
            return CosineAnnealingWarmRestarts(
                self.optimizer,
                T_0=T_0,
                T_mult=self.params["T_mult"],
                eta_min=self.params["eta_min"]
            )
        
        elif self.scheduler_type == "step":
            step_size = max(1, remaining_steps // self.params["decay_steps"])
            return StepLR(
                self.optimizer,
                step_size=step_size,
                gamma=self.params["gamma"]
            )
        
        elif self.scheduler_type == "exponential":
            return ExponentialLR(
                self.optimizer,
                gamma=self.params["gamma"]
            )
        
        elif self.scheduler_type == "one_cycle":
            # Get base learning rates safely
            base_lrs = getattr(self, 'base_lrs', [group['lr'] for group in self.optimizer.param_groups])
            return OneCycleLR(
                self.optimizer,
                max_lr=[base_lr * self.params["lr_mult"] for base_lr in base_lrs],
                total_steps=self.total_steps,
                pct_start=self.warmup_ratio,
                anneal_strategy='cos',
                div_factor=self.params["div_factor"],
                final_div_factor=self.params["final_div_factor"]
            )

    def get_lr(self):
        """Get current learning rate."""
        if self._is_warmup:
            progress = min(1.0, self.current_step / self.warmup_steps)
            if self.warmup_type == "linear":
                factor = progress
            else:  # exponential
                factor = math.exp(progress * math.log(100)) / 100
            return [base_lr * factor for base_lr in self.base_lrs]
        
        # Return base learning rates if after_scheduler is not yet created
        if self.after_scheduler is None:
            return self.base_lrs
        
        return self.after_scheduler.get_last_lr()

    def step(self):
        """Step the scheduler."""
        self.current_step += 1
        if self._is_warmup and self.current_step >= self.warmup_steps:
            self._is_warmup = False
        
        if self._is_warmup:
            super().step()
        else:
            # Only step the after_scheduler if it's been created
            if self.after_scheduler is not None:
                self.after_scheduler.step()


def get_scheduler(args, optimizer, num_samples):
    """
    Create a learning rate scheduler from training arguments.
    
    Args:
        args: Training arguments object containing scheduler configuration
              Expected attributes:
              - batch_size: Training batch size
              - num_epoch: Number of training epochs
              - scheduler_type: Type of scheduler (default: 'cosine_restart')
              - warmup_ratio: Warmup ratio (default: 0.1)
              - warmup_type: Warmup type (default: 'linear')
              - eta_min, T_mult, rounds, gamma, decay_steps: Optional scheduler-specific params
        optimizer: PyTorch optimizer
        num_samples: Number of training samples
        
    Returns:
        AutoScheduler instance
        
    Example:
        # In any trainer class:
        self.optimizer = optim.AdamW(model.parameters(), lr=args.lr)
        self.scheduler = get_scheduler(args, self.optimizer, len(dataset))
        
        # During training:
        self.scheduler.step()
    """
    # Extract scheduler-specific parameters from args
    scheduler_kwargs = {}
    for param in ['eta_min', 'T_mult', 'rounds', 'gamma', 'decay_steps']:
        if hasattr(args, param):
            scheduler_kwargs[param] = getattr(args, param)
    
    # Calculate total steps and create scheduler
    total_steps = math.ceil(num_samples / args.batch_size) * args.num_epoch
    
    return AutoScheduler(
        optimizer=optimizer,
        total_steps=total_steps,
        scheduler_type=args.scheduler_type,
        warmup_ratio=getattr(args, 'warmup_ratio', 0.1),
        warmup_type=getattr(args, 'warmup_type', 'linear'),
        **scheduler_kwargs
    )