# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. """ Advanced distributed functions for sequence parallel. """ from typing import Optional, List import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import ShardingStrategy from .basic import get_global_rank, get_world_size _DATA_PARALLEL_GROUP = None _SEQUENCE_PARALLEL_GROUP = None _SEQUENCE_PARALLEL_CPU_GROUP = None _MODEL_SHARD_CPU_INTER_GROUP = None _MODEL_SHARD_CPU_INTRA_GROUP = None _MODEL_SHARD_INTER_GROUP = None _MODEL_SHARD_INTRA_GROUP = None _SEQUENCE_PARALLEL_GLOBAL_RANKS = None def get_data_parallel_group() -> Optional[dist.ProcessGroup]: """ Get data parallel process group. """ return _DATA_PARALLEL_GROUP def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]: """ Get sequence parallel process group. """ return _SEQUENCE_PARALLEL_GROUP def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]: """ Get sequence parallel CPU process group. """ return _SEQUENCE_PARALLEL_CPU_GROUP def get_data_parallel_rank() -> int: """ Get data parallel rank. """ group = get_data_parallel_group() return dist.get_rank(group) if group else get_global_rank() def get_data_parallel_world_size() -> int: """ Get data parallel world size. """ group = get_data_parallel_group() return dist.get_world_size(group) if group else get_world_size() def get_sequence_parallel_rank() -> int: """ Get sequence parallel rank. """ group = get_sequence_parallel_group() return dist.get_rank(group) if group else 0 def get_sequence_parallel_world_size() -> int: """ Get sequence parallel world size. """ group = get_sequence_parallel_group() return dist.get_world_size(group) if group else 1 def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]: """ Get the CPU intra process group of model sharding. """ return _MODEL_SHARD_CPU_INTRA_GROUP def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]: """ Get the CPU inter process group of model sharding. """ return _MODEL_SHARD_CPU_INTER_GROUP def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]: """ Get the GPU intra process group of model sharding. """ return _MODEL_SHARD_INTRA_GROUP def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]: """ Get the GPU inter process group of model sharding. """ return _MODEL_SHARD_INTER_GROUP def init_sequence_parallel(sequence_parallel_size: int): """ Initialize sequence parallel. """ global _DATA_PARALLEL_GROUP global _SEQUENCE_PARALLEL_GROUP global _SEQUENCE_PARALLEL_CPU_GROUP global _SEQUENCE_PARALLEL_GLOBAL_RANKS assert dist.is_initialized() world_size = dist.get_world_size() rank = dist.get_rank() data_parallel_size = world_size // sequence_parallel_size for i in range(data_parallel_size): start_rank = i * sequence_parallel_size end_rank = (i + 1) * sequence_parallel_size ranks = range(start_rank, end_rank) group = dist.new_group(ranks) cpu_group = dist.new_group(ranks, backend="gloo") if rank in ranks: _SEQUENCE_PARALLEL_GROUP = group _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks) def init_model_shard_group( *, sharding_strategy: ShardingStrategy, device_mesh: Optional[DeviceMesh] = None, ): """ Initialize process group of model sharding. """ global _MODEL_SHARD_INTER_GROUP global _MODEL_SHARD_INTRA_GROUP global _MODEL_SHARD_CPU_INTER_GROUP global _MODEL_SHARD_CPU_INTRA_GROUP assert dist.is_initialized() world_size = dist.get_world_size() if device_mesh is not None: num_shards_per_group = device_mesh.shape[1] elif sharding_strategy == ShardingStrategy.NO_SHARD: num_shards_per_group = 1 elif sharding_strategy in [ ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2, ]: num_shards_per_group = torch.cuda.device_count() else: num_shards_per_group = world_size num_groups = world_size // num_shards_per_group device_mesh = (num_groups, num_shards_per_group) gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra")) cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra")) _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter") _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra") _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter") _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra") def get_sequence_parallel_global_ranks() -> List[int]: """ Get all global ranks of the sequence parallel process group that the caller rank belongs to. """ if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None: return [dist.get_rank()] return _SEQUENCE_PARALLEL_GLOBAL_RANKS def get_next_sequence_parallel_rank() -> int: """ Get the next global rank of the sequence parallel process group that the caller rank belongs to. """ sp_global_ranks = get_sequence_parallel_global_ranks() sp_rank = get_sequence_parallel_rank() sp_size = get_sequence_parallel_world_size() return sp_global_ranks[(sp_rank + 1) % sp_size] def get_prev_sequence_parallel_rank() -> int: """ Get the previous global rank of the sequence parallel process group that the caller rank belongs to. """ sp_global_ranks = get_sequence_parallel_global_ranks() sp_rank = get_sequence_parallel_rank() sp_size = get_sequence_parallel_world_size() return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size]