roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple
import torch
from megatron.core import parallel_state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from cosmos_predict1.utils import distributed
from cosmos_predict1.utils.callbacks.grad_clip import GradClip as GradClipImage
from cosmos_predict1.utils.callbacks.grad_clip import _fused_nan_to_num
from cosmos_predict1.utils.model import Model
@dataclass
class _MagnitudeRecord:
state: float = 0
iter_count: int = 0
def reset(self) -> None:
self.state = 0
self.iter_count = 0
def update(self, cur_state: torch.Tensor) -> None:
self.state += cur_state
self.iter_count += 1
def get_stat(self) -> Tuple[float, float]:
if self.iter_count > 0:
avg_state = self.state / self.iter_count
avg_state = avg_state.item()
else:
avg_state = 0
self.reset()
return avg_state
class GradClip(GradClipImage):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.img_mag_log = _MagnitudeRecord()
self.video_mag_log = _MagnitudeRecord()
self._cur_state = None
def on_training_step_start(self, model: Model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None:
if model.is_image_batch(data_batch):
self._cur_state = self.img_mag_log
else:
self._cur_state = self.video_mag_log
def on_before_optimizer_step(
self,
model_ddp: distributed.DistributedDataParallel,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
grad_scaler: torch.amp.GradScaler,
iteration: int = 0,
) -> None:
del optimizer, scheduler
if isinstance(model_ddp, distributed.DistributedDataParallel):
model = model_ddp.module
else:
model = model_ddp
params = []
if self.model_key is not None:
items = self.model_key.split(".")
for item in items:
model = getattr(model, item)
if self.force_finite:
for param in model.parameters():
if param.grad is not None:
params.append(param.grad)
# torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
_fused_nan_to_num(params)
if isinstance(model, FSDP) and self.fsdp_enabled:
total_norm = model.clip_grad_norm_(self.clip_norm)
else:
if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1:
total_norm = model_ddp.module.clip_grad_norm_(self.clip_norm)
else:
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True)
self._cur_state.update(total_norm)