# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # This file is modified from https://github.com/feifeibear/long-context-attention # Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719 import os import deepspeed.comm as dist import torch class Singleton: _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super().__new__(cls) cls._instance.__initialized = False return cls._instance def __init__(self): if not self.__initialized: self.__initialized = True class ProcessGroupManager(Singleton): """ sp_degree = sp_ring_degree x sp_ulysses_degree """ def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low, ring_type): if not hasattr(self, "__initialized"): super().__init__() self.ulysses_degree = ulysses_degree self.ring_type = ring_type self.ulysses_seq_len = None self.ring_degree = ring_degree self.sp_degree = ring_degree * ulysses_degree self.dp_degree = dp_degree self.rank = dist.get_rank() if self.ring_degree == 1: # Using Ulysses Sequence Parallelism only num_ulysses_pgs = self.dp_degree self.ring_pg = None self.ring_rank = None for i in range(num_ulysses_pgs): ulysses_ranks = list(range(i * self.ulysses_degree, (i + 1) * self.ulysses_degree)) group = dist.new_group(ulysses_ranks) if self.rank in ulysses_ranks: self.ulysses_pg = group for sp_rank in range(self.sp_degree): dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree)) group = dist.new_group(dp_ranks) if self.rank in dp_ranks: self.dp_pg = group self.ulysses_rank = dist.get_rank(self.ulysses_pg) self.sp_rank = self.ulysses_rank self.dp_rank = dist.get_rank(self.dp_pg) self.sp_pg = self.ulysses_pg print(f"GPU {torch.cuda.current_device()} Ulysses rank: {self.ulysses_rank} out of {self.sp_degree}") else: # Using Hybrid Sequence Parallelism assert self.ring_degree > 1 num_ulysses_pgs = self.ring_degree # world_size // self.ulysses_degree num_ring_pgs = self.ulysses_degree # world_size // self.ring_degree # Set up process groups if use_ulysses_low: for dp_rank in range(dp_degree): offset = dp_rank * self.sp_degree for i in range(num_ulysses_pgs): ulysses_ranks = list( range( i * self.ulysses_degree + offset, (i + 1) * self.ulysses_degree + offset, ) ) group = dist.new_group(ulysses_ranks) if self.rank in ulysses_ranks: self.ulysses_pg = group for i in range(num_ring_pgs): ring_ranks = list(range(i + offset, self.sp_degree + offset, num_ring_pgs)) group = dist.new_group(ring_ranks) if self.rank in ring_ranks: self.ring_pg = group else: for dp_rank in range(dp_degree): offset = dp_rank * self.sp_degree for i in range(num_ring_pgs): ring_ranks = list(range(i * self.ring_degree + offset, (i + 1) * self.ring_degree + offset)) group = dist.new_group(ring_ranks) if self.rank in ring_ranks: self.ring_pg = group for i in range(num_ulysses_pgs): ulysses_ranks = list(range(i + offset, self.sp_degree + offset, num_ulysses_pgs)) group = dist.new_group(ulysses_ranks) if self.rank in ulysses_ranks: self.ulysses_pg = group for sp_rank in range(self.sp_degree): dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree)) group = dist.new_group(dp_ranks) if self.rank in dp_ranks: self.dp_pg = group for i in range(self.dp_degree): sp_ranks = list(range(i * self.sp_degree, (i + 1) * self.sp_degree)) group = dist.new_group(sp_ranks) if self.rank in sp_ranks: self.sp_pg = group self.ulysses_rank = dist.get_rank(self.ulysses_pg) self.ring_rank = dist.get_rank(self.ring_pg) self.dp_rank = dist.get_rank(self.dp_pg) if use_ulysses_low: self.sp_rank = self.ulysses_rank + self.ring_rank * self.ulysses_degree else: self.sp_rank = self.ring_rank + self.ulysses_rank * self.ring_degree print( f"Rank {self.rank}, GPU {torch.cuda.current_device()} Hybrid SP rank: {self.sp_rank} out of {self.sp_degree} (Ulysses: {self.ulysses_rank}/{self.ulysses_degree}, Ring: {self.ring_rank}/{self.ring_degree})" ) print("--------------ProcessGroupManager Initialized---------------------") PROCESS_GROUP_MANAGER = None def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True, ring_type=None): """ Set the process group manager for sequence parallelism. sp_degree = sp_ring_degree x sp_ulysses_degree """ # first check torch distributed group init and set device accordingly; # (DL) TODO: Whether this can be skipped in DeepSpeed. if dist.is_initialized(): if dist.get_rank() == 0: print( "torch distributed is already initialized, " "skipping initialization ...", flush=True, ) else: if int(os.environ["RANK"]) == 0: print("Initializing Torch distributed.") dist.init_distributed(dist_backend="nccl", dist_init_required=True) local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) torch.cuda.set_device(dist.get_rank() % local_world_size) world_size = dist.get_world_size() assert sp_degree <= world_size assert world_size % sp_degree == 0, f"world_size {world_size} % sp_degree {sp_degree} != 0" if sp_ring_degree < 1: sp_ring_degree = 1 sp_ulysses_degree = sp_degree // sp_ring_degree assert sp_degree % sp_ring_degree == 0, f"sp_degree {sp_degree} % sp_ring_degree {sp_ring_degree} != 0" dp_degree = world_size // sp_degree # Init the process group manager global PROCESS_GROUP_MANAGER PROCESS_GROUP_MANAGER = ProcessGroupManager( sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low, ring_type ) def get_pg_manager(): return PROCESS_GROUP_MANAGER def get_sequence_parallel_size(): """Get the size of the sequence parallel group.""" return PROCESS_GROUP_MANAGER.sp_degree def get_sequence_parallel_rank(): """Get the rank of this process in the sequence parallel group the caller rank belongs to.""" return PROCESS_GROUP_MANAGER.sp_rank def get_sequence_parallel_pg(): """Get the overall sequence parallel process group (include Ring and Ulysses).""" return PROCESS_GROUP_MANAGER.sp_pg def get_ulysses_sp_size(): """Get the size of the Ulysses sequence parallel group.""" return PROCESS_GROUP_MANAGER.ulysses_degree def get_ulysses_seq_len(): """Get the size of the Ulysses sequence parallel group.""" return PROCESS_GROUP_MANAGER.ulysses_seq_len def set_ulysses_seq_len(seq_len): """Get the size of the Ulysses sequence parallel group.""" PROCESS_GROUP_MANAGER.ulysses_seq_len = seq_len def get_ulysses_sp_rank(): """Get the rank of this process in the Ulysses sequence parallel group the caller rank belongs to.""" return PROCESS_GROUP_MANAGER.ulysses_rank def get_ulysses_sp_pg(): """Get the Ulysses sequence parallel process group.""" return PROCESS_GROUP_MANAGER.ulysses_pg def get_ring_sp_size(): """Get the size of the RingAttn sequence parallel group.""" return PROCESS_GROUP_MANAGER.ring_degree def get_ring_sp_rank(): """Get the rank of this process in the RingAttn sequence parallel group the caller rank belongs to.""" return PROCESS_GROUP_MANAGER.ring_rank def get_ring_sp_pg(): """Get the RingAttn sequence parallel process group.""" return PROCESS_GROUP_MANAGER.ring_pg def get_ring_type(): """Get the RingAttn implementation type.""" return PROCESS_GROUP_MANAGER.ring_type def get_data_parallel_size(): """Get the size of the data parallel group.""" return PROCESS_GROUP_MANAGER.dp_degree def get_data_parallel_rank(): """Get the rank of this process in the data parallel group the caller rank belongs to.""" return PROCESS_GROUP_MANAGER.dp_rank