SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/feifeibear/long-context-attention
# Implementation refers to USP Paper: https://arxiv.org/abs/2405.07719
import os
import deepspeed.comm as dist
import torch
class Singleton:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
cls._instance.__initialized = False
return cls._instance
def __init__(self):
if not self.__initialized:
self.__initialized = True
class ProcessGroupManager(Singleton):
"""
sp_degree = sp_ring_degree x sp_ulysses_degree
"""
def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low, ring_type):
if not hasattr(self, "__initialized"):
super().__init__()
self.ulysses_degree = ulysses_degree
self.ring_type = ring_type
self.ulysses_seq_len = None
self.ring_degree = ring_degree
self.sp_degree = ring_degree * ulysses_degree
self.dp_degree = dp_degree
self.rank = dist.get_rank()
if self.ring_degree == 1:
# Using Ulysses Sequence Parallelism only
num_ulysses_pgs = self.dp_degree
self.ring_pg = None
self.ring_rank = None
for i in range(num_ulysses_pgs):
ulysses_ranks = list(range(i * self.ulysses_degree, (i + 1) * self.ulysses_degree))
group = dist.new_group(ulysses_ranks)
if self.rank in ulysses_ranks:
self.ulysses_pg = group
for sp_rank in range(self.sp_degree):
dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree))
group = dist.new_group(dp_ranks)
if self.rank in dp_ranks:
self.dp_pg = group
self.ulysses_rank = dist.get_rank(self.ulysses_pg)
self.sp_rank = self.ulysses_rank
self.dp_rank = dist.get_rank(self.dp_pg)
self.sp_pg = self.ulysses_pg
print(f"GPU {torch.cuda.current_device()} Ulysses rank: {self.ulysses_rank} out of {self.sp_degree}")
else:
# Using Hybrid Sequence Parallelism
assert self.ring_degree > 1
num_ulysses_pgs = self.ring_degree # world_size // self.ulysses_degree
num_ring_pgs = self.ulysses_degree # world_size // self.ring_degree
# Set up process groups
if use_ulysses_low:
for dp_rank in range(dp_degree):
offset = dp_rank * self.sp_degree
for i in range(num_ulysses_pgs):
ulysses_ranks = list(
range(
i * self.ulysses_degree + offset,
(i + 1) * self.ulysses_degree + offset,
)
)
group = dist.new_group(ulysses_ranks)
if self.rank in ulysses_ranks:
self.ulysses_pg = group
for i in range(num_ring_pgs):
ring_ranks = list(range(i + offset, self.sp_degree + offset, num_ring_pgs))
group = dist.new_group(ring_ranks)
if self.rank in ring_ranks:
self.ring_pg = group
else:
for dp_rank in range(dp_degree):
offset = dp_rank * self.sp_degree
for i in range(num_ring_pgs):
ring_ranks = list(range(i * self.ring_degree + offset, (i + 1) * self.ring_degree + offset))
group = dist.new_group(ring_ranks)
if self.rank in ring_ranks:
self.ring_pg = group
for i in range(num_ulysses_pgs):
ulysses_ranks = list(range(i + offset, self.sp_degree + offset, num_ulysses_pgs))
group = dist.new_group(ulysses_ranks)
if self.rank in ulysses_ranks:
self.ulysses_pg = group
for sp_rank in range(self.sp_degree):
dp_ranks = list(range(sp_rank, self.dp_degree * self.sp_degree, self.sp_degree))
group = dist.new_group(dp_ranks)
if self.rank in dp_ranks:
self.dp_pg = group
for i in range(self.dp_degree):
sp_ranks = list(range(i * self.sp_degree, (i + 1) * self.sp_degree))
group = dist.new_group(sp_ranks)
if self.rank in sp_ranks:
self.sp_pg = group
self.ulysses_rank = dist.get_rank(self.ulysses_pg)
self.ring_rank = dist.get_rank(self.ring_pg)
self.dp_rank = dist.get_rank(self.dp_pg)
if use_ulysses_low:
self.sp_rank = self.ulysses_rank + self.ring_rank * self.ulysses_degree
else:
self.sp_rank = self.ring_rank + self.ulysses_rank * self.ring_degree
print(
f"Rank {self.rank}, GPU {torch.cuda.current_device()} Hybrid SP rank: {self.sp_rank} out of {self.sp_degree} (Ulysses: {self.ulysses_rank}/{self.ulysses_degree}, Ring: {self.ring_rank}/{self.ring_degree})"
)
print("--------------ProcessGroupManager Initialized---------------------")
PROCESS_GROUP_MANAGER = None
def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True, ring_type=None):
"""
Set the process group manager for sequence parallelism.
sp_degree = sp_ring_degree x sp_ulysses_degree
"""
# first check torch distributed group init and set device accordingly;
# (DL) TODO: Whether this can be skipped in DeepSpeed.
if dist.is_initialized():
if dist.get_rank() == 0:
print(
"torch distributed is already initialized, " "skipping initialization ...",
flush=True,
)
else:
if int(os.environ["RANK"]) == 0:
print("Initializing Torch distributed.")
dist.init_distributed(dist_backend="nccl", dist_init_required=True)
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
torch.cuda.set_device(dist.get_rank() % local_world_size)
world_size = dist.get_world_size()
assert sp_degree <= world_size
assert world_size % sp_degree == 0, f"world_size {world_size} % sp_degree {sp_degree} != 0"
if sp_ring_degree < 1:
sp_ring_degree = 1
sp_ulysses_degree = sp_degree // sp_ring_degree
assert sp_degree % sp_ring_degree == 0, f"sp_degree {sp_degree} % sp_ring_degree {sp_ring_degree} != 0"
dp_degree = world_size // sp_degree
# Init the process group manager
global PROCESS_GROUP_MANAGER
PROCESS_GROUP_MANAGER = ProcessGroupManager(
sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low, ring_type
)
def get_pg_manager():
return PROCESS_GROUP_MANAGER
def get_sequence_parallel_size():
"""Get the size of the sequence parallel group."""
return PROCESS_GROUP_MANAGER.sp_degree
def get_sequence_parallel_rank():
"""Get the rank of this process in the sequence parallel group the caller rank belongs to."""
return PROCESS_GROUP_MANAGER.sp_rank
def get_sequence_parallel_pg():
"""Get the overall sequence parallel process group (include Ring and Ulysses)."""
return PROCESS_GROUP_MANAGER.sp_pg
def get_ulysses_sp_size():
"""Get the size of the Ulysses sequence parallel group."""
return PROCESS_GROUP_MANAGER.ulysses_degree
def get_ulysses_seq_len():
"""Get the size of the Ulysses sequence parallel group."""
return PROCESS_GROUP_MANAGER.ulysses_seq_len
def set_ulysses_seq_len(seq_len):
"""Get the size of the Ulysses sequence parallel group."""
PROCESS_GROUP_MANAGER.ulysses_seq_len = seq_len
def get_ulysses_sp_rank():
"""Get the rank of this process in the Ulysses sequence parallel group the caller rank belongs to."""
return PROCESS_GROUP_MANAGER.ulysses_rank
def get_ulysses_sp_pg():
"""Get the Ulysses sequence parallel process group."""
return PROCESS_GROUP_MANAGER.ulysses_pg
def get_ring_sp_size():
"""Get the size of the RingAttn sequence parallel group."""
return PROCESS_GROUP_MANAGER.ring_degree
def get_ring_sp_rank():
"""Get the rank of this process in the RingAttn sequence parallel group the caller rank belongs to."""
return PROCESS_GROUP_MANAGER.ring_rank
def get_ring_sp_pg():
"""Get the RingAttn sequence parallel process group."""
return PROCESS_GROUP_MANAGER.ring_pg
def get_ring_type():
"""Get the RingAttn implementation type."""
return PROCESS_GROUP_MANAGER.ring_type
def get_data_parallel_size():
"""Get the size of the data parallel group."""
return PROCESS_GROUP_MANAGER.dp_degree
def get_data_parallel_rank():
"""Get the rank of this process in the data parallel group the caller rank belongs to."""
return PROCESS_GROUP_MANAGER.dp_rank