import datetime import functools import os import pathlib import shutil import time from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import datasets.distributed import torch import torch.distributed._functional_collectives import torch.distributed.checkpoint import torch.distributed.checkpoint.stateful from diffusers.hooks import HookRegistry, ModelHook from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard from torch.distributed._composable.replicate import replicate from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, set_model_state_dict, ) from torch.distributed.tensor import DTensor, Shard from finetrainers._metadata import ContextParallelModelPlan, CPInput, CPOutput, TransformerRegistry from finetrainers.data import DPDataLoader from finetrainers.logging import get_logger from finetrainers.utils import enable_determinism, get_device_info, get_submodule_by_name, unwrap_module from finetrainers.utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES from .base import BaseCheckpointer, BaseParallelBackend if TYPE_CHECKING: from finetrainers import optimizer _device_type, _device_module = get_device_info() logger = get_logger() class PytorchDTensorParallelBackend(BaseParallelBackend): def __init__( self, world_size: int, pp_degree: int = 1, dp_degree: int = 1, dp_shards: int = -1, cp_degree: int = 1, tp_degree: int = 1, backend: str = "nccl", timeout: int = 180, logging_dir: Optional[str] = None, output_dir: Optional[str] = None, gradient_accumulation_steps: Optional[int] = None, ) -> None: super().__init__() self._world_size = world_size self._pp_degree = pp_degree self._dp_degree = dp_degree self._dp_shards = dp_shards self._cp_degree = cp_degree self._tp_degree = tp_degree self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None self._logging_dir = ( self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None ) self._backend = backend self._timeout = timeout for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]: if degree < 1: raise ValueError(f"Parallel degree must be at least 1, got {degree}.") if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size: raise ValueError( f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards." ) torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)) _device_module.set_device(self.local_rank) logger.info( f"Initialized parallel state with:\n" f" - World size: {world_size}\n" f" - Pipeline parallel degree: {pp_degree}\n" f" - Data parallel degree: {dp_degree}\n" f" - Context parallel degree: {cp_degree}\n" f" - Tensor parallel degree: {tp_degree}\n" f" - Data parallel shards: {dp_shards}\n" ) self._mesh: torch.distributed.DeviceMesh = None def enable_determinism(self, seed): world_mesh = self.get_mesh() enable_determinism(seed, world_mesh) def apply_ddp( self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None ) -> torch.nn.Module: if device_mesh is None: device_mesh = self.get_mesh() apply_ddp(model, device_mesh) logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.") return model def apply_fsdp2( self, model: torch.nn.Module, param_dtype: torch.dtype, reduce_dtype: torch.dtype, output_dtype: torch.dtype, pp_enabled: bool = False, cpu_offload: bool = False, device_mesh: Optional[torch.distributed.DeviceMesh] = None, ) -> torch.nn.Module: if device_mesh is None: device_mesh = self.get_mesh() apply_fsdp2(model, device_mesh, param_dtype, reduce_dtype, output_dtype, pp_enabled, cpu_offload) logger.debug("Applied PytorchDTensorParallel::apply_fsdp2 to model.") return model def apply_context_parallel( self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None ) -> torch.nn.Module: if device_mesh is None: device_mesh = self.get_mesh() apply_context_parallel(model, device_mesh) logger.debug("Applied PytorchDTensorParallel::apply_context_parallel to model.") return model def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module: return model def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: if self._dp_degree == 1: return dataset dp_mesh = self.get_mesh()["dp_replicate"] dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!") return dataset def prepare_dataloader( self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool ) -> DPDataLoader: if self._dp_degree == 1: dp_local_rank = 0 else: dp_mesh = self.get_mesh()["dp_replicate"] dp_local_rank = dp_mesh.get_local_rank() dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers) logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!") return dataloader def prepare_optimizer(self, optimizer, lr_scheduler): logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!") return optimizer, lr_scheduler def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: def _get_mesh(): if name is None: return self._mesh try: return self._mesh[name] except (KeyError, RuntimeError): if self._mesh.ndim == 0: return None return self._mesh if self._mesh is not None: return _get_mesh() mesh_list = [ ("pp", self._pp_degree), ("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards), ("cp", self._cp_degree), ("tp", self._tp_degree), ] mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] names = [x[0] for x in mesh_list] degrees = [x[1] for x in mesh_list] mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] if self.data_replication_enabled: dp_mesh_names.append("dp_replicate") dp_cp_mesh_names.append("dp_replicate") if self.data_sharding_enabled: dp_mesh_names.append("dp_shard") dp_cp_mesh_names.append("dp_shard") dp_shard_cp_mesh_names.append("dp_shard") if self.context_parallel_enabled: dp_cp_mesh_names.append("cp") dp_shard_cp_mesh_names.append("cp") if len(dp_mesh_names) > 0: mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") if len(dp_cp_mesh_names) > 0: mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") if len(dp_shard_cp_mesh_names) > 0: mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") logger.debug(f"Device mesh: {mesh}") self._mesh = mesh return _get_mesh() def get_checkpointer(self, *args, **kwargs): return PTDCheckpointer(*args, **kwargs) @property def world_size(self): return torch.distributed.get_world_size() @property def rank(self): return torch.distributed.get_rank() @property def local_rank(self): return int(os.environ.get("LOCAL_RANK", 0)) @property def is_main_process(self): r"""Returns `True` if the current process is the main process on the master node.""" return self.rank == 0 @property def is_local_main_process(self): r"""Returns `True` if the current process is the main process on local node.""" return self.local_rank == 0 @property def device(self): return torch.device(_device_type, self.local_rank) def wait_for_everyone(self): return torch.distributed.barrier() # @contextmanager # def main_process_first(self): # if self.is_main_process: # yield # self.wait_for_everyone() # else: # self.wait_for_everyone() # yield def destroy(self): if self.is_main_process and self.tracker is not None: self.tracker.finish() return torch.distributed.destroy_process_group() @property def pipeline_parallel_enabled(self): return self._pp_degree > 1 @property def data_parallel_enabled(self): return self._dp_degree > 1 or self._dp_shards > 1 @property def data_replication_enabled(self): return self._dp_degree > 1 @property def data_sharding_enabled(self): return self._dp_shards > 1 @property def context_parallel_enabled(self): return self._cp_degree > 1 @property def tensor_parallel_enabled(self): return self._tp_degree > 1 class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful): def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: self.model = [model] if isinstance(model, torch.nn.Module) else model def state_dict(self) -> Dict[str, Any]: return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: func = functools.partial( set_model_state_dict, model_state_dict=state_dict, options=StateDictOptions(strict=False), ) list(map(func, self.model)) class PTDCheckpointer(BaseCheckpointer): def __init__( self, dataloader: torch.utils.data.DataLoader, model_parts: List[torch.nn.Module], optimizers: "optimizer.OptimizerWrapper", schedulers: "optimizer.SchedulerWrapper", states: Dict[str, Any], checkpointing_steps: int, checkpointing_limit: int, output_dir: str, enable: bool = True, _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, _prefix: str = "finetrainers_step", ) -> None: self.states = states self.states.update( { "model": ModelWrapper(model_parts), "optimizer": optimizers, "dataloader": dataloader, } ) self.states.update(schedulers.get_lr_scheduler_state()) self.checkpointing_steps = checkpointing_steps self.checkpointing_limit = checkpointing_limit self.output_dir = pathlib.Path(output_dir) self.enable = enable self._callback_fn = _callback_fn self._prefix = _prefix logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: if not self._should_checkpoint(step, force): return None checkpoint_dir = self._get_checkpoint_dir(step) begin_time = time.monotonic() torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) end_time = time.monotonic() logger.info( f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" ) self._purge_stale_checkpoints() state_dicts = [ gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) for model in self.states["model"].model ] if self._callback_fn is not None: list(map(self._callback_fn, state_dicts)) return checkpoint_dir.as_posix() def load(self, step: int = -1) -> bool: if not self.enable: return False if not self.output_dir.exists(): return False if step != -1 and not self._get_checkpoint_dir(step).exists(): return False if step == -1: latest_checkpoint_dir = self._find_latest_checkpoint_dir() if latest_checkpoint_dir is None: return False step = int(latest_checkpoint_dir.name.split("_")[-1]) checkpoint_dir = self._get_checkpoint_dir(step) logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") # For step 0, optimizers/schedulers are not available as they are created during training after first step states = {"model": self.states["model"]} if step == 0 else self.states # See bug: https://github.com/pytorch/pytorch/pull/138575 original_stateful_states = { k: v for k, v in states.items() if isinstance(v, torch.distributed.checkpoint.stateful.Stateful) } begin_time = time.monotonic() torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) end_time = time.monotonic() logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() states.update(original_stateful_states) return True def _should_checkpoint(self, step: int, force: bool) -> bool: if not self.enable: return False if not force: if step % self.checkpointing_steps != 0: return False return True def _get_checkpoint_dir(self, step: int) -> pathlib.Path: return self.output_dir / f"{self._prefix}_{step}" def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]: checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) return checkpoints[-1] if len(checkpoints) > 0 else None def _purge_stale_checkpoints(self) -> None: if self.checkpointing_limit is None or self.checkpointing_limit <= 0: return checkpoints = sorted( self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True ) for checkpoint in checkpoints[self.checkpointing_limit :]: logger.info(f"Deleting stale checkpoint: {checkpoint}") shutil.rmtree(checkpoint, ignore_errors=True) def gather_state_dict_on_cpu_rank0( model, device: Optional[torch.device] = None, *, is_main_process: bool ) -> Dict[str, Any]: cpu_state_dict = {} sharded_sd = model.state_dict() for param_name, param in sharded_sd.items(): if param.is_cpu: # Move back to device if offloaded to CPU param = param.to(device) if hasattr(param, "_local_tensor"): # Gather DTensor param = param.full_tensor() if is_main_process: cpu_state_dict[param_name] = param.cpu() torch.distributed.barrier() return cpu_state_dict # # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict # def dcp_to_torch_save( # dcp_checkpoint_dir: Union[str, os.PathLike], # torch_save_path: Union[str, os.PathLike], # callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, # ): # """ # Given a directory containing a DCP checkpoint, this function will convert it into a # Torch save file. # Args: # dcp_checkpoint_dir: Directory containing the DCP checkpoint. # torch_save_path: Filename to store the converted Torch save file. # callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. # .. warning:: # To avoid OOM, it's recommended to only run this function on a single rank. # """ # state_dict = {} # _load_state_dict( # state_dict, # storage_reader=FileSystemReader(dcp_checkpoint_dir), # planner=_EmptyStateDictLoadPlanner(), # no_dist=True, # ) # if callback_fn is not None: # state_dict = callback_fn(state_dict) # torch.save(state_dict, torch_save_path) def apply_ddp(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) def apply_fsdp2( model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh, param_dtype: torch.dtype, reduce_dtype: torch.dtype, output_dtype: torch.dtype, pp_enabled: bool = False, cpu_offload: bool = False, ) -> None: """Apply FSDP2 on a model.""" mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) def apply_fully_shard(blocks): for layer_index, block in enumerate(blocks): if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False else: # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = layer_index < len(blocks) - 1 fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: blocks = getattr(model, transformer_block_name, None) if blocks is not None: apply_fully_shard(blocks) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) def apply_context_parallel( model: torch.nn.Module, mesh: torch.distributed.device_mesh.DeviceMesh, plan: Optional[Dict[str, ContextParallelModelPlan]] = None, ) -> None: """Apply context parallel on a model.""" logger.debug(f"Applying context parallel with CP mesh: {mesh}") model_cls = unwrap_module(model).__class__ if plan is None: plan = TransformerRegistry.get(model_cls).cp_plan for module_id, cp_model_plan in plan.items(): module = get_submodule_by_name(model, module_id) if not isinstance(module, list): module = [module] logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(module)} modules") for m in module: registry = HookRegistry.check_if_exists_or_initialize(m) if isinstance(cp_model_plan, list): # Metadata can only be a list when it is a list of CPOutput assert all(isinstance(x, CPOutput) for x in cp_model_plan) hook = ContextParallelGatherHook(cp_model_plan, mesh) hook_name = f"cp_output---{module_id}" else: hook = ContextParallelSplitHook(cp_model_plan, mesh) hook_name = f"cp_input---{module_id}" registry.register_hook(hook, hook_name) class ContextParallelSplitHook(ModelHook): def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None: super().__init__() self.metadata = metadata self.mesh = mesh def pre_forward(self, module, *args, **kwargs): args_list = list(args) for param_identifier, cpm in self.metadata.items(): name = param_identifier.name index = param_identifier.index if isinstance(cpm, CPInput) and cpm.split_output: continue # Maybe the parameter was passed as a keyword argument is_kwarg = True input_val = kwargs.get(name, None) # If not, maybe it was passed as a positional argument if input_val is None and index is not None: if index < len(args_list): # Ensure index is within bounds input_val = args_list[index] is_kwarg = False else: logger.warning(f"Index {index} out of bounds for args of length {len(args_list)}.") continue # Skip if index is invalid # Either the input_val is truly None, or argument is passed as normal argument # but user forgot to specify the index when registering metadata if input_val is None: continue # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard # the output instead of input for a particular layer by setting split_output=True if torch.is_tensor(input_val): input_val = self._prepare_cp_input(input_val, cpm) elif isinstance(input_val, (list, tuple)): if len(input_val) != len(cpm): raise ValueError( f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." ) sharded_input_val = [] for i, x in enumerate(input_val): if torch.is_tensor(x) and not cpm[i].split_output: x = self._prepare_cp_input(x, cpm[i]) sharded_input_val.append(x) input_val = sharded_input_val else: raise ValueError(f"Unsupported input type: {type(input_val)}") if is_kwarg: kwargs[name] = input_val elif index is not None and index < len(args_list): args_list[index] = input_val return tuple(args_list), kwargs def post_forward(self, module, output): is_tensor = torch.is_tensor(output) is_tensor_list = isinstance(output, (list, tuple)) and all(torch.is_tensor(x) for x in output) if not is_tensor and not is_tensor_list: raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") output = [output] if is_tensor else list(output) for param_identifier, cpm in self.metadata.items(): if not isinstance(cpm, CPInput) or not cpm.split_output: continue index = param_identifier.index if index >= len(output): raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") current_output = output[index] current_output = self._prepare_cp_input(current_output, cpm) output[index] = current_output return output[0] if is_tensor else tuple(output) def _prepare_cp_input(self, x: torch.Tensor, cp_input: CPInput) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: raise ValueError( f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." ) return _EquipartitionSharder.shard(x, cp_input.split_dim, self.mesh) class ContextParallelGatherHook(ModelHook): def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None: super().__init__() self.metadata = metadata self.mesh = mesh def post_forward(self, module, output): is_tensor = torch.is_tensor(output) if is_tensor: output = [output] output = list(output) assert len(output) == len(self.metadata), f"Expected {len(self.metadata)} outputs, but got {len(output)}." for i, cpm in enumerate(self.metadata): if cpm is None: continue output[i] = _EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.mesh) return output[0] if is_tensor else tuple(output) class _ContextParallelSharder: @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: raise NotImplementedError("_ContextParallelSharder::shard should be implemented in subclasses") @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: raise NotImplementedError("_ContextParallelSharder::unshard should be implemented in subclasses") class _EquipartitionSharder(_ContextParallelSharder): """ Shards the input tensor along the specified dimension into cp_mesh's world size chunks. Essentially, rank_i gets the i-th chunk. This sharding strategy should only be used when performing full attention. Otherwise, it will have performance penalty. If using causal attention, please use _CausalSharder instead. """ @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: assert tensor.size()[dim] % mesh.size() == 0 return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()] @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: tensor = tensor.contiguous() # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() return result # TODO(aryan): this class is untested class _CausalSharder(_ContextParallelSharder): """ Shards the input tensor along the specified dimension into 2x cp_mesh's world size chunks. Essentially, rank_i gets the i-th chunk and (2 * cp_world_size - 1 - i)-th chunk. This sharding strategy improves the performance for causal attention, as it allows equal distribution of computation across all ranks. Causal attention mask: ``` 1 0 0 0 <--- Group 0 1 1 0 0 <--- Group 1 1 1 1 0 <--- Group 1 1 1 1 1 <--- Group 0 ``` """ @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: world_size = mesh.size() rank = mesh.get_local_rank() assert tensor.size()[dim] % (2 * world_size) == 0 chunks = tensor.chunk(2 * world_size, dim=dim) i, j = rank, 2 * world_size - 1 - rank return torch.cat((chunks[i], chunks[j]), dim=dim) @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: tensor = tensor.contiguous() world_size = mesh.size() # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim all_tensors = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() sliced_tensors = [st for t in all_tensors for st in t.chunk(2, dim=dim)] ordered_tensors = list(sliced_tensors) for i, t in enumerate(sliced_tensors): if i % 2 == 0: ordered_tensors[i // 2] = t else: ordered_tensors[world_size * 2 - (i // 2) - 1] = t return torch.cat(ordered_tensors, dim=dim)