# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import collections import collections.abc import ctypes import functools import os from contextlib import contextmanager from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Container, Optional import pynvml import torch import torch.distributed as dist from torch.distributed import get_process_group_ranks from cosmos_predict1.utils import log from cosmos_predict1.utils.device import Device if TYPE_CHECKING: from cosmos_predict1.utils.config import DDPConfig if dist.is_available(): from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes try: from megatron.core import parallel_state except ImportError: print("Megatron-core is not installed.") def init() -> int | None: """Initialize distributed training.""" # Set GPU affinity. pynvml.nvmlInit() local_rank = int(os.getenv("LOCAL_RANK", 0)) device = Device(local_rank) # os.sched_setaffinity(0, device.get_cpu_affinity()) # Set up NCCL communication. os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" if dist.is_available(): if dist.is_initialized(): return torch.cuda.current_device() torch.cuda.set_device(local_rank) # Get the timeout value from environment variable timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) # Convert the timeout to an integer (if it isn't already) and then to a timedelta timeout_timedelta = timedelta(seconds=int(timeout_seconds)) dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) log.critical( f"Initialized distributed program with local rank {local_rank} with timeout {timeout_seconds}", rank0_only=False, ) # Increase the L2 fetch granularity for faster speed. _libcudart = ctypes.CDLL("libcudart.so") # Set device limit on the current device. p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) log.info(f"Running with {get_world_size()} GPUs.") def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: """Get the rank (GPU device) of the worker. Returns: rank (int): The rank of the worker. """ rank = 0 if dist.is_available() and dist.is_initialized(): rank = dist.get_rank(group) return rank def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: """Get world size. How many GPUs are available in this job. Returns: world_size (int): The total number of GPUs available in this job. """ world_size = 1 if dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size(group) return world_size def is_rank0() -> bool: """Check if current process is the master GPU. Returns: (bool): True if this function is called from the master GPU, else False. """ return get_rank() == 0 def is_local_rank0() -> bool: """Check if current process is the local master GPU in the current node. Returns: (bool): True if this function is called from the local master GPU, else False. """ return torch.cuda.current_device() == 0 def device_with_rank(device: str) -> str: """If the device is 'cuda' and parallelism over GPUs is enabled, returns Otherwise, returns the device as-is.""" if device == 'cuda': return f'cuda:{get_rank()}' return device def rank0_only(func: Callable) -> Callable: """Apply this function only to the master GPU. Example usage: @rank0_only def func(x): return x + 3 Args: func (Callable): a function. Returns: (Callable): A function wrapper executing the function only on the master GPU. """ @functools.wraps(func) def wrapper(*args, **kwargs): # noqa: ANN202 if is_rank0(): return func(*args, **kwargs) else: return None return wrapper def barrier() -> None: """Barrier for all GPUs.""" if dist.is_available() and dist.is_initialized(): dist.barrier() def rank0_first(func: Callable) -> Callable: """run the function on rank 0 first, then on other ranks.""" @functools.wraps(func) def wrapper(*args, **kwargs): # noqa: ANN202 if is_rank0(): result = func(*args, **kwargs) barrier() if not is_rank0(): result = func(*args, **kwargs) return result return wrapper def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: """Wraps the model to enable data parallalism for training across multiple GPU devices. Args: config_ddp (DDPConfig): The data parallel config. model (torch.nn.Module): The PyTorch module. Returns: model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper if distributed environment is available, otherwise return the original model. """ if dist.is_available() and dist.is_initialized(): local_rank = int(os.getenv("LOCAL_RANK", 0)) try: ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) except Exception as e: log.info(e) log.info("parallel_state not initialized, treating all GPUs equally for DDP") ddp_group = None model = DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=config_ddp.find_unused_parameters, static_graph=config_ddp.static_graph, broadcast_buffers=config_ddp.broadcast_buffers, process_group=ddp_group, ) return model class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> training_step), allowing us to preserve the function names and signatures. """ def __init__(self, model: torch.nn.Module, *args, **kwargs): super().__init__(model, *args, **kwargs) self.show_sync_grad_static_graph_warning = True def training_step(self, *args, **kwargs) -> Any: # Cache the original model.forward() method. original_forward = self.module.forward def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 # Unpatch immediately before calling training_step() because itself may want to call the real forward. self.module.forward = original_forward # The actual .training_step(). return self.module.training_step(*_args, **_kwargs) # Patch the original_module's forward so we can redirect the arguments back to the real method. self.module.forward = wrapped_training_step # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. return self(*args, **kwargs) @contextmanager def ddp_sync_grad(model, enabled): r""" Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. Modified from: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context. .. warning:: The forward pass should be included inside the context manager, or else gradients will still be synchronized. """ assert isinstance(model, torch.nn.Module) if isinstance(model, DistributedDataParallel): old_require_backward_grad_sync = model.require_backward_grad_sync if model.static_graph and model.require_backward_grad_sync != enabled: if model.show_sync_grad_static_graph_warning: log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") model.show_sync_grad_static_graph_warning = False else: model.require_backward_grad_sync = enabled try: yield finally: if isinstance(model, DistributedDataParallel): model.require_backward_grad_sync = old_require_backward_grad_sync def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: """Aggregate the list of data batches from all devices and process the results. This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. It will return the data/output of the entire validation set in its original index order. The sizes of data_batches in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be created before calling dis.all_gather(). Args: data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where leaf entries are tensors. Returns: data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where leaf entries are concatenated tensors. """ if isinstance(data_batches[0], torch.Tensor): # Concatenate the local data batches. data_concat = torch.cat(data_batches, dim=0) # type: ignore # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. max_num_local_samples = torch.tensor(len(data_concat), device="cuda") dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) if len(data_concat) < max_num_local_samples: assert len(data_concat) + 1 == max_num_local_samples dummy = torch.empty_like(data_concat[:1]) data_concat = torch.cat([data_concat, dummy], dim=0) dummy_count = torch.tensor(1, device="cuda") else: dummy_count = torch.tensor(0, device="cuda") # Get all concatenated batches from all ranks and concatenate again. dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) data_concat = all_gather_tensor(data_concat.contiguous()) data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) # Remove the dummy samples. if dummy_count > 0: data_collate = data_collate[:-dummy_count] elif isinstance(data_batches[0], collections.abc.Mapping): data_collate = dict() for key in data_batches[0].keys(): data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore else: raise TypeError return data_collate @torch.no_grad() def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: """Gather the corresponding tensor from all GPU devices to a list. Args: tensor (torch.Tensor): Pytorch tensor. Returns: tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. """ tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] dist.all_gather(tensor_list, tensor) return tensor_list def broadcast(tensor, src, group=None, async_op=False): world_size = get_world_size() if world_size < 2: return tensor dist.broadcast(tensor, src=src, group=group, async_op=async_op) def sync_model_states( model: torch.nn.Module, process_group: Optional[dist.ProcessGroup] = None, src: int = 0, params_and_buffers_to_ignore: Optional[Container[str]] = None, broadcast_buffers: bool = True, ): """ Modify based on DDP source code Synchronizes the parameters and buffers of a model across different processes in a distributed setting. This function ensures that all processes in the specified process group have the same initial parameters and buffers from the source rank, typically rank 0. It is useful when different processes start with different model states and a synchronization is required to ensure consistency across all ranks. Args: model (nn.Module): The model whose parameters and buffers are to be synchronized. process_group (dist.ProcessGroup, optional): The process group for communication. If None, the default group is used. Defaults to None. src (int, optional): The source rank from which parameters and buffers will be broadcasted. Defaults to 0. params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer names to exclude from synchronization. Defaults to None, which means all parameters and buffers are included. broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True. Side Effects: This function modifies the state of the model in-place to synchronize it with the source rank's model state. Raises: RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised. Examples: >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth >>> # useful and save our time when model weights are huge >>> if dist.get_rank == 0: >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path)) >>> dist.barrir() >>> sync_model_states(model) # sync rank0 weights to other ranks """ if process_group is None: process_group = _get_default_group() if not params_and_buffers_to_ignore: params_and_buffers_to_ignore = set() log.info( f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}." ) # Build tuple of (module, parameter) for all parameters that require grads. modules_and_parameters = [ (module, parameter) for module_name, module in model.named_modules() for parameter in [ param # Note that we access module.named_parameters instead of # parameters(module). parameters(module) is only needed in the # single-process multi device case, where it accesses replicated # parameters through _former_parameters. for param_name, param in module.named_parameters(recurse=False) if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore # if param.requires_grad # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore ] ] # Deduplicate any parameters that might be shared across child modules. memo = set() modules_and_parameters = [ # "p not in memo" is the deduplication check. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. (m, p) for m, p in modules_and_parameters if p not in memo and not memo.add(p) # type: ignore[func-returns-value] ] # Build list of parameters. parameters = [parameter for _, parameter in modules_and_parameters] if len(parameters) == 0: return _verify_param_shape_across_processes(process_group, parameters) _sync_module_states( module=model, process_group=process_group, broadcast_bucket_size=int(250 * 1024 * 1024), src=src, params_and_buffers_to_ignore=params_and_buffers_to_ignore, broadcast_buffers=broadcast_buffers, ) def dist_reduce_tensor(tensor, rank=0, reduce="mean"): r"""Reduce to rank 0""" world_size = get_world_size() if world_size < 2: return tensor with torch.no_grad(): dist.reduce(tensor, dst=rank) if get_rank() == rank: if reduce == "mean": tensor /= world_size elif reduce == "sum": pass else: raise NotImplementedError return tensor