Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,379 Bytes
ba7cb71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
def init_distributed_group():
"""r initialize sequence parallel group.
"""
if not dist.is_initialized():
dist.init_process_group(backend='nccl')
def get_rank():
return dist.get_rank()
def get_world_size():
return dist.get_world_size()
def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
"""
`scatter` along one dimension and `gather` along another.
"""
world_size = get_world_size()
if world_size > 1:
inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
outputs = [torch.empty_like(u) for u in inputs]
dist.all_to_all(outputs, inputs, group=group, **kwargs)
x = torch.cat(outputs, dim=gather_dim).contiguous()
return x
def all_gather(tensor):
world_size = dist.get_world_size()
if world_size == 1:
return [tensor]
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor)
return tensor_list
def gather_forward(input, dim):
# skip if world_size == 1
world_size = dist.get_world_size()
if world_size == 1:
return input
# gather sequence
output = all_gather(input)
return torch.cat(output, dim=dim).contiguous()
|