|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.autograd import Function |
|
|
|
|
|
class AllGather(Function): |
|
@staticmethod |
|
def forward(ctx, tensor, process_group): |
|
world_size = dist.get_world_size(process_group) |
|
ctx.world_size = world_size |
|
ctx.rank = process_group.rank() |
|
|
|
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] |
|
dist.all_gather(gathered_tensors, tensor.contiguous(), process_group) |
|
return torch.cat(gathered_tensors, dim=0) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
world_size = ctx.world_size |
|
rank = ctx.rank |
|
|
|
|
|
grad_chunks = grad_output.chunk(world_size) |
|
|
|
|
|
grad_input = grad_chunks[rank] |
|
return grad_input, None |
|
|
|
|
|
def gather_along_first_dim(tensor, process_group): |
|
return AllGather.apply(tensor, process_group) |
|
|
|
|
|
class Scatter(Function): |
|
@staticmethod |
|
def forward(ctx, tensor, process_group): |
|
world_size = dist.get_world_size(process_group) |
|
ctx.world_size = world_size |
|
ctx.process_group = process_group |
|
rank = process_group.rank() |
|
|
|
|
|
tensor_chunks = tensor.chunk(world_size) |
|
|
|
|
|
return tensor_chunks[rank] |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
world_size = ctx.world_size |
|
process_group = ctx.process_group |
|
|
|
|
|
gathered_grads = [torch.zeros_like(grad_output) for _ in range(world_size)] |
|
dist.all_gather(gathered_grads, grad_output.contiguous(), process_group) |
|
return torch.cat(gathered_grads, dim=0), None |
|
|
|
|
|
def scatter_along_first_dim(tensor, process_group): |
|
return Scatter.apply(tensor, process_group) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
rank = int(os.environ["RANK"]) |
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
torch.cuda.set_device(local_rank) |
|
torch.distributed.init_process_group(world_size=world_size, rank=rank) |
|
|
|
|
|
x = torch.randn(10, 1, requires_grad=True, device="cuda") |
|
|
|
|
|
y = gather_along_first_dim(x, dist.group.WORLD) |
|
print(f"{y.shape=}") |
|
y = scatter_along_first_dim(y, dist.group.WORLD) |
|
print(f"{y.shape=}") |
|
|
|
|
|
loss = y.sum() |
|
loss.backward() |
|
|
|
|
|
print(x.grad) |
|
|