Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
from contextlib import contextmanager | |
from functools import partial | |
import torch | |
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
CheckpointImpl, | |
apply_activation_checkpointing, | |
checkpoint_wrapper, | |
) | |
from torch.distributed.device_mesh import init_device_mesh | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp._runtime_utils import ( | |
_post_forward, | |
_post_forward_reshard, | |
_pre_forward, | |
_pre_forward_unshard, | |
_root_pre_forward, | |
) | |
from torch.distributed.utils import _p_assert | |
from cosmos_predict1.utils import distributed, log | |
def apply_fsdp_checkpointing(model, list_block_cls): | |
"""apply activation checkpointing to model | |
returns None as model is updated directly | |
""" | |
log.critical("--> applying fdsp activation checkpointing...") | |
non_reentrant_wrapper = partial( | |
checkpoint_wrapper, | |
# offload_to_cpu=False, | |
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
) | |
def check_fn(submodule): | |
result = False | |
for block_cls in list_block_cls: | |
if isinstance(submodule, block_cls): | |
result = True | |
break | |
return result | |
apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) | |
def possible_fsdp_scope( | |
model: torch.nn.Module, | |
): | |
enabled = isinstance(model, FSDP) | |
if enabled: | |
assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" | |
handle = model._handle | |
args, kwargs = [0], dict(dummy=0) | |
with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): | |
args, kwargs = _root_pre_forward(model, model, args, kwargs) | |
unused = None | |
args, kwargs = _pre_forward( | |
model, | |
handle, | |
_pre_forward_unshard, | |
model._fsdp_wrapped_module, | |
args, | |
kwargs, | |
) | |
if handle: | |
_p_assert( | |
handle.flat_param.device == model.compute_device, | |
"Expected `FlatParameter` to be on the compute device " | |
f"{model.compute_device} but got {handle.flat_param.device}", | |
) | |
try: | |
yield None | |
finally: | |
if enabled: | |
output = {"output": 1} | |
_post_forward(model, handle, _post_forward_reshard, model, unused, output) | |
def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): | |
""" | |
Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. | |
This function requires explicit sizes for replica and sharding groups to accommodate models | |
whose GPU fit is unknown, providing flexibility in distributed training setups. | |
Args: | |
replica_group_size (int): The size of each replica group. Must be provided to ensure | |
the model fits within the available resources. | |
sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to | |
ensure the correct distribution of model parameters. | |
device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" | |
with the local rank as the device index. | |
Returns: | |
A device mesh object compatible with FSDP. | |
Raises: | |
ValueError: If replica_group_size or sharding_group_size are not provided, or if the | |
world size is not evenly divisible by the sharding group size. | |
RuntimeError: If a valid device mesh cannot be created. | |
Usage: | |
If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: | |
Sharding_Group_Size = 4 | |
Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups | |
>>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) | |
>>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) | |
""" | |
# world_size = int(os.getenv("WORLD_SIZE", "1")) | |
world_size = distributed.get_world_size() | |
if sharding_group_size is None: | |
sharding_group_size = min(world_size, 8) | |
sharding_group_size = min(sharding_group_size, world_size) | |
if replica_group_size is None: | |
replica_group_size = world_size // sharding_group_size | |
device = device or "cuda" | |
if world_size % sharding_group_size != 0: | |
raise ValueError( | |
f"World size {world_size} is not evenly divisible by " f"sharding group size {sharding_group_size}." | |
) | |
if (world_size // sharding_group_size) % replica_group_size != 0: | |
raise ValueError( | |
f"The calculated number of replica groups is not evenly divisible by " | |
f"replica_group_size {replica_group_size}." | |
) | |
device_mesh = init_device_mesh( | |
device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard") | |
) | |
if device_mesh is None: | |
raise RuntimeError("Failed to create a valid device mesh.") | |
log.critical( | |
f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}" | |
) | |
return device_mesh | |