jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import datetime
import functools
import os
import pathlib
import shutil
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import datasets.distributed
import torch
import torch.distributed._functional_collectives
import torch.distributed.checkpoint
import torch.distributed.checkpoint.stateful
from diffusers.hooks import HookRegistry, ModelHook
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
from torch.distributed._composable.replicate import replicate
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
set_model_state_dict,
)
from torch.distributed.tensor import DTensor, Shard
from finetrainers._metadata import ContextParallelModelPlan, CPInput, CPOutput, TransformerRegistry
from finetrainers.data import DPDataLoader
from finetrainers.logging import get_logger
from finetrainers.utils import enable_determinism, get_device_info, get_submodule_by_name, unwrap_module
from finetrainers.utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
from .base import BaseCheckpointer, BaseParallelBackend
if TYPE_CHECKING:
from finetrainers import optimizer
_device_type, _device_module = get_device_info()
logger = get_logger()
class PytorchDTensorParallelBackend(BaseParallelBackend):
def __init__(
self,
world_size: int,
pp_degree: int = 1,
dp_degree: int = 1,
dp_shards: int = -1,
cp_degree: int = 1,
tp_degree: int = 1,
backend: str = "nccl",
timeout: int = 180,
logging_dir: Optional[str] = None,
output_dir: Optional[str] = None,
gradient_accumulation_steps: Optional[int] = None,
) -> None:
super().__init__()
self._world_size = world_size
self._pp_degree = pp_degree
self._dp_degree = dp_degree
self._dp_shards = dp_shards
self._cp_degree = cp_degree
self._tp_degree = tp_degree
self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
self._logging_dir = (
self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
)
self._backend = backend
self._timeout = timeout
for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]:
if degree < 1:
raise ValueError(f"Parallel degree must be at least 1, got {degree}.")
if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size:
raise ValueError(
f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards."
)
torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout))
_device_module.set_device(self.local_rank)
logger.info(
f"Initialized parallel state with:\n"
f" - World size: {world_size}\n"
f" - Pipeline parallel degree: {pp_degree}\n"
f" - Data parallel degree: {dp_degree}\n"
f" - Context parallel degree: {cp_degree}\n"
f" - Tensor parallel degree: {tp_degree}\n"
f" - Data parallel shards: {dp_shards}\n"
)
self._mesh: torch.distributed.DeviceMesh = None
def enable_determinism(self, seed):
world_mesh = self.get_mesh()
enable_determinism(seed, world_mesh)
def apply_ddp(
self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
) -> torch.nn.Module:
if device_mesh is None:
device_mesh = self.get_mesh()
apply_ddp(model, device_mesh)
logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.")
return model
def apply_fsdp2(
self,
model: torch.nn.Module,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
output_dtype: torch.dtype,
pp_enabled: bool = False,
cpu_offload: bool = False,
device_mesh: Optional[torch.distributed.DeviceMesh] = None,
) -> torch.nn.Module:
if device_mesh is None:
device_mesh = self.get_mesh()
apply_fsdp2(model, device_mesh, param_dtype, reduce_dtype, output_dtype, pp_enabled, cpu_offload)
logger.debug("Applied PytorchDTensorParallel::apply_fsdp2 to model.")
return model
def apply_context_parallel(
self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
) -> torch.nn.Module:
if device_mesh is None:
device_mesh = self.get_mesh()
apply_context_parallel(model, device_mesh)
logger.debug("Applied PytorchDTensorParallel::apply_context_parallel to model.")
return model
def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module:
return model
def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
if self._dp_degree == 1:
return dataset
dp_mesh = self.get_mesh()["dp_replicate"]
dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!")
return dataset
def prepare_dataloader(
self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool
) -> DPDataLoader:
if self._dp_degree == 1:
dp_local_rank = 0
else:
dp_mesh = self.get_mesh()["dp_replicate"]
dp_local_rank = dp_mesh.get_local_rank()
dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)
logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
return dataloader
def prepare_optimizer(self, optimizer, lr_scheduler):
logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!")
return optimizer, lr_scheduler
def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
def _get_mesh():
if name is None:
return self._mesh
try:
return self._mesh[name]
except (KeyError, RuntimeError):
if self._mesh.ndim == 0:
return None
return self._mesh
if self._mesh is not None:
return _get_mesh()
mesh_list = [
("pp", self._pp_degree),
("dp_replicate", self._dp_degree),
("dp_shard", self._dp_shards),
("cp", self._cp_degree),
("tp", self._tp_degree),
]
mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
names = [x[0] for x in mesh_list]
degrees = [x[1] for x in mesh_list]
mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
if self.data_replication_enabled:
dp_mesh_names.append("dp_replicate")
dp_cp_mesh_names.append("dp_replicate")
if self.data_sharding_enabled:
dp_mesh_names.append("dp_shard")
dp_cp_mesh_names.append("dp_shard")
dp_shard_cp_mesh_names.append("dp_shard")
if self.context_parallel_enabled:
dp_cp_mesh_names.append("cp")
dp_shard_cp_mesh_names.append("cp")
if len(dp_mesh_names) > 0:
mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
if len(dp_cp_mesh_names) > 0:
mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
if len(dp_shard_cp_mesh_names) > 0:
mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
logger.debug(f"Device mesh: {mesh}")
self._mesh = mesh
return _get_mesh()
def get_checkpointer(self, *args, **kwargs):
return PTDCheckpointer(*args, **kwargs)
@property
def world_size(self):
return torch.distributed.get_world_size()
@property
def rank(self):
return torch.distributed.get_rank()
@property
def local_rank(self):
return int(os.environ.get("LOCAL_RANK", 0))
@property
def is_main_process(self):
r"""Returns `True` if the current process is the main process on the master node."""
return self.rank == 0
@property
def is_local_main_process(self):
r"""Returns `True` if the current process is the main process on local node."""
return self.local_rank == 0
@property
def device(self):
return torch.device(_device_type, self.local_rank)
def wait_for_everyone(self):
return torch.distributed.barrier()
# @contextmanager
# def main_process_first(self):
# if self.is_main_process:
# yield
# self.wait_for_everyone()
# else:
# self.wait_for_everyone()
# yield
def destroy(self):
if self.is_main_process and self.tracker is not None:
self.tracker.finish()
return torch.distributed.destroy_process_group()
@property
def pipeline_parallel_enabled(self):
return self._pp_degree > 1
@property
def data_parallel_enabled(self):
return self._dp_degree > 1 or self._dp_shards > 1
@property
def data_replication_enabled(self):
return self._dp_degree > 1
@property
def data_sharding_enabled(self):
return self._dp_shards > 1
@property
def context_parallel_enabled(self):
return self._cp_degree > 1
@property
def tensor_parallel_enabled(self):
return self._tp_degree > 1
class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful):
def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None:
self.model = [model] if isinstance(model, torch.nn.Module) else model
def state_dict(self) -> Dict[str, Any]:
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_model_state_dict,
model_state_dict=state_dict,
options=StateDictOptions(strict=False),
)
list(map(func, self.model))
class PTDCheckpointer(BaseCheckpointer):
def __init__(
self,
dataloader: torch.utils.data.DataLoader,
model_parts: List[torch.nn.Module],
optimizers: "optimizer.OptimizerWrapper",
schedulers: "optimizer.SchedulerWrapper",
states: Dict[str, Any],
checkpointing_steps: int,
checkpointing_limit: int,
output_dir: str,
enable: bool = True,
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
_prefix: str = "finetrainers_step",
) -> None:
self.states = states
self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
}
)
self.states.update(schedulers.get_lr_scheduler_state())
self.checkpointing_steps = checkpointing_steps
self.checkpointing_limit = checkpointing_limit
self.output_dir = pathlib.Path(output_dir)
self.enable = enable
self._callback_fn = _callback_fn
self._prefix = _prefix
logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")
def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
if not self._should_checkpoint(step, force):
return None
checkpoint_dir = self._get_checkpoint_dir(step)
begin_time = time.monotonic()
torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix())
end_time = time.monotonic()
logger.info(
f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
)
self._purge_stale_checkpoints()
state_dicts = [
gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process)
for model in self.states["model"].model
]
if self._callback_fn is not None:
list(map(self._callback_fn, state_dicts))
return checkpoint_dir.as_posix()
def load(self, step: int = -1) -> bool:
if not self.enable:
return False
if not self.output_dir.exists():
return False
if step != -1 and not self._get_checkpoint_dir(step).exists():
return False
if step == -1:
latest_checkpoint_dir = self._find_latest_checkpoint_dir()
if latest_checkpoint_dir is None:
return False
step = int(latest_checkpoint_dir.name.split("_")[-1])
checkpoint_dir = self._get_checkpoint_dir(step)
logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")
# For step 0, optimizers/schedulers are not available as they are created during training after first step
states = {"model": self.states["model"]} if step == 0 else self.states
# See bug: https://github.com/pytorch/pytorch/pull/138575
original_stateful_states = {
k: v for k, v in states.items() if isinstance(v, torch.distributed.checkpoint.stateful.Stateful)
}
begin_time = time.monotonic()
torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix())
end_time = time.monotonic()
logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")
# bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load()
states.update(original_stateful_states)
return True
def _should_checkpoint(self, step: int, force: bool) -> bool:
if not self.enable:
return False
if not force:
if step % self.checkpointing_steps != 0:
return False
return True
def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
return self.output_dir / f"{self._prefix}_{step}"
def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]:
checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
return checkpoints[-1] if len(checkpoints) > 0 else None
def _purge_stale_checkpoints(self) -> None:
if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
return
checkpoints = sorted(
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
)
for checkpoint in checkpoints[self.checkpointing_limit :]:
logger.info(f"Deleting stale checkpoint: {checkpoint}")
shutil.rmtree(checkpoint, ignore_errors=True)
def gather_state_dict_on_cpu_rank0(
model, device: Optional[torch.device] = None, *, is_main_process: bool
) -> Dict[str, Any]:
cpu_state_dict = {}
sharded_sd = model.state_dict()
for param_name, param in sharded_sd.items():
if param.is_cpu:
# Move back to device if offloaded to CPU
param = param.to(device)
if hasattr(param, "_local_tensor"):
# Gather DTensor
param = param.full_tensor()
if is_main_process:
cpu_state_dict[param_name] = param.cpu()
torch.distributed.barrier()
return cpu_state_dict
# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict
# def dcp_to_torch_save(
# dcp_checkpoint_dir: Union[str, os.PathLike],
# torch_save_path: Union[str, os.PathLike],
# callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
# ):
# """
# Given a directory containing a DCP checkpoint, this function will convert it into a
# Torch save file.
# Args:
# dcp_checkpoint_dir: Directory containing the DCP checkpoint.
# torch_save_path: Filename to store the converted Torch save file.
# callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict.
# .. warning::
# To avoid OOM, it's recommended to only run this function on a single rank.
# """
# state_dict = {}
# _load_state_dict(
# state_dict,
# storage_reader=FileSystemReader(dcp_checkpoint_dir),
# planner=_EmptyStateDictLoadPlanner(),
# no_dist=True,
# )
# if callback_fn is not None:
# state_dict = callback_fn(state_dict)
# torch.save(state_dict, torch_save_path)
def apply_ddp(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
def apply_fsdp2(
model: torch.nn.Module,
dp_mesh: torch.distributed.device_mesh.DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
output_dtype: torch.dtype,
pp_enabled: bool = False,
cpu_offload: bool = False,
) -> None:
"""Apply FSDP2 on a model."""
mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)
def apply_fully_shard(blocks):
for layer_index, block in enumerate(blocks):
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_index < len(blocks) - 1
fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
blocks = getattr(model, transformer_block_name, None)
if blocks is not None:
apply_fully_shard(blocks)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
def apply_context_parallel(
model: torch.nn.Module,
mesh: torch.distributed.device_mesh.DeviceMesh,
plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
) -> None:
"""Apply context parallel on a model."""
logger.debug(f"Applying context parallel with CP mesh: {mesh}")
model_cls = unwrap_module(model).__class__
if plan is None:
plan = TransformerRegistry.get(model_cls).cp_plan
for module_id, cp_model_plan in plan.items():
module = get_submodule_by_name(model, module_id)
if not isinstance(module, list):
module = [module]
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(module)} modules")
for m in module:
registry = HookRegistry.check_if_exists_or_initialize(m)
if isinstance(cp_model_plan, list):
# Metadata can only be a list when it is a list of CPOutput
assert all(isinstance(x, CPOutput) for x in cp_model_plan)
hook = ContextParallelGatherHook(cp_model_plan, mesh)
hook_name = f"cp_output---{module_id}"
else:
hook = ContextParallelSplitHook(cp_model_plan, mesh)
hook_name = f"cp_input---{module_id}"
registry.register_hook(hook, hook_name)
class ContextParallelSplitHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
super().__init__()
self.metadata = metadata
self.mesh = mesh
def pre_forward(self, module, *args, **kwargs):
args_list = list(args)
for param_identifier, cpm in self.metadata.items():
name = param_identifier.name
index = param_identifier.index
if isinstance(cpm, CPInput) and cpm.split_output:
continue
# Maybe the parameter was passed as a keyword argument
is_kwarg = True
input_val = kwargs.get(name, None)
# If not, maybe it was passed as a positional argument
if input_val is None and index is not None:
if index < len(args_list): # Ensure index is within bounds
input_val = args_list[index]
is_kwarg = False
else:
logger.warning(f"Index {index} out of bounds for args of length {len(args_list)}.")
continue # Skip if index is invalid
# Either the input_val is truly None, or argument is passed as normal argument
# but user forgot to specify the index when registering metadata
if input_val is None:
continue
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if torch.is_tensor(input_val):
input_val = self._prepare_cp_input(input_val, cpm)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
)
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
sharded_input_val.append(x)
input_val = sharded_input_val
else:
raise ValueError(f"Unsupported input type: {type(input_val)}")
if is_kwarg:
kwargs[name] = input_val
elif index is not None and index < len(args_list):
args_list[index] = input_val
return tuple(args_list), kwargs
def post_forward(self, module, output):
is_tensor = torch.is_tensor(output)
is_tensor_list = isinstance(output, (list, tuple)) and all(torch.is_tensor(x) for x in output)
if not is_tensor and not is_tensor_list:
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
output = [output] if is_tensor else list(output)
for param_identifier, cpm in self.metadata.items():
if not isinstance(cpm, CPInput) or not cpm.split_output:
continue
index = param_identifier.index
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
output[index] = current_output
return output[0] if is_tensor else tuple(output)
def _prepare_cp_input(self, x: torch.Tensor, cp_input: CPInput) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
raise ValueError(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
)
return _EquipartitionSharder.shard(x, cp_input.split_dim, self.mesh)
class ContextParallelGatherHook(ModelHook):
def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
super().__init__()
self.metadata = metadata
self.mesh = mesh
def post_forward(self, module, output):
is_tensor = torch.is_tensor(output)
if is_tensor:
output = [output]
output = list(output)
assert len(output) == len(self.metadata), f"Expected {len(self.metadata)} outputs, but got {len(output)}."
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = _EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.mesh)
return output[0] if is_tensor else tuple(output)
class _ContextParallelSharder:
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
raise NotImplementedError("_ContextParallelSharder::shard should be implemented in subclasses")
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
raise NotImplementedError("_ContextParallelSharder::unshard should be implemented in subclasses")
class _EquipartitionSharder(_ContextParallelSharder):
"""
Shards the input tensor along the specified dimension into cp_mesh's world size chunks.
Essentially, rank_i gets the i-th chunk.
This sharding strategy should only be used when performing full attention. Otherwise, it will
have performance penalty. If using causal attention, please use _CausalSharder instead.
"""
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
assert tensor.size()[dim] % mesh.size() == 0
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()]
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
# TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim
result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
return result
# TODO(aryan): this class is untested
class _CausalSharder(_ContextParallelSharder):
"""
Shards the input tensor along the specified dimension into 2x cp_mesh's world size chunks.
Essentially, rank_i gets the i-th chunk and (2 * cp_world_size - 1 - i)-th chunk.
This sharding strategy improves the performance for causal attention, as it allows
equal distribution of computation across all ranks.
Causal attention mask:
```
1 0 0 0 <--- Group 0
1 1 0 0 <--- Group 1
1 1 1 0 <--- Group 1
1 1 1 1 <--- Group 0
```
"""
@classmethod
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
world_size = mesh.size()
rank = mesh.get_local_rank()
assert tensor.size()[dim] % (2 * world_size) == 0
chunks = tensor.chunk(2 * world_size, dim=dim)
i, j = rank, 2 * world_size - 1 - rank
return torch.cat((chunks[i], chunks[j]), dim=dim)
@classmethod
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
tensor = tensor.contiguous()
world_size = mesh.size()
# TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim
all_tensors = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
sliced_tensors = [st for t in all_tensors for st in t.chunk(2, dim=dim)]
ordered_tensors = list(sliced_tensors)
for i, t in enumerate(sliced_tensors):
if i % 2 == 0:
ordered_tensors[i // 2] = t
else:
ordered_tensors[world_size * 2 - (i // 2) - 1] = t
return torch.cat(ordered_tensors, dim=dim)