|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
from typing import Any, Dict, Optional, Type, TypeVar, Union |
|
|
|
import attrs |
|
import torch |
|
|
|
try: |
|
from megatron.core import ModelParallelConfig |
|
|
|
USE_MEGATRON = True |
|
except ImportError: |
|
USE_MEGATRON = False |
|
print("Megatron-core is not installed.") |
|
|
|
from cosmos_transfer1.utils.callback import EMAModelCallback, ProgressBarCallback |
|
from cosmos_transfer1.utils.ddp_config import DDPConfig, make_freezable |
|
from cosmos_transfer1.utils.lazy_config import LazyCall as L |
|
from cosmos_transfer1.utils.lazy_config import LazyDict |
|
from cosmos_transfer1.utils.misc import Color |
|
|
|
|
|
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: |
|
""" |
|
Recursively pretty prints attrs objects with color. |
|
""" |
|
|
|
assert attrs.has(obj.__class__) |
|
|
|
lines: list[str] = [] |
|
for attribute in attrs.fields(obj.__class__): |
|
value = getattr(obj, attribute.name) |
|
if attrs.has(value.__class__): |
|
if use_color: |
|
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") |
|
else: |
|
lines.append(" " * indent + "* " + attribute.name + ":") |
|
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) |
|
else: |
|
if use_color: |
|
lines.append( |
|
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) |
|
) |
|
else: |
|
lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) |
|
return "\n".join(lines) |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class JobConfig: |
|
|
|
project: str = "" |
|
|
|
group: str = "" |
|
|
|
name: str = "" |
|
|
|
@property |
|
def path(self) -> str: |
|
return f"{self.project}/{self.group}/{self.name}" |
|
|
|
@property |
|
def path_local(self) -> str: |
|
local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") |
|
return f"{local_root}/{self.path}" |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class EMAConfig: |
|
|
|
enabled: bool = False |
|
|
|
beta: float = 0.9999 |
|
|
|
torch_compile_buffer_renaming: bool = False |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class CuDNNConfig: |
|
|
|
deterministic: bool = False |
|
|
|
benchmark: bool = True |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class JITConfig: |
|
|
|
enabled: bool = False |
|
|
|
input_shape: Union[list[int], None] = None |
|
|
|
device: str = "cuda" |
|
|
|
dtype: str = "bfloat16" |
|
|
|
strict: bool = True |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class CheckpointConfig: |
|
|
|
type: Optional[Dict] = None |
|
|
|
dcp_async_mode_enabled: bool = False |
|
|
|
save_iter: int = 999999999 |
|
|
|
load_path: str = "" |
|
|
|
load_training_state: bool = False |
|
|
|
only_load_scheduler_state: bool = False |
|
|
|
strict_resume: bool = True |
|
|
|
verbose: bool = True |
|
|
|
jit: JITConfig = attrs.field(factory=JITConfig) |
|
|
|
keys_not_to_resume: list[str] = [] |
|
|
|
broadcast_via_filesystem: bool = False |
|
load_ema_to_reg: bool = False |
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class TrainerConfig: |
|
from cosmos_transfer1.utils.trainer import Trainer |
|
|
|
type: Type[Trainer] = Trainer |
|
|
|
|
|
callbacks: LazyDict = LazyDict( |
|
dict( |
|
ema=L(EMAModelCallback)(), |
|
progress_bar=L(ProgressBarCallback)(), |
|
) |
|
) |
|
|
|
distributed_parallelism: str = "ddp" |
|
|
|
ddp: DDPConfig = attrs.field(factory=DDPConfig) |
|
|
|
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) |
|
|
|
seed: int = 0 |
|
|
|
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) |
|
|
|
max_iter: int = 999999999 |
|
|
|
max_val_iter: int | None = None |
|
|
|
logging_iter: int = 100 |
|
|
|
run_validation: bool = True |
|
|
|
validation_iter: int = 999999999 |
|
|
|
timeout_period: int = 999999999 |
|
|
|
memory_format: torch.memory_format = torch.preserve_format |
|
|
|
grad_accum_iter: int = 1 |
|
|
|
timestamp_seed: bool = True |
|
|
|
|
|
|
|
|
|
@make_freezable |
|
@attrs.define(slots=False) |
|
class Config: |
|
"""Config for a job. |
|
|
|
See /README.md/Configuration System for more info. |
|
""" |
|
|
|
|
|
model: LazyDict |
|
|
|
optimizer: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
scheduler: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
dataloader_train: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
dataloader_val: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
|
|
job: JobConfig = attrs.field(factory=JobConfig) |
|
|
|
|
|
trainer: TrainerConfig = attrs.field(factory=TrainerConfig) |
|
|
|
|
|
if USE_MEGATRON: |
|
|
|
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) |
|
else: |
|
model_parallel: None = None |
|
|
|
|
|
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) |
|
|
|
def pretty_print(self, use_color: bool = False) -> str: |
|
return _pretty_print_attrs_instance(self, 0, use_color) |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
return attrs.asdict(self) |
|
|
|
def validate(self) -> None: |
|
"""Validate that the config has all required fields.""" |
|
assert self.job.project != "", "Project name is required." |
|
assert self.job.group != "", "Group name is required." |
|
assert self.job.name != "", "Job name is required." |
|
|