roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from cosmos_predict1.utils.lazy_config import LazyCall as L
class LambdaLinearWarmupScheduler:
"""
A learning rate scheduler that implements linear warm-up and cool-down.
This scheduler provides three phases:
1. Warm-up: Learning rate linearly increases from 0 to 1.
2. Constant: Learning rate remains at 1.
3. Cool-down: Learning rate linearly decreases from 1 to 0.
Args:
warmup_steps (int): Number of steps for the warm-up phase.
warmup_offset (int): Starts warmup from this offset.
max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided.
cooldown_steps (int, optional): Number of steps for the cool-down phase.
Raises:
ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given.
"""
def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None):
self.warmup_steps = warmup_steps
self.warmup_offset = warmup_offset
self.max_iter = max_iter
self.cooldown_steps = cooldown_steps
if cooldown_steps is not None:
if max_iter is None:
raise ValueError("max_iter must be specified when cooldown_steps is provided")
self.cooldown_start = max_iter - cooldown_steps
else:
self.cooldown_start = None
def __call__(self, step):
# Warm-up phase
if step < self.warmup_offset:
return 0
if step < self.warmup_steps + self.warmup_offset:
return float(step - self.warmup_offset) / float(max(1, self.warmup_steps))
# Constant phase (no cool-down)
elif self.cooldown_steps is None:
return 1.0
# Constant phase (before cool-down starts)
elif step < self.cooldown_start:
return 1.0
# Cool-down phase
elif self.cooldown_start <= step < self.max_iter:
cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps
return 1.0 - cooldown_progress
# After max_iter
elif step >= self.max_iter:
return 0.0
# Unexpected case
else:
raise ValueError(f"Invalid step {step}")
LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)(
optimizer=None,
lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000),
)