Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import time | |
import warnings | |
from typing import TYPE_CHECKING, Any, Callable, Optional | |
import omegaconf | |
import torch | |
import torch.utils.data | |
import tqdm | |
from cosmos_predict1.utils import distributed, log | |
from cosmos_predict1.utils.lazy_config import instantiate | |
from cosmos_predict1.utils.misc import get_local_tensor_if_DTensor | |
if TYPE_CHECKING: | |
from cosmos_predict1.utils.config import Config | |
from cosmos_predict1.utils.model import Model | |
from cosmos_predict1.utils.trainer import Trainer | |
class CallBackGroup: | |
"""A class for hosting a collection of callback objects. | |
It is used to execute callback functions of multiple callback objects with the same method name. | |
When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs | |
self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match. | |
Attributes: | |
_callbacks (list[Callback]): List of callback objects. | |
""" | |
def __init__(self, config: Config, trainer: Trainer) -> None: | |
"""Initializes the list of callback objects. | |
Args: | |
config (Config): The config object for the codebase. | |
trainer (Trainer): The main trainer. | |
""" | |
self._callbacks = [] | |
callback_configs = config.trainer.callbacks | |
if callback_configs: | |
if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig): | |
warnings.warn( | |
"The 'config.trainer.callbacks' parameter should be a dict instead of a list. " | |
"Please update your code", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)} | |
for callback_name, current_callback_cfg in callback_configs.items(): | |
if "_target_" not in current_callback_cfg: | |
log.critical( | |
f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}" | |
) | |
continue | |
log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}") | |
_callback = instantiate(current_callback_cfg) | |
assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback." | |
_callback.config = config | |
_callback.trainer = trainer | |
self._callbacks.append(_callback) | |
def __getattr__(self, method_name: str) -> Callable: | |
"""Loops through the callback objects to call the corresponding callback function. | |
Args: | |
method_name (str): Callback method name. | |
""" | |
def multi_callback_wrapper(*args, **kwargs) -> None: | |
for callback in self._callbacks: | |
assert hasattr(callback, method_name) | |
method = getattr(callback, method_name) | |
assert callable(method) | |
_ = method(*args, **kwargs) | |
return multi_callback_wrapper | |
class Callback: | |
"""The base class for all callbacks. | |
All callbacks should inherit from this class and adhere to the established method names and signatures. | |
""" | |
def __init__(self, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): | |
"""Initializes a Callback object. | |
Args: | |
config (Optional[Config]): The configuration object for the codebase, if available. | |
trainer (Optional[Trainer]): The main trainer handling the training loop, if available. | |
Notes: | |
The config and trainer parameters are optional to maintain backward compatibility. | |
In future releases, these parameters will be removed. Upon using these parameters, a deprecation | |
warning will be issued. | |
""" | |
if config is not None or trainer is not None: | |
warnings.warn( | |
"The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. " | |
"Please update your code to create Callback instances without these parameters.", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
del config, trainer | |
def on_train_start(self, model: Model, iteration: int = 0) -> None: | |
pass | |
def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: | |
pass | |
def on_before_forward(self, iteration: int = 0) -> None: | |
pass | |
def on_after_forward(self, iteration: int = 0) -> None: | |
pass | |
def on_before_backward( | |
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 | |
) -> None: | |
pass | |
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: | |
pass | |
def on_before_dataloading(self, iteration: int = 0) -> None: | |
pass | |
def on_after_dataloading(self, iteration: int = 0) -> None: | |
pass | |
def on_optimizer_init_start(self) -> None: | |
pass | |
def on_optimizer_init_end(self) -> None: | |
pass | |
def on_before_optimizer_step( | |
self, | |
model_ddp: distributed.DistributedDataParallel, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
grad_scaler: torch.amp.GradScaler, | |
iteration: int = 0, | |
) -> None: | |
pass | |
def on_before_zero_grad( | |
self, | |
model_ddp: distributed.DistributedDataParallel, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
iteration: int = 0, | |
) -> None: | |
pass | |
def on_training_step_end( | |
self, | |
model: Model, | |
data_batch: dict[str, torch.Tensor], | |
output_batch: dict[str, torch.Tensor], | |
loss: torch.Tensor, | |
iteration: int = 0, | |
) -> None: | |
pass | |
def on_validation_start( | |
self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 | |
) -> None: | |
pass | |
def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: | |
pass | |
def on_validation_step_end( | |
self, | |
model: Model, | |
data_batch: dict[str, torch.Tensor], | |
output_batch: dict[str, torch.Tensor], | |
loss: torch.Tensor, | |
iteration: int = 0, | |
) -> None: | |
pass | |
def on_validation_end(self, model: Model, iteration: int = 0) -> None: | |
pass | |
def on_load_checkpoint_start(self, model: Model) -> None: | |
pass | |
def on_load_checkpoint_end(self, model: Model) -> None: | |
pass | |
def on_load_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: | |
pass | |
def on_save_checkpoint_start(self, model: Model, iteration: int = 0) -> None: | |
pass | |
def on_save_checkpoint_end(self, model: Model, iteration: int = 0) -> None: | |
pass | |
def on_save_checkpoint_success(self, iteration: int = 0) -> None: | |
pass | |
def on_save_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: | |
pass | |
def on_train_end(self, model: Model, iteration: int = 0) -> None: | |
pass | |
def on_app_end(self) -> None: | |
pass | |
class EMAModelCallback(Callback): | |
"""The callback class for tracking EMA model weights.""" | |
def on_train_start(self, model: Model, iteration: int = 0) -> None: | |
# Set up the EMA model weight tracker. | |
if model.config.ema.enabled: | |
assert hasattr(model, "ema"), "EMA should be initialized from Model" | |
# EMA model must be kept in FP32 precision. | |
model.ema = model.ema.to(dtype=torch.float32) | |
else: | |
assert not hasattr(model, "ema"), "There should be no EMA initialized." | |
def on_training_step_end( | |
self, | |
model: Model, | |
data_batch: dict[str, torch.Tensor], | |
output_batch: dict[str, torch.Tensor], | |
loss: torch.Tensor, | |
iteration: int = 0, | |
) -> None: | |
# Update the EMA model with the new regular weights. | |
if model.config.ema.enabled: | |
model.ema.update_average(model, iteration) | |
class ProgressBarCallback(Callback): | |
"""The callback class for visualizing the training/validation progress bar in the console.""" | |
def on_train_start(self, model: Model, iteration: int = 0) -> None: | |
self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") | |
def on_training_step_end( | |
self, | |
model: Model, | |
data_batch: dict[str, torch.Tensor], | |
output_batch: dict[str, torch.Tensor], | |
loss: torch.Tensor, | |
iteration: int = 0, | |
) -> None: | |
self.train_pbar.update() | |
def on_validation_start( | |
self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 | |
) -> None: | |
if self.config.trainer.max_val_iter is not None: | |
num_iter = self.config.trainer.max_val_iter | |
else: | |
num_iter = len(dataloader_val) | |
assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}" | |
self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False) | |
def on_validation_step_end( | |
self, | |
model: Model, | |
data_batch: dict[str, torch.Tensor], | |
output_batch: dict[str, torch.Tensor], | |
loss: torch.Tensor, | |
iteration: int = 0, | |
) -> None: | |
self.val_pbar.update() | |
def on_validation_end(self, model: Model, iteration: int = 0) -> None: | |
self.val_pbar.close() | |
def on_train_end(self, model: Model, iteration: int = 0) -> None: | |
self.trainer.checkpointer.finalize() | |
self.train_pbar.close() | |
class IterationLoggerCallback(Callback): | |
"""The callback class for visualizing the training/validation progress bar in the console.""" | |
def on_train_start(self, model: Model, iteration: int = 0) -> None: | |
# self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") | |
self.start_iteration_time = time.time() | |
self.elapsed_iteration_time = 0 | |
def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: | |
self.start_iteration_time = time.time() | |
def on_training_step_end( | |
self, | |
model: Model, | |
data_batch: dict[str, torch.Tensor], | |
output_batch: dict[str, torch.Tensor], | |
loss: torch.Tensor, | |
iteration: int = 0, | |
) -> None: | |
self.elapsed_iteration_time += time.time() - self.start_iteration_time | |
if iteration % self.config.trainer.logging_iter == 0: | |
avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter | |
log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}") | |
self.elapsed_iteration_time = 0 | |
class GradClipCallback(Callback): | |
"""The callback class for gradient clipping.""" | |
def __init__( | |
self, | |
config: Optional["Config"] = None, | |
trainer: Optional["Trainer"] = None, | |
grad_clip_norm: float = 1.0, | |
): | |
super().__init__(config, trainer) | |
self.grad_clip_norm = grad_clip_norm | |
def on_before_optimizer_step( | |
self, | |
model_ddp: distributed.DistributedDataParallel, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
grad_scaler: torch.amp.GradScaler, | |
iteration: int = 0, | |
) -> None: | |
grad_scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) | |
class LowPrecisionCallback(Callback): | |
"""The callback class handling low precision training""" | |
def __init__(self, update_iter: int, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): | |
super().__init__(config, trainer) | |
self.update_iter = update_iter | |
def on_train_start(self, model: Model, iteration: int = 0) -> None: | |
assert model.precision in [ | |
torch.bfloat16, | |
torch.float16, | |
torch.half, | |
], "LowPrecisionCallback must use a low precision dtype." | |
self.precision_type = model.precision | |
def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: | |
for k, v in data.items(): | |
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): | |
data[k] = v.to(dtype=self.precision_type) | |
def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: | |
for k, v in data.items(): | |
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): | |
data[k] = v.to(dtype=self.precision_type) | |
def on_before_zero_grad( | |
self, | |
model_ddp: distributed.DistributedDataParallel, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
iteration: int = 0, | |
) -> None: | |
if iteration % self.update_iter == 0: | |
if getattr(optimizer, "master_weights", False): | |
params, master_params = [], [] | |
for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): | |
for p, p_master in zip(group["params"], group_master["params"]): | |
params.append(get_local_tensor_if_DTensor(p.data)) | |
master_params.append(p_master.data) | |
torch._foreach_copy_(params, master_params) | |