# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hydra import torch from torch import nn from cosmos_predict1.utils import log from cosmos_predict1.utils.fused_adam import FusedAdam def get_regular_param_group(net: nn.Module): """ seperate the parameters of the network into two groups: decay and no_decay. based on nano_gpt codebase. """ param_dict = {pn: p for pn, p in net.named_parameters()} param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] return decay_params, nodecay_params def get_base_optimizer( model: nn.Module, lr: float, weight_decay: float, optim_type: str = "adamw", sharding: bool = False, **kwargs, ) -> torch.optim.Optimizer: net_decay_param, net_nodecay_param = get_regular_param_group(model) num_decay_params = sum(p.numel() for p in net_decay_param) num_nodecay_params = sum(p.numel() for p in net_nodecay_param) net_param_total = num_decay_params + num_nodecay_params log.critical(f"total num parameters : {net_param_total:,}") param_group = [ { "params": net_decay_param + net_nodecay_param, "lr": lr, "weight_decay": weight_decay, }, ] if optim_type == "adamw": opt_cls = torch.optim.AdamW elif optim_type == "fusedadam": opt_cls = FusedAdam else: raise ValueError(f"Unknown optimizer type: {optim_type}") return opt_cls(param_group, **kwargs) def get_base_scheduler( optimizer: torch.optim.Optimizer, model: nn.Module, scheduler_config: dict, ): net_scheduler = hydra.utils.instantiate(scheduler_config) net_scheduler.model = model return torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=[ net_scheduler.schedule, ], )