Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tokenizer callbacks extended from base callbacks.""" | |
import math | |
from typing import Any, Optional | |
import numpy as np | |
import torch | |
from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule | |
from cosmos_predict1.utils import callback, distributed, log | |
from cosmos_predict1.utils.config import Config | |
from cosmos_predict1.utils.model import Model | |
from cosmos_predict1.utils.trainer import Trainer | |
_UINT8_MAX_F = float(np.iinfo(np.uint8).max) | |
_VIDEO_CONSISTENCY_LOSS = "video_consistency" | |
def make_video_grid(video, nrow=None, padding=1): | |
r"""Make a grid of videos for visualization. | |
Args: | |
video (tensor): video of size B x C x T x H x W. | |
nrow (int): number of rows in the grid. | |
padding (int): size of paddings between videos. | |
""" | |
b, c, t, h, w = video.shape | |
video = video.permute(0, 2, 3, 4, 1) | |
video = (video.cpu().detach().numpy() * _UINT8_MAX_F).astype("uint8") | |
if nrow is None: | |
nrow = math.ceil(math.sqrt(b)) | |
ncol = math.ceil(b / nrow) | |
video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype="uint8") | |
for i in range(b): | |
r = i // ncol | |
c = i % ncol | |
start_r = (padding + h) * r | |
start_c = (padding + w) * c | |
video_grid[:, start_r : start_r + h, start_c : start_c + w] = video[i] | |
video = [] | |
for i in range(t): | |
video.append(video_grid[i]) | |
return video | |
def compute_weight_norm(model): | |
weight_norm = dict() | |
for layer_name, param in model.named_parameters(): | |
if torch.isnan(param).any(): | |
raise ValueError(f"[weight] {layer_name} NaN detected in gradients") | |
weight_norm[f"{layer_name}"] = torch.norm(param, p=2).item() | |
return weight_norm | |
def compute_grad_norm(model): | |
grad_norm = dict() | |
for layer_name, param in model.named_parameters(): | |
if param.grad is not None: | |
if torch.isnan(param.grad).any(): | |
raise ValueError(f"[grad] {layer_name} NaN detected in gradients") | |
grad_norm[f"{layer_name}"] = torch.norm(param.grad, p=2).item() | |
return grad_norm | |
class AdaptCkptStateDict(callback.Callback): | |
def __init__(self, config: Config, trainer: Trainer): | |
super().__init__(config, trainer) | |
def on_save_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: | |
"""Adapt the state dict should the model be a compiled one.""" | |
if not isinstance(model.network, torch_OptimizedModule): | |
return | |
def _uncompiled_key(k): | |
if k.startswith("network._orig_mod"): | |
return k.replace("network._orig_mod", "network") | |
elif k.startswith("ema.network-_orig_mod"): | |
return k.replace("ema.network-_orig_mod", "ema.network") | |
return k | |
fixed_keys_state_dict = {} | |
for k, v in state_dict["model"].items(): | |
fixed_keys_state_dict[_uncompiled_key(k)] = v | |
state_dict["model"] = fixed_keys_state_dict | |
def on_load_checkpoint(self, model: Model, state_dict: dict[Any, Any]) -> None: | |
"""Adapt the state dict should the model be a compiled one.""" | |
if not isinstance(model.network, torch_OptimizedModule): | |
return | |
def _compiled_key(k): | |
if k.startswith("network."): | |
return k.replace("network", "network._orig_mod") | |
elif k.startswith("ema.network-"): | |
return k.replace("ema.network", "ema.network-_orig_mod") | |
return k | |
fixed_keys_state_dict = {} | |
for k, v in state_dict["model"].items(): | |
fixed_keys_state_dict[_compiled_key(k)] = v | |
state_dict["model"] = fixed_keys_state_dict | |
class GradClipCallback(callback.GradClipCallback): | |
"""The verbose tokenizer callback for gradient clipping.""" | |
def __init__(self, grad_clip_norm: float, config: Config, trainer: Trainer, verbose: bool): | |
super().__init__(config, trainer, grad_clip_norm) | |
self.verbose = verbose | |
def on_before_optimizer_step( | |
self, | |
model_ddp: distributed.DistributedDataParallel, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
grad_scaler: torch.amp.GradScaler, | |
iteration: int = 0, | |
) -> None: | |
grad_scaler.unscale_(optimizer) | |
total_norm = torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) | |
if torch.isnan(total_norm): | |
raise ValueError("[gradient clipping] NaN detected in gradient norms") | |
if torch.isfinite(total_norm) and total_norm > self.grad_clip_norm and self.verbose: | |
if model_ddp.module.network.training: | |
log.warning( | |
f"[net:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." | |
) | |
else: | |
log.warning( | |
f"[unknown:{iteration:07d}] Gradient norm {total_norm} > {self.grad_clip_norm}. Clipping gradients." | |
) | |
class ExpandLossMask(callback.Callback): | |
def __init__(self, kernel_size: int, config: Config, trainer: Trainer): | |
super().__init__(config, trainer) | |
self.kernel_size = kernel_size | |
def on_training_step_start(self, model: Model, data: dict[str, Any], iteration: int = 0) -> None: | |
"""Expand loss_mask with max pooling (to cover some partial human regions)""" | |
if "loss_mask" not in data.keys(): | |
return | |
assert data["loss_mask"].ndim == 4 or data["loss_mask"].ndim == 5, "ndim of loss_mask must be 4 or 5" | |
kernel_size = self.kernel_size | |
if data["loss_mask"].ndim == 4: | |
data["loss_mask"] = torch.nn.functional.max_pool2d( | |
data["loss_mask"], kernel_size, stride=1, padding=kernel_size // 2 | |
) | |
else: | |
data["loss_mask"] = torch.nn.functional.max_pool3d( | |
data["loss_mask"], | |
(1, kernel_size, kernel_size), | |
stride=1, | |
padding=(0, kernel_size // 2, kernel_size // 2), | |
) | |
class TorchCompile(callback.Callback): | |
""" | |
Callback to use torch.compile() on network or modules in losses(FlowLoss and PerceptualLoss) or both. | |
We compile them at later iteration as it prevents NCCL timeouts when times are very unstable during first iterations | |
""" | |
_TORCH_DYNAMO_CACHE_SIZE = 128 | |
def __init__( | |
self, | |
compile_after_iterations: int = 8, | |
compile_network: bool = False, | |
compile_loss: bool = False, | |
compile_loss_keys: list[str] = ["flow", "perceptual"], | |
): | |
self.initial_iteration: Optional[int] = None | |
self.compile_after_iterations: int = compile_after_iterations | |
self.compile_network: bool = compile_network | |
self.compile_loss: bool = compile_loss | |
self.compile_loss_keys: list[str] = compile_loss_keys | |
if self.compile_network or self.compile_loss: | |
torch._dynamo.config.cache_size_limit = TorchCompile._TORCH_DYNAMO_CACHE_SIZE | |
# Hack to make ".training" work on "torch.compile()" module. | |
# Value of ".training" is incorrectly set on torch.compile() module, when .eval() or .train() | |
# is invoked, but is correctly set on original module and this hack accesses that value | |
# I've created issue about this: https://github.com/pytorch/pytorch/issues/132986 | |
torch_OptimizedModule.training = property( | |
lambda self: self._orig_mod.training, lambda self, value: None, lambda self: None | |
) | |
def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: | |
if not (self.compile_network or self.compile_loss): | |
return | |
if self.initial_iteration is None: | |
log.info(f"Compilation will done on iteration {iteration + self.compile_after_iterations}") | |
self.initial_iteration = iteration | |
if self.compile_network: | |
if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: | |
log.warning( | |
'"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), network will not be compiled' | |
) | |
if iteration - self.initial_iteration == self.compile_after_iterations: | |
if self.compile_network: | |
if model.config.ema.enabled is True and model.config.ema.torch_compile_buffer_renaming is False: | |
log.warning( | |
'"model.config.ema.torch_compile_buffer_renaming" should be turned on for the EMA to work with torch.compile(), skipping network compilation' | |
) | |
else: | |
log.info("Compiling network") | |
model.network = torch.compile(model.network, dynamic=False) | |
if self.compile_loss: | |
for key in self.compile_loss_keys: | |
if key not in model.loss.loss_modules: | |
log.warning(f"Loss module for compilation with key: {key} not found") | |
else: | |
if ( | |
hasattr(model.loss.loss_modules[key], "checkpoint_activations") | |
and getattr(model.loss.loss_modules[key], "checkpoint_activations") is True | |
): | |
log.warning( | |
f"torch.compile() doesn't work with activation checkpointing, skipping compilation for loss with key: {key}" | |
) | |
else: | |
log.info(f"Compiling loss with key: {key}") | |
model.loss.loss_modules[key].torch_compile() | |