roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import Any, Dict, Optional, Type, TypeVar, Union
import attrs
import torch
from megatron.core import ModelParallelConfig
from cosmos_predict1.utils import callback
from cosmos_predict1.utils.lazy_config import LazyCall as L
from cosmos_predict1.utils.lazy_config import LazyDict
from cosmos_predict1.utils.misc import Color
T = TypeVar("T")
def _is_attrs_instance(obj: object) -> bool:
"""
Helper function to check if an object is an instance of an attrs-defined class.
Args:
obj: The object to check.
Returns:
bool: True if the object is an instance of an attrs-defined class, False otherwise.
"""
return hasattr(obj, "__attrs_attrs__")
def make_freezable(cls: T) -> T:
"""
A decorator that adds the capability to freeze instances of an attrs-defined class.
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
to hack on a "_is_frozen" attribute.
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
any attrs-defined objects that are attributes of the class.
Usage:
@make_freezable
@attrs.define(slots=False)
class MyClass:
attribute1: int
attribute2: str
obj = MyClass(1, 'a')
obj.freeze() # Freeze the instance
obj.attribute1 = 2 # Raises AttributeError
Args:
cls: The class to be decorated.
Returns:
The decorated class with added freezing capability.
"""
if not hasattr(cls, "__dict__"):
raise TypeError(
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
"class was defined with `@attrs.define(slots=False)`"
)
original_setattr = cls.__setattr__
def setattr_override(self, key, value) -> None: # noqa: ANN001
"""
Override __setattr__ to allow modifications during initialization
and prevent modifications once the instance is frozen.
"""
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
raise AttributeError("Cannot modify frozen instance")
original_setattr(self, key, value) # type: ignore
cls.__setattr__ = setattr_override # type: ignore
def freeze(self: object) -> None:
"""
Freeze the instance and all its attrs-defined attributes.
"""
for _, value in attrs.asdict(self, recurse=False).items():
if _is_attrs_instance(value) and hasattr(value, "freeze"):
value.freeze()
self._is_frozen = True # type: ignore
cls.freeze = freeze # type: ignore
return cls
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
"""
Recursively pretty prints attrs objects with color.
"""
assert attrs.has(obj.__class__)
lines: list[str] = []
for attribute in attrs.fields(obj.__class__):
value = getattr(obj, attribute.name)
if attrs.has(value.__class__):
if use_color:
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
else:
lines.append(" " * indent + "* " + attribute.name + ":")
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
else:
if use_color:
lines.append(
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
)
else:
lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
return "\n".join(lines)
def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str:
"""
Pretty prints overrides.
"""
lines: list[str] = []
lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
for override in overrides:
if override == "--":
continue
attribute_name, attribute_value = override.split("=")
if use_color:
lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
else:
lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
return "\n".join(lines)
@make_freezable
@attrs.define(slots=False)
class JobConfig:
# Project name.
project: str = ""
# Experiment name.
group: str = ""
# Run/job name.
name: str = ""
@property
def path(self) -> str:
return f"{self.project}/{self.group}/{self.name}"
@property
def path_local(self) -> str:
local_root = os.environ.get("OUTPUT_ROOT", "checkpoints")
return f"{local_root}/{self.path}"
@make_freezable
@attrs.define(slots=False)
class EMAConfig:
# Enable tracking a set of exponential moving average (EMA) weights.
enabled: bool = False
# EMA decay rate.
beta: float = 0.9999
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
torch_compile_buffer_renaming: bool = False
@make_freezable
@attrs.define(slots=False)
class DDPConfig:
# Traverse the computation graph to find parameters that don't receive gradients.
find_unused_parameters: bool = False
# Set to True if the computation graph does not change during the whole training loop.
static_graph: bool = True
# Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
broadcast_buffers: bool = True
@make_freezable
@attrs.define(slots=False)
class CuDNNConfig:
# Set to True for better reproducibility of the results (only using deterministic cudnn functions).
deterministic: bool = False
# If set to True, cudnn will benchmark several algorithms and pick the fastest one.
benchmark: bool = True
@make_freezable
@attrs.define(slots=False)
class JITConfig:
# Enable exporting a JIT compiled model.
enabled: bool = False
# Input tensor shape, for example input.
input_shape: Union[list[int], None] = None
# Device to compile onto.
device: str = "cuda"
# # Data type to compile onto.
dtype: str = "bfloat16"
# Strict mode for PyTorch JIT.
strict: bool = True
@make_freezable
@attrs.define(slots=False)
class CheckpointConfig:
# possible checkpoint class
type: Optional[Dict] = None
# for dcp, whether to use async mode
dcp_async_mode_enabled: bool = False
# Save the checkpoint every N iterations.
save_iter: int = 999999999
# Path of model weights to resume the checkpoint from.
load_path: str = ""
# Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
load_training_state: bool = False
# Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
only_load_scheduler_state: bool = False
# Load state_dict to the models in strict mode.
strict_resume: bool = True
# Print detailed information during checkpoint saving/loading.
verbose: bool = True
# Configs for JIT compiling EMA model.
jit: JITConfig = attrs.field(factory=JITConfig)
# keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
keys_not_to_resume: list[str] = []
# Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
broadcast_via_filesystem: bool = False
load_ema_to_reg: bool = False
async_saving: bool = True
@make_freezable
@attrs.define(slots=False)
class TrainerConfig:
from cosmos_predict1.utils.trainer import Trainer
type: Type[Trainer] = Trainer
# Set the callback class.
# Defaults to the callbacks below.
callbacks: LazyDict = LazyDict(
dict(
ema=L(callback.EMAModelCallback)(),
progress_bar=L(callback.ProgressBarCallback)(),
)
)
# distributed parallelism strategy
distributed_parallelism: str = "ddp"
# Distributed data parallel configs.
ddp: DDPConfig = attrs.field(factory=DDPConfig)
# cuDNN configs.
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
# Set the random seed.
seed: int = 0
# Gradient scaler arguments (for torch.amp.GradScaler).
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
# Maximum number of iterations to train the model.
max_iter: int = 999999999
# Maximum number of iterations to validate the model. If None, validate on the entire dataset.
max_val_iter: int | None = None
# How often we log the training stats.
logging_iter: int = 100
# Whether we want to run the validation routines.
run_validation: bool = True
# How often we evaluate on the validation set.
validation_iter: int = 999999999
# Kill the process after N seconds since the last iteration (usually means dead job).
timeout_period: int = 999999999
# Tensor memory organization format.
memory_format: torch.memory_format = torch.preserve_format
# Gradient accumulation (update step every N iteration).
grad_accum_iter: int = 1
# # Profiling config
# profiling: Profiling = attrs.field(factory=Profiling)
@make_freezable
@attrs.define(slots=False)
class Config:
"""Config for a job.
See /README.md/Configuration System for more info.
"""
# Model configs.
model: LazyDict
# Optimizer configs.
optimizer: LazyDict = LazyDict(dict(dummy=None))
# Scheduler configs.
scheduler: LazyDict = LazyDict(dict(dummy=None))
# Training data configs.
dataloader_train: LazyDict = LazyDict(dict(dummy=None))
# Validation data configs.
dataloader_val: LazyDict = LazyDict(dict(dummy=None))
# Training job configs.
job: JobConfig = attrs.field(factory=JobConfig)
# Trainer configs.
trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
# Megatron-Core configs
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
# Checkpointer configs.
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
def pretty_print(self, use_color: bool = False) -> str:
return _pretty_print_attrs_instance(self, 0, use_color)
# Training job configs.
job: JobConfig = attrs.field(factory=JobConfig)
def to_dict(self) -> dict[str, Any]:
return attrs.asdict(self)
def validate(self) -> None:
"""Validate that the config has all required fields."""
assert self.job.project != "", "Project name is required."
assert self.job.group != "", "Group name is required."
assert self.job.name != "", "Job name is required."