Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import os | |
from typing import Any, Dict, Optional, Type, TypeVar, Union | |
import attrs | |
import torch | |
from megatron.core import ModelParallelConfig | |
from cosmos_predict1.utils import callback | |
from cosmos_predict1.utils.lazy_config import LazyCall as L | |
from cosmos_predict1.utils.lazy_config import LazyDict | |
from cosmos_predict1.utils.misc import Color | |
T = TypeVar("T") | |
def _is_attrs_instance(obj: object) -> bool: | |
""" | |
Helper function to check if an object is an instance of an attrs-defined class. | |
Args: | |
obj: The object to check. | |
Returns: | |
bool: True if the object is an instance of an attrs-defined class, False otherwise. | |
""" | |
return hasattr(obj, "__attrs_attrs__") | |
def make_freezable(cls: T) -> T: | |
""" | |
A decorator that adds the capability to freeze instances of an attrs-defined class. | |
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need | |
to hack on a "_is_frozen" attribute. | |
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. | |
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes | |
any attrs-defined objects that are attributes of the class. | |
Usage: | |
@make_freezable | |
@attrs.define(slots=False) | |
class MyClass: | |
attribute1: int | |
attribute2: str | |
obj = MyClass(1, 'a') | |
obj.freeze() # Freeze the instance | |
obj.attribute1 = 2 # Raises AttributeError | |
Args: | |
cls: The class to be decorated. | |
Returns: | |
The decorated class with added freezing capability. | |
""" | |
if not hasattr(cls, "__dict__"): | |
raise TypeError( | |
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " | |
"class was defined with `@attrs.define(slots=False)`" | |
) | |
original_setattr = cls.__setattr__ | |
def setattr_override(self, key, value) -> None: # noqa: ANN001 | |
""" | |
Override __setattr__ to allow modifications during initialization | |
and prevent modifications once the instance is frozen. | |
""" | |
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": | |
raise AttributeError("Cannot modify frozen instance") | |
original_setattr(self, key, value) # type: ignore | |
cls.__setattr__ = setattr_override # type: ignore | |
def freeze(self: object) -> None: | |
""" | |
Freeze the instance and all its attrs-defined attributes. | |
""" | |
for _, value in attrs.asdict(self, recurse=False).items(): | |
if _is_attrs_instance(value) and hasattr(value, "freeze"): | |
value.freeze() | |
self._is_frozen = True # type: ignore | |
cls.freeze = freeze # type: ignore | |
return cls | |
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: | |
""" | |
Recursively pretty prints attrs objects with color. | |
""" | |
assert attrs.has(obj.__class__) | |
lines: list[str] = [] | |
for attribute in attrs.fields(obj.__class__): | |
value = getattr(obj, attribute.name) | |
if attrs.has(value.__class__): | |
if use_color: | |
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") | |
else: | |
lines.append(" " * indent + "* " + attribute.name + ":") | |
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) | |
else: | |
if use_color: | |
lines.append( | |
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) | |
) | |
else: | |
lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) | |
return "\n".join(lines) | |
def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str: | |
""" | |
Pretty prints overrides. | |
""" | |
lines: list[str] = [] | |
lines.append(Color.cyan("* ") + Color.green("overrides") + ": ") | |
for override in overrides: | |
if override == "--": | |
continue | |
attribute_name, attribute_value = override.split("=") | |
if use_color: | |
lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value)) | |
else: | |
lines.append(" " + "* " + attribute_name + ": " + str(attribute_value)) | |
return "\n".join(lines) | |
class JobConfig: | |
# Project name. | |
project: str = "" | |
# Experiment name. | |
group: str = "" | |
# Run/job name. | |
name: str = "" | |
def path(self) -> str: | |
return f"{self.project}/{self.group}/{self.name}" | |
def path_local(self) -> str: | |
local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") | |
return f"{local_root}/{self.path}" | |
class EMAConfig: | |
# Enable tracking a set of exponential moving average (EMA) weights. | |
enabled: bool = False | |
# EMA decay rate. | |
beta: float = 0.9999 | |
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile | |
torch_compile_buffer_renaming: bool = False | |
class DDPConfig: | |
# Traverse the computation graph to find parameters that don't receive gradients. | |
find_unused_parameters: bool = False | |
# Set to True if the computation graph does not change during the whole training loop. | |
static_graph: bool = True | |
# Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. | |
broadcast_buffers: bool = True | |
class CuDNNConfig: | |
# Set to True for better reproducibility of the results (only using deterministic cudnn functions). | |
deterministic: bool = False | |
# If set to True, cudnn will benchmark several algorithms and pick the fastest one. | |
benchmark: bool = True | |
class JITConfig: | |
# Enable exporting a JIT compiled model. | |
enabled: bool = False | |
# Input tensor shape, for example input. | |
input_shape: Union[list[int], None] = None | |
# Device to compile onto. | |
device: str = "cuda" | |
# # Data type to compile onto. | |
dtype: str = "bfloat16" | |
# Strict mode for PyTorch JIT. | |
strict: bool = True | |
class CheckpointConfig: | |
# possible checkpoint class | |
type: Optional[Dict] = None | |
# for dcp, whether to use async mode | |
dcp_async_mode_enabled: bool = False | |
# Save the checkpoint every N iterations. | |
save_iter: int = 999999999 | |
# Path of model weights to resume the checkpoint from. | |
load_path: str = "" | |
# Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path. | |
load_training_state: bool = False | |
# Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored. | |
only_load_scheduler_state: bool = False | |
# Load state_dict to the models in strict mode. | |
strict_resume: bool = True | |
# Print detailed information during checkpoint saving/loading. | |
verbose: bool = True | |
# Configs for JIT compiling EMA model. | |
jit: JITConfig = attrs.field(factory=JITConfig) | |
# keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"] | |
keys_not_to_resume: list[str] = [] | |
# Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer). | |
broadcast_via_filesystem: bool = False | |
load_ema_to_reg: bool = False | |
async_saving: bool = True | |
class TrainerConfig: | |
from cosmos_predict1.utils.trainer import Trainer | |
type: Type[Trainer] = Trainer | |
# Set the callback class. | |
# Defaults to the callbacks below. | |
callbacks: LazyDict = LazyDict( | |
dict( | |
ema=L(callback.EMAModelCallback)(), | |
progress_bar=L(callback.ProgressBarCallback)(), | |
) | |
) | |
# distributed parallelism strategy | |
distributed_parallelism: str = "ddp" | |
# Distributed data parallel configs. | |
ddp: DDPConfig = attrs.field(factory=DDPConfig) | |
# cuDNN configs. | |
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) | |
# Set the random seed. | |
seed: int = 0 | |
# Gradient scaler arguments (for torch.amp.GradScaler). | |
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) | |
# Maximum number of iterations to train the model. | |
max_iter: int = 999999999 | |
# Maximum number of iterations to validate the model. If None, validate on the entire dataset. | |
max_val_iter: int | None = None | |
# How often we log the training stats. | |
logging_iter: int = 100 | |
# Whether we want to run the validation routines. | |
run_validation: bool = True | |
# How often we evaluate on the validation set. | |
validation_iter: int = 999999999 | |
# Kill the process after N seconds since the last iteration (usually means dead job). | |
timeout_period: int = 999999999 | |
# Tensor memory organization format. | |
memory_format: torch.memory_format = torch.preserve_format | |
# Gradient accumulation (update step every N iteration). | |
grad_accum_iter: int = 1 | |
# # Profiling config | |
# profiling: Profiling = attrs.field(factory=Profiling) | |
class Config: | |
"""Config for a job. | |
See /README.md/Configuration System for more info. | |
""" | |
# Model configs. | |
model: LazyDict | |
# Optimizer configs. | |
optimizer: LazyDict = LazyDict(dict(dummy=None)) | |
# Scheduler configs. | |
scheduler: LazyDict = LazyDict(dict(dummy=None)) | |
# Training data configs. | |
dataloader_train: LazyDict = LazyDict(dict(dummy=None)) | |
# Validation data configs. | |
dataloader_val: LazyDict = LazyDict(dict(dummy=None)) | |
# Training job configs. | |
job: JobConfig = attrs.field(factory=JobConfig) | |
# Trainer configs. | |
trainer: TrainerConfig = attrs.field(factory=TrainerConfig) | |
# Megatron-Core configs | |
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) | |
# Checkpointer configs. | |
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) | |
def pretty_print(self, use_color: bool = False) -> str: | |
return _pretty_print_attrs_instance(self, 0, use_color) | |
# Training job configs. | |
job: JobConfig = attrs.field(factory=JobConfig) | |
def to_dict(self) -> dict[str, Any]: | |
return attrs.asdict(self) | |
def validate(self) -> None: | |
"""Validate that the config has all required fields.""" | |
assert self.job.project != "", "Project name is required." | |
assert self.job.group != "", "Group name is required." | |
assert self.job.name != "", "Job name is required." | |